r'''
K-space utilities for lattice systems.
Provides:
- Brillouin zone path generation and high-symmetry points
- Bloch transformation caching for efficient k-space transforms
- K-point grid generation and path extraction
- Wigner-Seitz cell masking and BZ extension utilities
--------------------------------
File : lattices/tools/lattice_kspace.py
Author : Maksymilian Kliczkowski
Date : 2025-01-15
Changelog :
- 2026-03-03: Improved handling of reciprocal vector inputs and copy counts in extend_kspace_data().
Version : 2.1
- Added HighSymmetryPoint and HighSymmetryPoints classes for better management of high-symmetry points and paths.
- Enhanced ws_bz_mask() to support multiple shells of reciprocal lattice points for improved accuracy at the BZ boundary.
--------------------------------
'''
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, List, Optional, Literal, Tuple, Dict, NamedTuple, Union
from dataclasses import dataclass, field
from enum import Enum
import numpy as np
import scipy.sparse as sp
if TYPE_CHECKING:
from ..lattice import Lattice
from QES.Algebra.hamil_quadratic import QuadraticBlockDiagonalInfo
# -----------------------------------------------------------------------------------------------------------
# BRILLOUIN ZONE UTILITIES
# -----------------------------------------------------------------------------------------------------------
def _resolve_reciprocal_vector_inputs(*, lattice=None, reciprocal_vectors: Optional[Iterable[np.ndarray]] = None, b1=None, b2=None, b3=None) -> List[np.ndarray]:
"""Collect raw reciprocal vectors from the provided sources without changing dimensionality."""
if reciprocal_vectors is not None:
if isinstance(reciprocal_vectors, np.ndarray):
if reciprocal_vectors.ndim == 1:
return [reciprocal_vectors]
if reciprocal_vectors.ndim == 2:
return [row for row in reciprocal_vectors]
raise ValueError("reciprocal_vectors must be 1D or 2D")
return [vec for vec in reciprocal_vectors]
if lattice is not None:
dim = int(getattr(lattice, "dim", 0) or 0)
if dim > 0:
return [getattr(lattice, f"k{i + 1}", None) for i in range(dim)]
return [getattr(lattice, "k1", None), getattr(lattice, "k2", None), getattr(lattice, "k3", None)]
return [b1, b2, b3]
def _coerce_k_points(KX, KY=None, KZ=None, *, k_points: Optional[np.ndarray] = None) -> np.ndarray:
"""Return k-points with shape ``(..., dim)`` from either grids or explicit points."""
if k_points is not None:
k_vec = np.asarray(k_points, dtype=float)
elif KY is None:
k_vec = np.asarray(KX, dtype=float)
else:
components = [np.asarray(KX, dtype=float), np.asarray(KY, dtype=float)]
if KZ is not None:
components.append(np.asarray(KZ, dtype=float))
k_vec = np.stack(components, axis=-1)
if k_vec.ndim == 0:
raise ValueError("k-points must contain at least one coordinate axis")
if k_vec.ndim == 1:
k_vec = k_vec.reshape(1, -1)
return k_vec
def _normalize_reciprocal_vectors(vectors: Iterable[np.ndarray], kdim: int, *, copies: Optional[Iterable[int]] = None, drop_zero: bool = True, tol: float = 1e-12) -> Tuple[List[np.ndarray], List[int]]:
"""Normalize reciprocal vectors to ``kdim`` and keep copy counts aligned."""
raw_vectors = list(vectors)
raw_copies = None if copies is None else list(copies)
if raw_copies is not None:
if len(raw_copies) < len(raw_vectors):
raise ValueError("Number of reciprocal vectors must not exceed number of copy counts")
if len(raw_copies) > len(raw_vectors):
raw_copies = raw_copies[:len(raw_vectors)]
normalized_vectors: List[np.ndarray] = []
normalized_copies: List[int] = []
for idx, vec in enumerate(raw_vectors):
if vec is None:
continue
arr = np.asarray(vec, dtype=float).ravel()
if arr.size < kdim:
arr = np.pad(arr, (0, kdim - arr.size))
elif arr.size > kdim:
arr = arr[:kdim]
if drop_zero and np.linalg.norm(arr) <= tol:
continue
normalized_vectors.append(arr)
if raw_copies is not None:
normalized_copies.append(int(raw_copies[idx]))
return normalized_vectors, normalized_copies
# -----------------------------------------------------------------------------------------------------------
def ws_bz_mask(KX, KY=None, b1=None, b2=None, shells=1, *, KZ=None, b3=None, k_points=None, reciprocal_vectors=None, lattice=None, tol: float = 1e-12):
r"""
Wigner-Seitz (first BZ) mask for a reciprocal lattice.
Keeps points closer to Gamma than to any other reciprocal lattice point
in a finite neighborhood of reciprocal-lattice translations.
The Wigner-Seitz condition is: |k|^2 <= |k-G|^2 for all reciprocal lattice vectors G.
This simplifies to: 2k \cdot G <= G^2 (mathematically equivalent, computationally faster).
Parameters
----------
KX, KY, KZ : array_like, optional
Coordinate grids. The legacy 2D form is ``ws_bz_mask(KX, KY, b1, b2)``.
k_points : array_like, optional
Explicit k-points with shape ``(..., dim)``.
b1, b2, b3 : array_like, optional
Legacy reciprocal lattice vectors.
reciprocal_vectors : iterable of array_like, optional
General reciprocal translation vectors.
lattice : object, optional
Lattice-like object exposing ``dim`` and reciprocal vectors ``k1``, ``k2``, ``k3``.
shells : int or iterable of int, default=1
Number of reciprocal shells to consider along each reciprocal direction. Each new
shell adds more reciprocal lattice points to the Wigner-Seitz condition, improving accuracy at the BZ boundary.
The condition is modified to: 2k \cdot G <= |G|^2 + tol, where G are the reciprocal lattice vectors up to the specified shell count.
tol : float, default=1e-12
Numerical tolerance used at the BZ boundary.
Returns
-------
inside : ndarray (bool)
True for points inside the first Brillouin zone.
"""
# Process inputs and normalize to k-point array and list of reciprocal vectors
k_vec = _coerce_k_points(KX, KY, KZ, k_points=k_points)
kdim = k_vec.shape[-1]
raw_vectors = _resolve_reciprocal_vector_inputs(lattice=lattice, reciprocal_vectors=reciprocal_vectors, b1=b1, b2=b2, b3=b3,)
basis, _ = _normalize_reciprocal_vectors(raw_vectors, kdim, tol=tol)
if len(basis) == 0:
raise ValueError("At least one reciprocal lattice vector must be provided")
# Determine shell counts for each reciprocal vector and generate the grid of reciprocal lattice points to check against
if np.isscalar(shells):
shell_counts = [int(shells)] * len(basis)
else:
shell_counts = [int(shell) for shell in shells]
if len(shell_counts) != len(basis):
raise ValueError("Number of shell counts must match number of reciprocal vectors")
# Generate the grid of reciprocal lattice points G = n1*b1 + n2*b2 + n3*b3 for all integer combinations of n_i in [-shells, shells].
if any(shell < 0 for shell in shell_counts):
raise ValueError("shells must be non-negative")
coeff_axes = [np.arange(-shell, shell + 1, dtype=int) for shell in shell_counts]
coeff_grid = np.stack(np.meshgrid(*coeff_axes, indexing='ij'), axis=-1).reshape(-1, len(basis))
coeff_grid = coeff_grid[np.any(coeff_grid != 0, axis=1)]
if len(coeff_grid) == 0:
return np.ones(k_vec.shape[:-1], dtype=bool)
# Wigner-Seitz condition: 2k \cdot G <= |G|^2 for all reciprocal lattice vectors G.
G = coeff_grid @ np.vstack(basis)
k_dot_G = k_vec @ G.T
G_squared = np.sum(G**2, axis=1)
inside = np.all(2 * k_dot_G <= G_squared + tol, axis=-1)
return inside
def ws_bz_shifts(
*,
lattice = None,
reciprocal_vectors: Optional[Iterable[np.ndarray]] = None,
b1: Optional[np.ndarray] = None,
b2: Optional[np.ndarray] = None,
b3: Optional[np.ndarray] = None,
copies: Optional[Union[int, Iterable[int]]] = None,
nx: int = 1,
ny: int = 1,
nz: int = 0,
include_origin: bool = False,
tol: float = 1e-12,
) -> np.ndarray:
r"""
Return reciprocal-lattice translation vectors for Brillouin-zone copies.
This is the shared geometric primitive for drawing or selecting translated
Brillouin zones. It returns the centers of reciprocal-cell copies, not an
extended k-mesh.
Parameters
----------
lattice : object, optional
Lattice-like object exposing ``dim`` and reciprocal vectors ``k1``,
``k2``, ``k3``.
reciprocal_vectors : iterable of array_like, optional
Explicit reciprocal translation vectors. If provided, they take
precedence over ``b1``/``b2``/``b3``.
b1, b2, b3 : array_like, optional
Legacy reciprocal lattice vectors.
copies : int or iterable of int, optional
Number of translated copies along each reciprocal direction.
nx, ny, nz : int, default=(1, 1, 0)
Legacy per-direction copy counts used when ``copies`` is not given.
include_origin : bool, default=False
Whether to include the central Brillouin zone at ``Gamma``.
tol : float, default=1e-12
Tolerance used when removing numerically duplicated shifts.
Returns
-------
np.ndarray
Array of shape ``(Nshift, dim)`` containing unique reciprocal
translation vectors for Brillouin-zone copies.
"""
raw_vectors = _resolve_reciprocal_vector_inputs(
lattice=lattice,
reciprocal_vectors=reciprocal_vectors,
b1=b1,
b2=b2,
b3=b3,
)
nonzero_vectors = [np.asarray(vec, dtype=float).ravel() for vec in raw_vectors if vec is not None]
if lattice is not None and int(getattr(lattice, "dim", 0) or 0) > 0:
kdim = int(getattr(lattice, "dim"))
elif nonzero_vectors:
kdim = max(int(vec.size) for vec in nonzero_vectors)
else:
raise ValueError("At least one reciprocal lattice vector must be provided")
centers, _ = extend_kspace_data(
k_points=np.zeros((1, kdim), dtype=float),
lattice=lattice,
reciprocal_vectors=reciprocal_vectors,
b1=b1,
b2=b2,
b3=b3,
nx=nx,
ny=ny,
nz=nz,
copies=copies,
)
scale = max(float(tol), np.finfo(float).eps)
seen = set()
unique = []
for shift in np.asarray(centers, dtype=float):
key = tuple(np.rint(shift / scale).astype(np.int64))
if key in seen:
continue
seen.add(key)
if not include_origin and np.allclose(shift, 0.0, atol=tol, rtol=0.0):
continue
unique.append(shift)
if not unique:
return np.zeros((0, kdim), dtype=float)
return np.asarray(unique, dtype=float)
def extend_kspace_data(
k_points : np.ndarray,
data : Optional[np.ndarray] = None,
b1 : Optional[np.ndarray] = None,
b2 : Optional[np.ndarray] = None,
b3 : np.ndarray = None,
nx : int = 2,
ny : int = 2,
nz : int = 0,
*,
lattice = None,
reciprocal_vectors : Optional[Iterable[np.ndarray]] = None,
copies : Optional[Union[int, Iterable[int]]] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
r"""
Extend k-space points and optional data across translated Brillouin zones.
The helper works for arbitrary k-space dimensions and any number of
reciprocal translation vectors. Legacy ``b1``/``b2``/``b3`` with
``nx``/``ny``/``nz`` remain supported for existing callers.
Allows to generate extended k-point grids for plotting band structures along high-symmetry paths...
Parameters
----------
k_points : array_like, shape (Nk, dim) or (dim,)
Original Cartesian k-points.
data : array_like, shape (Nk, ...), optional
Data associated with each k-point. If omitted, only the extended
k-points are returned and the second return value is ``None``.
b1, b2, b3 : array_like, optional
Reciprocal lattice vectors.
nx, ny, nz : int, default=(2, 2, 0)
Copy counts for ``b1``, ``b2``, ``b3``.
---------------------------------
lattice : object, optional
Lattice-like object exposing ``dim`` and reciprocal vectors ``k1``, ``k2``, ``k3``.
Used when ``reciprocal_vectors`` and explicit ``b1``/``b2``/``b3`` are not given.
reciprocal_vectors : iterable of array_like, optional
General reciprocal translation vectors. If provided, these take
precedence over ``b1``/``b2``/``b3``.
copies : int or iterable of int, optional
Copy counts for ``reciprocal_vectors``. A scalar applies to every
reciprocal vector. If omitted, one copy is used for each provided
reciprocal vector.
Returns
-------
extended_k_points : np.ndarray
Extended k-points covering multiple BZs.
extended_data : np.ndarray or None
Corresponding data for the extended k-points, if ``data`` was given.
Example
-------
>>> k_points = np.array([[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]])
>>> b1 = np.array([1.0, 0.0, 0.0])
>>> b2 = np.array([0.0, 1.0, 0.0])
>>> extended_k, _ = extend_kspace_data(k_points, b1=b1, b2=b2, nx=1, ny=1)
>>> extended_k.shape
(18, 3)
"""
k_points = np.asarray(k_points, dtype=float)
if k_points.ndim == 1:
k_points = k_points.reshape(1, -1)
elif k_points.ndim != 2:
raise ValueError("k_points must have shape (Nk, dim) or (dim,)")
nk, kdim = k_points.shape
legacy_copies = [nx, ny, nz]
# Determine the reciprocal vectors to use and their copy counts, with proper precedence and normalization.
if reciprocal_vectors is None:
raw_vectors = _resolve_reciprocal_vector_inputs(lattice=lattice, b1=b1, b2=b2, b3=b3)
if copies is not None:
if np.isscalar(copies):
raw_copies = [int(copies)] * len(raw_vectors)
else:
raw_copies = [int(copy) for copy in copies]
else:
raw_copies = legacy_copies[:len(raw_vectors)]
else:
raw_vectors = _resolve_reciprocal_vector_inputs(reciprocal_vectors=reciprocal_vectors)
if copies is None:
raw_copies = [1] * len(raw_vectors)
elif np.isscalar(copies):
raw_copies = [int(copies)] * len(raw_vectors)
else:
raw_copies = [int(copy) for copy in copies]
# Normalize reciprocal vectors and copy counts, ensuring they match the k-space dimensionality and dropping zero vectors if needed.
active_vectors, active_copies = _normalize_reciprocal_vectors(raw_vectors, kdim, copies=raw_copies)
if any(copy < 0 for copy in active_copies):
raise ValueError("Copy counts must be non-negative")
# Generate the shifts for extending the k-points. Each shift is a linear combination of the active reciprocal vectors, with coefficients in the range defined by the copy counts.
if len(active_vectors) == 0:
shifts = np.zeros((1, kdim), dtype=float)
else:
coeff_axes = [np.arange(-copy, copy + 1, dtype=int) for copy in active_copies]
coeff_grid = np.stack(np.meshgrid(*coeff_axes, indexing='ij'), axis=-1).reshape(-1, len(active_vectors))
basis = np.vstack(active_vectors)
shifts = coeff_grid @ basis
# Apply the shifts to the original k-points to generate the extended k-point grid. The resulting shape will be (Nk * number_of_shifts, dim).
extended_k_points = (k_points[None, :, :] + shifts[:, None, :]).reshape(-1, kdim)
extended_data = None
# If data is provided, match it with the extended k-points by tiling it according to the number of shifts.
# The leading dimension of data must match the number of original k-points.
if data is not None:
data = np.asarray(data)
if nk == 1 and data.ndim == 0:
data = data.reshape(1)
if data.shape[0] != nk:
raise ValueError("data must have the same leading dimension as k_points")
tile_shape = (len(shifts),) + (1,) * max(data.ndim - 1, 0)
extended_data = np.tile(data, tile_shape)
return extended_k_points, extended_data
# -----------------------------------------------------------------------------------------------------------
#! HIGH-SYMMETRY POINTS AND PATHS
# -----------------------------------------------------------------------------------------------------------
[docs]
@dataclass
class HighSymmetryPoint:
r"""
A high-symmetry point in the Brillouin zone.
Attributes
----------
label : str
Label for the point (e.g., 'Gamma', 'K', 'M', 'X')
frac_coords : Tuple[float, float, float]
Fractional coordinates in reciprocal lattice units (f1, f2, f3).
The actual k-vector is: k = f1*b1 + f2*b2 + f3*b3
latex_label : str, optional
LaTeX-formatted label for plotting (e.g., r'$\\Gamma$')
description : str, optional
Description of the point
"""
label : str
frac_coords : Tuple[float, float, float]
latex_label : str = ""
description : str = ""
def __post_init__(self):
if not self.latex_label:
# Auto-generate LaTeX label
special_labels = {
'Gamma' : r'$\Gamma$', 'G': r'$\Gamma$',
'K' : r'$K$',
'M' : r'$M$', 'X': r'$X$', 'Y': r'$Y$', 'Z': r'$Z$',
'R' : r'$R$', 'A': r'$A$', 'L': r'$L$', 'H': r'$H$',
}
self.latex_label = special_labels.get(self.label, f'${self.label}$')
[docs]
def __contains__(self, coord: Union[Tuple[float, float, float], str]) -> bool:
"""Check if given fractional coordinates match this point."""
if isinstance(coord, str):
return coord == self.label
return np.allclose(self.frac_coords, coord)
[docs]
def to_cartesian(self, b1: np.ndarray, b2: np.ndarray, b3: np.ndarray) -> np.ndarray:
"""Convert fractional coordinates to Cartesian k-vector."""
f1, f2, f3 = self.frac_coords
return f1 * np.asarray(b1) + f2 * np.asarray(b2) + f3 * np.asarray(b3)
[docs]
def as_tuple(self) -> Tuple[str, List[float]]:
"""Return as (label, [f1, f2, f3]) tuple for path generation."""
return (self.latex_label, list(self.frac_coords))
[docs]
@dataclass
class HighSymmetryPoints:
"""
Collection of high-symmetry points for a lattice type.
Provides named access to standard high-symmetry points and defines
default paths through the Brillouin zone.
Example
-------
>>> pts = HighSymmetryPoints.square_2d()
>>> print(pts.Gamma) # HighSymmetryPoint for Gamma
>>> print(pts.default_path()) # ['Gamma', 'X', 'M', 'Gamma']
>>> print(pts.get_path_points(['Gamma', 'M'])) # Custom path
"""
points : Dict[str, HighSymmetryPoint] = field(default_factory=dict)
_default_path : List[str] = field(default_factory=list)
def __getattr__(self, name: str) -> HighSymmetryPoint:
if name.startswith('_') or name == 'points':
raise AttributeError(name)
if name in self.points:
return self.points[name]
raise AttributeError(f"No high-symmetry point named '{name}'")
def __contains__(self, name: str) -> bool:
return name in self.points
def __iter__(self):
return iter(self.points.values())
[docs]
def add(self, point: HighSymmetryPoint) -> 'HighSymmetryPoints':
"""Add a high-symmetry point."""
self.points[point.label] = point
return self
[docs]
def get(self, name: str) -> Optional[HighSymmetryPoint]:
"""Get a point by name, returns None if not found."""
return self.points.get(name)
@staticmethod
def _normalize_label(name: str) -> str:
"""Normalize common aliases for high-symmetry point labels."""
if name is None:
return ""
label = str(name).strip()
label = label.replace("Γ", "Gamma").replace("γ", "Gamma")
label = label.replace("’", "'")
if label in ("G", "g", "\\Gamma", "gamma", "Gamma"):
return "Gamma"
if label.endswith("'"):
label = f"{label[:-1]}p"
return label
[docs]
def resolve_label(self, name: str) -> Optional[str]:
"""
Resolve a label or alias to a canonical key in ``self.points``.
Examples: ``"Γ" -> "Gamma"``, ``"K'" -> "Kp"``.
"""
if not self.points:
return None
# First try direct normalization and lookup
norm = self._normalize_label(name)
if norm in self.points:
return norm
# Fallback: case-insensitive match against keys and normalized keys
norm_l = norm.lower()
for key in self.points:
if key.lower() == norm_l:
return key
if self._normalize_label(key).lower() == norm_l:
return key
return None
[docs]
def resolve(self, name: str) -> Optional[HighSymmetryPoint]:
"""Resolve a label/alias and return the matching point object."""
key = self.resolve_label(name)
if key is None:
return None
return self.points.get(key)
@property
def default_path(self) -> List[str]:
"""Return the default path through high-symmetry points."""
return self._default_path
[docs]
def get_path_points(self, path_labels: List[str]) -> List[Tuple[str, List[float]]]:
"""
Get path as list of (label, frac_coords) tuples.
Parameters
----------
path_labels : List[str]
List of point labels defining the path (e.g., ['Gamma', 'X', 'M', 'Gamma'])
Returns
-------
List[Tuple[str, List[float]]]
Path suitable for brillouin_zone_path() function
"""
path = []
for label in path_labels:
resolved = self.resolve_label(label)
if resolved is None:
raise ValueError(f"Unknown high-symmetry point: '{label}'. "
f"Available: {list(self.points.keys())}")
path.append(self.points[resolved].as_tuple())
return path
[docs]
def get_default_path_points(self) -> List[Tuple[str, List[float]]]:
"""Get the default path as list of (label, frac_coords) tuples."""
return self.get_path_points(self._default_path)
# -----------------------------------------------------------------
# Factory methods for common lattice types
# -----------------------------------------------------------------
[docs]
@classmethod
def chain_1d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for 1D chain."""
pts = cls(
_default_path=['Gamma', 'X', 'Gamma2']
)
pts.add(HighSymmetryPoint('Gamma', (0.0, 0.0, 0.0), r'$0$', '1D BZ center'))
pts.add(HighSymmetryPoint('X', (0.5, 0.0, 0.0), r'$\pi$', 'Zone boundary'))
pts.add(HighSymmetryPoint('Gamma2', (1.0, 0.0, 0.0), r'$2\pi$', 'Wrapped Gamma'))
return pts
[docs]
@classmethod
def square_2d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for 2D square lattice."""
pts = cls(
_default_path=['Gamma', 'X', 'M', 'Gamma']
)
pts.add(HighSymmetryPoint('Gamma', (0.0, 0.0, 0.0), r'$\Gamma$', 'BZ center'))
pts.add(HighSymmetryPoint('X', (0.5, 0.0, 0.0), r'$X$', 'Zone face center'))
pts.add(HighSymmetryPoint('M', (0.5, 0.5, 0.0), r'$M$', 'Zone corner'))
pts.add(HighSymmetryPoint('Y', (0.0, 0.5, 0.0), r'$Y$', 'Zone face center (y)'))
return pts
[docs]
@classmethod
def cubic_3d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for 3D cubic lattice."""
pts = cls(
_default_path=['Gamma', 'X', 'M', 'Gamma', 'R', 'X']
)
pts.add(HighSymmetryPoint('Gamma', (0.0, 0.0, 0.0), r'$\Gamma$', 'BZ center'))
pts.add(HighSymmetryPoint('X', (0.5, 0.0, 0.0), r'$X$', 'Face center'))
pts.add(HighSymmetryPoint('M', (0.5, 0.5, 0.0), r'$M$', 'Edge center'))
pts.add(HighSymmetryPoint('R', (0.5, 0.5, 0.5), r'$R$', 'Corner'))
return pts
[docs]
@classmethod
def triangular_2d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for 2D triangular lattice."""
pts = cls(
_default_path=['Gamma', 'M', 'K', 'Gamma']
)
pts.add(HighSymmetryPoint('Gamma', (0.0, 0.0, 0.0), r'$\Gamma$', 'BZ center'))
pts.add(HighSymmetryPoint('M', (0.5, 0.0, 0.0), r'$M$', 'Edge midpoint'))
pts.add(HighSymmetryPoint('K', (1/3, 1/3, 0.0), r'$K$', 'Corner (Dirac point)'))
pts.add(HighSymmetryPoint('Kp', (2/3, 1/3, 0.0), r"$K'$", 'Other Dirac point'))
return pts
[docs]
@classmethod
def honeycomb_2d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for honeycomb/graphene lattice."""
pts = cls(_default_path=['Gamma', 'K', 'M', 'Gamma'])
pts.add(HighSymmetryPoint('Gamma', (0.0, 0.0, 0.0), r'$\Gamma$', 'BZ center'))
pts.add(HighSymmetryPoint('K', (2.0/3.0, 1.0/3.0, 0.0), r'$K$', 'Dirac point'))
pts.add(HighSymmetryPoint('Kp', (1.0/3.0, 2.0/3.0, 0.0), r"$K'$", 'Other Dirac point'))
pts.add(HighSymmetryPoint('M', (0.5, 0.0, 0.0), r'$M$', 'Edge midpoint'))
return pts
[docs]
@classmethod
def hexagonal_2d(cls) -> 'HighSymmetryPoints':
"""High-symmetry points for 2D hexagonal lattice (same as honeycomb)."""
return cls.honeycomb_2d()
[docs]
class StandardBZPath(Enum):
r"""
Enumeration of standard high-symmetry paths in the Brillouin zone.
We define the k-space paths in a general representation of momentum vectors:
\[
k = f1 * b1 + f2 * b2 + f3 * b3,
\]
where (b1, b2, b3) are the reciprocal lattice vectors, and (f1, f2, f3) are the fractional coordinates,
f_i = n_i / N_i, with n_i = 0, 1, ..., N_i - 1 for each direction i.
Each value returns a list of (label, fractional_coord) pairs.
The fractional coordinates are expressed in units of reciprocal lattice vectors.
Example:
>>> path = StandardBZPath.SQUARE_2D.value
>>> for label, coord in path:
... print(f"{label}: {coord}")
G: [0.0, 0.0, 0.0]
X: [0.5, 0.0, 0.0]
M: [0.5, 0.5, 0.0]
G: [0.0, 0.0, 0.0]
"""
CHAIN_1D = [
("0", [0.0, 0.0, 0.0]),
(r"\pi", [0.5, 0.0, 0.0]),
(r"2\pi", [1.0, 0.0, 0.0])
]
SQUARE_2D = [
(r"$\Gamma$", [0.0, 0.0, 0.0]),
(r"$X$", [0.5, 0.0, 0.0]),
(r"$M$", [0.5, 0.5, 0.0]),
(r"$\Gamma$", [0.0, 0.0, 0.0])
]
TRIANGULAR_2D = [
(r"$\Gamma$", [0.0, 0.0, 0.0]),
(r"$M$", [0.5, 0.0, 0.0]),
(r"$K$", [1/3, 1/3, 0.0]),
(r"$\Gamma$", [0.0, 0.0, 0.0])
]
CUBIC_3D = [
(r"$\Gamma$", [0.0, 0.0, 0.0]),
(r"$X$", [0.5, 0.0, 0.0]),
(r"$M$", [0.5, 0.5, 0.0]),
(r"$R$", [0.5, 0.5, 0.5]),
(r"$\Gamma$", [0.0, 0.0, 0.0])
]
HONEYCOMB_2D = [
(r"$\Gamma$", [0.0, 0.0, 0.0]),
(r"$K$", [2/3, 1/3, 0.0]),
(r"$M$", [0.5, 0.0, 0.0]),
(r"$\Gamma$", [0.0, 0.0, 0.0])
]
PathTypes = Literal['CHAIN_1D', 'SQUARE_2D', 'TRIANGULAR_2D', 'CUBIC_3D', 'HONEYCOMB_2D']
# -----------------------------------------------------------------------------------------------------------
#! BRILLOUIN ZONE PATH GENERATION
# -----------------------------------------------------------------------------------------------------------
def resolve_path_input(path: Iterable[tuple[str, Iterable[float]]] | StandardBZPath | str | List[str] | HighSymmetryPoints, lattice: Optional[Lattice] = None) -> list[tuple[str, list[float]]]:
"""
Resolve path input to a list of (label, fractional_coord) pairs.
Parameters
----------
path : list[(label, coords)], StandardBZPath, str, List[str], or HighSymmetryPoints
Path definition (fractional coordinates), standard enum, enum name string,
list of point labels, or HighSymmetryPoints object.
lattice : Lattice, optional
Lattice object used to resolve labels if path is a list of strings.
Returns
-------
resolved_path : list[(label, list[float])]
Resolved path as a list of (label, fractional_coord) pairs.
Example
-------
>>> path = resolve_path_input("SQUARE_2D")
>>> for label, coord in path:
... print(f"{label}: {coord}")
"""
if isinstance(path, str):
path = StandardBZPath[path.upper()].value
elif isinstance(path, StandardBZPath):
path = path.value
elif isinstance(path, HighSymmetryPoints):
path = path.get_default_path_points()
elif isinstance(path, list) and len(path) > 0 and isinstance(path[0], str):
# List of strings, resolve via lattice if available
if lattice is not None and hasattr(lattice, 'high_symmetry_points'):
hs = lattice.high_symmetry_points()
if hs is not None:
path = hs.get_path_points(path)
else:
raise ValueError(f"Cannot resolve path labels {path} as lattice has no high-symmetry points defined.")
else:
raise ValueError(f"Cannot resolve path labels {path} without a lattice providing high-symmetry points.")
return [(label, list(map(float, frac))) for label, frac in path]
def generate_kgrid(lattice: Lattice, n_k: Iterable[int], shift: Optional[Union[bool, Tuple[bool, bool, bool]]] = None) -> np.ndarray:
"""
Generate a full k-point grid for the given lattice.
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors _k1, _k2, _k3.
n_k : Iterable[int]
Number of points (Lx, Ly, Lz) along each reciprocal direction.
We define the k-points as:
k = f1 * b1 + f2 * b2 + f3 * b3,
where f_i = n_i / N_i, with n_i = 0, 1, ..., N_i - 1.
Returns
-------
k_points : np.ndarray, shape (Nk, dim)
Cartesian coordinates of k-points in reciprocal space.
"""
recip = np.vstack([v for v in [lattice._k1, lattice._k2, lattice._k3] if v is not None]) # shape (dim, 3)
nk = np.array(list(n_k)) if isinstance(n_k, Iterable) else np.array([n_k] * lattice.dim)
grids = [np.arange(n) / n for n in nk] # fractional grids
mesh = np.meshgrid(*grids, indexing="ij") # meshgrid
frac = np.stack([m.ravel() for m in mesh], axis=-1) # shape (Nk, dim)
# define vectors in a full 3D array for matrix multiplication -> k=f1*b1 + f2*b2 + f3*b3
k_pts = frac @ recip # fractional -> cartesian
return k_pts
def brillouin_zone_path(lattice: Lattice, path: Iterable[tuple[str, Iterable[float]]] | StandardBZPath | List[str] | HighSymmetryPoints, *, points_per_seg : int = 40,) -> tuple[np.ndarray, np.ndarray, list[tuple[int, str]], np.ndarray]:
"""
Generate k-points along a specified Brillouin zone path. It takes
a list of (label, fractional_coord) pairs defining the path in reciprocal lattice units and
interpolates k-points along the straight lines connecting these points in Cartesian coordinates.
In general, if coordinates are given as c1 = (f1, f2, f3) and c2 = (g1, g2, g3)
we want to follow the straight line in Cartesian coordinates:
k(t) = (1-t) * (f1*b1 + f2*b2 + f3*b3) + t * (g1*b1 + g2*b2 + g3*b3)
for t in [0, 1], t = 0, 1/points_per_seg, 2/points_per_seg, ..., (points_per_seg-1)/points_per_seg.
Each segment between two labels is interpolated with `points_per_seg` points.
Example:
>>> path = StandardBZPath.SQUARE_2D.value
>>> # define the interpolated path
>>> k_path, k_dist, labels = brillouin_zone_path(lattice, path, points_per_seg=10) # 10 points per segment for demonstration
>>> print("k-path shape:", k_path.shape)
k-path shape: (30, 3)
>>> print("k-dist shape:", k_dist.shape)
k-dist shape: (30,)
>>> print("Labels:", labels)
Labels: [(0, 'G'), (10, 'X'), (20, 'M'), (30, 'G')]
>>> print("First 3 k-points:\n", k_path[:3])
First 3 k-points:
[[0. 0. 0.] # = 0.0 * b1 + 0.0 * b2 + 0.0 * b3
[0.05 0. 0. ] # = 0.1 * b1 + 0.0 * b2 + 0.0 * b3
[0.1 0. 0. ]] # = 0.2 * b1 + 0.0 * b2 + 0.0 * b3
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors (_k1, _k2, _k3).
path : list[(label, coords)], StandardBZPath, List[str], or HighSymmetryPoints
Path definition (fractional coordinates), one of the standard enums,
list of symmetry point labels, or HighSymmetryPoints object.
points_per_seg : int
Number of interpolated points between labels.
Returns
-------
k_path : np.ndarray, shape (Npath, 3)
k-points along the path.
k_dist : np.ndarray, shape (Npath,)
Cumulative distance for x-axis plotting.
labels : list[(int, str)]
Indices and labels for symmetry points.
"""
path = resolve_path_input(path, lattice=lattice)
# Reciprocal lattice matrix: columns = b1, b2, b3
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
B = np.column_stack([b1, b2, b3]) # (3,3)
cart_pts = []
frac_pts = []
for _, frac in path:
f = np.zeros(3, float)
f[:len(frac)] = np.array(frac, float)
k_cart = B @ f
cart_pts.append(k_cart)
frac_pts.append(f)
k_path = []
k_path_frac = []
k_dist = [0.0]
labels = [(0, path[0][0])]
for i in range(len(cart_pts) - 1):
p0, p1 = cart_pts[i], cart_pts[i + 1]
f0, f1 = frac_pts[i], frac_pts[i + 1]
nseg = points_per_seg
seg_cart = np.linspace(p0, p1, nseg, endpoint=False)
seg_frac = np.linspace(f0, f1, nseg, endpoint=False)
for j in range(nseg):
k = seg_cart[j]
if k_path:
dk = np.linalg.norm(k - k_path[-1])
k_dist.append(k_dist[-1] + dk)
k_path.append(k)
k_path_frac.append(seg_frac[j])
labels.append((len(k_path), path[i + 1][0]))
k_path = np.array(k_path)
k_path_frac = np.array(k_path_frac)
k_dist = np.array(k_dist)
return k_path, k_dist, labels, k_path_frac
@dataclass
class KPathSelection:
"""
Ideal Brillouin-zone path with optional nearest matches on an existing k-grid.
This container is intended for visualization and path-selection logic where
the path geometry is needed even when no sampled k-grid exists yet.
"""
path_cart : np.ndarray
path_frac : np.ndarray
k_dist : np.ndarray
labels : List[Tuple[int, str]]
matched_cart : Optional[np.ndarray] = None
matched_frac : Optional[np.ndarray] = None
matched_grid_indices : np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
matched_indices : np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
matched_distances : np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
match_tolerance : float = 0.0
@property
def has_matches(self) -> bool:
"""Whether this path was matched to an existing k-grid."""
return self.matched_cart is not None and len(self.matched_indices) > 0
def unique_match_positions(self) -> np.ndarray:
"""Return path positions of unique matched k-points, preserving path order."""
if not self.has_matches:
return np.array([], dtype=int)
key_indices = self.matched_grid_indices if len(self.matched_grid_indices) > 0 else self.matched_indices
_, first = np.unique(key_indices, return_index=True)
return np.sort(first.astype(int))
def unique_matched_cart(self) -> np.ndarray:
"""Return unique matched Cartesian k-points in path order."""
if not self.has_matches:
return np.zeros((0, 3), dtype=float)
return self.matched_cart[self.unique_match_positions()]
# -----------------------------------------------------------------------------------------------------------
#! PATH SELECTION ON EXISTING K-GRID
# -----------------------------------------------------------------------------------------------------------
def _cartesian_from_fractional(lattice: Lattice, frac_vectors: np.ndarray) -> np.ndarray:
"""Convert fractional reciprocal coordinates to Cartesian k-vectors."""
frac = np.asarray(frac_vectors, dtype=float)
if frac.ndim == 1:
frac = frac.reshape(1, -1)
if frac.shape[1] < 3:
frac = np.pad(frac, ((0, 0), (0, 3 - frac.shape[1])))
elif frac.shape[1] > 3:
frac = frac[:, :3]
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
B = np.column_stack([b1, b2, b3])
return frac @ B.T
def _fractional_from_cartesian(lattice: Lattice, k_vectors: np.ndarray) -> np.ndarray:
"""Convert Cartesian k-vectors to fractional reciprocal coordinates."""
k_cart = np.asarray(k_vectors, dtype=float)
if k_cart.ndim == 1:
k_cart = k_cart.reshape(1, -1)
if k_cart.shape[1] < 3:
k_cart = np.pad(k_cart, ((0, 0), (0, 3 - k_cart.shape[1])))
elif k_cart.shape[1] > 3:
k_cart = k_cart[:, :3]
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
B = np.column_stack([b1, b2, b3])
B_pinv = np.linalg.pinv(B)
return (B_pinv @ k_cart.T).T
def _default_kpath_tolerance(lattice: Lattice) -> float:
"""Estimate a reasonable fractional-coordinate tolerance from the lattice size."""
Lx = max(getattr(lattice, "_lx", 1), 1)
Ly = max(getattr(lattice, "_ly", 1), 1)
Lz = max(getattr(lattice, "_lz", 1), 1)
return 0.5 * np.sqrt((1 / Lx) ** 2 + (1 / Ly) ** 2 + (1 / Lz) ** 2)
def _kpath_tolerance_to_cartesian(lattice: Lattice, tol: float) -> float:
"""Map a fractional k-path tolerance to a conservative Cartesian tolerance."""
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
B = np.column_stack([b1, b2, b3])
return float(np.linalg.norm(B, ord=2) * float(tol))
def _path_covering_copies(k_grid_frac: np.ndarray, path_frac: np.ndarray, *, tol: float = 1e-12, dim: int = 3) -> Tuple[int, ...]:
"""Return symmetric reciprocal-copy counts needed to cover a path range."""
if dim <= 0:
return tuple()
kf = np.asarray(k_grid_frac, dtype=float).reshape(-1, 3)[:, :dim]
pf = np.asarray(path_frac, dtype=float).reshape(-1, 3)[:, :dim]
if len(kf) == 0 or len(pf) == 0:
return tuple(0 for _ in range(dim))
grid_min = np.min(kf, axis=0)
grid_max = np.max(kf, axis=0)
path_min = np.min(pf, axis=0)
path_max = np.max(pf, axis=0)
copies = []
for idx in range(dim):
upper = max(0.0, path_max[idx] - grid_max[idx] - tol)
lower = max(0.0, grid_min[idx] - path_min[idx] - tol)
copies.append(int(np.ceil(max(upper, lower))))
return tuple(copies)
def _extend_kgrid_to_path(
lattice : Lattice,
k_cart : np.ndarray,
k_frac : np.ndarray,
path_frac : np.ndarray,
*,
tol : float = 1e-12,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extend an existing k-grid just enough to cover a requested path."""
kf_flat = np.asarray(k_frac, dtype=float).reshape(-1, 3)
kc_flat = np.asarray(k_cart, dtype=float).reshape(-1, 3)
dim = int(getattr(lattice, "dim", 0) or min(3, kf_flat.shape[1], path_frac.shape[1]))
dim = max(1, min(dim, 3))
copies = _path_covering_copies(kf_flat, path_frac, tol=tol, dim=dim)
if not any(copies):
base_indices = np.arange(len(kf_flat), dtype=int)
return kc_flat, kf_flat, base_indices, base_indices.copy()
ext_frac, _ = extend_kspace_data(kf_flat[:, :dim], reciprocal_vectors=np.eye(dim, dtype=float), copies=copies)
ext_cart = _cartesian_from_fractional(lattice, ext_frac)
if dim < 3:
ext_frac = np.pad(ext_frac, ((0, 0), (0, 3 - dim)))
ext_cart = np.pad(ext_cart[:, :dim], ((0, 0), (0, 3 - dim)))
n_shifts = len(ext_frac) // max(len(kf_flat), 1)
base_indices = np.tile(np.arange(len(kf_flat), dtype=int), n_shifts)
grid_indices = np.arange(len(ext_frac), dtype=int)
return ext_cart, ext_frac, base_indices, grid_indices
def _polyline_point_distances(points: np.ndarray, polyline: np.ndarray) -> np.ndarray:
"""Return the shortest Cartesian distance from each point to a polyline."""
pts = np.asarray(points, dtype=float)
path = np.asarray(polyline, dtype=float)
if pts.ndim == 1:
pts = pts.reshape(1, -1)
if path.ndim == 1:
path = path.reshape(1, -1)
if len(pts) == 0:
return np.zeros(0, dtype=float)
if len(path) == 0:
return np.full(len(pts), np.inf, dtype=float)
if len(path) == 1:
return np.linalg.norm(pts - path[0], axis=1)
distances = np.full(len(pts), np.inf, dtype=float)
for start, end in zip(path[:-1], path[1:]):
segment = end - start
seg_norm_sq = float(np.dot(segment, segment))
if seg_norm_sq <= 0.0:
candidate = np.linalg.norm(pts - start, axis=1)
else:
rel = pts - start
proj = np.clip((rel @ segment) / seg_norm_sq, 0.0, 1.0)
closest = start + proj[:, None] * segment
candidate = np.linalg.norm(pts - closest, axis=1)
distances = np.minimum(distances, candidate)
return distances
def find_nearest_kpoints(k_grid_frac: np.ndarray, target_frac: np.ndarray, tol: float = 0.5, *, periodic: bool = True) -> Tuple[np.ndarray, np.ndarray]:
r"""
Find nearest k-point indices for target fractional coordinates.
Parameters
----------
k_grid_frac : np.ndarray, shape (Nk, 3)
Fractional coordinates of available k-points
target_frac : np.ndarray, shape (Ntarget, 3)
Target fractional coordinates to match
tol : float
Warning threshold for match distance
periodic : bool, default=True
If True, measure distance modulo reciprocal-lattice translations.
If False, use direct fractional-coordinate distance without wrapping.
Returns
-------
indices : np.ndarray, shape (Ntarget,), dtype=int
Index of nearest k-point for each target
distances : np.ndarray, shape (Ntarget,)
Distance to nearest point (in fractional units, accounting for periodicity)
"""
k_grid_frac = np.asarray(k_grid_frac, dtype=float)
target_frac = np.asarray(target_frac, dtype=float)
if k_grid_frac.ndim == 1:
k_grid_frac = k_grid_frac.reshape(1, -1)
if target_frac.ndim == 1:
target_frac = target_frac.reshape(1, -1)
if k_grid_frac.shape[1] < 3:
k_grid_frac = np.pad(k_grid_frac, ((0, 0), (0, 3 - k_grid_frac.shape[1])))
elif k_grid_frac.shape[1] > 3:
k_grid_frac = k_grid_frac[:, :3]
if target_frac.shape[1] < 3:
target_frac = np.pad(target_frac, ((0, 0), (0, 3 - target_frac.shape[1])))
elif target_frac.shape[1] > 3:
target_frac = target_frac[:, :3]
n_targets = len(target_frac)
indices = np.zeros(n_targets, dtype=int)
distances = np.zeros(n_targets)
for i, kf_target in enumerate(target_frac):
diff = k_grid_frac - kf_target
if periodic:
diff -= np.round(diff)
dist = np.linalg.norm(diff, axis=1)
idx = np.argmin(dist)
indices[i] = idx
distances[i] = dist[idx]
if dist[idx] > tol:
continue # Optionally log a warning about unmatched point
return indices, distances
def bz_path_points(
lattice,
path : Iterable[tuple[str, Iterable[float]]] | StandardBZPath | HighSymmetryPoints | None = None,
*,
points_per_seg : int = 40,
k_vectors : Optional[np.ndarray] = None,
k_vectors_frac : Optional[np.ndarray] = None,
tol : Optional[float] = None,
periodic : bool = True,
) -> KPathSelection:
"""
Build an ideal Brillouin-zone path and optionally match it to an existing k-grid.
If no k-grid is provided, the returned object still contains the continuous
path geometry, which is useful for plotting or for constructing a path that
is not constrained to the sampled reciprocal mesh. When a sampled grid is
provided, it is automatically extended by reciprocal-lattice translations
if needed so paths through higher Brillouin-zone copies can still be matched
against the existing data.
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors and optionally high-symmetry points.
path : list[(label, coords)], StandardBZPath, HighSymmetryPoints, or None
Path definition (fractional coordinates), one of the standard enums, HighSymmetryPoints object
or None to use lattice default path.
points_per_seg : int
Number of interpolated points between labels for the ideal path.
k_vectors : np.ndarray, optional
Cartesian k-vectors of the existing grid to match against.
k_vectors_frac : np.ndarray, optional
Fractional k-vectors of the existing grid to match against (in reciprocal lattice units).
tol : float, optional
Tolerance for matching path points to the grid. With
``periodic=True`` it is interpreted in fractional reciprocal
coordinates. With ``periodic=False`` it is interpreted in the plotted
Cartesian reciprocal coordinates.
periodic : bool, default=True
Whether path matching should identify reciprocal-translation-equivalent
k-points as the same point. Use ``False`` for visual matching in the
displayed Brillouin-zone copy.
Returns
-------
KPathSelection
Object containing the ideal path geometry and matched grid points if available.
"""
if path is None:
if hasattr(lattice, 'high_symmetry_points'):
hs_pts = lattice.high_symmetry_points()
if hs_pts is not None:
path = hs_pts.get_default_path_points()
if path is None:
raise ValueError("No path specified and lattice has no default path. Specify path explicitly or use a lattice with high_symmetry_points().")
elif isinstance(path, HighSymmetryPoints):
path = path.get_default_path_points()
# Generate the ideal path geometry in Cartesian and fractional coordinates
path_cart, k_dist, labels, path_frac = brillouin_zone_path(lattice=lattice, path=path, points_per_seg=points_per_seg)
if k_vectors is None:
k_vectors = getattr(lattice, "kvectors", None)
if k_vectors_frac is None:
k_vectors_frac = getattr(lattice, "kvectors_frac", None)
if k_vectors is None and k_vectors_frac is None:
return KPathSelection(path_cart=path_cart, path_frac=path_frac, k_dist=k_dist, labels=labels)
if k_vectors is None:
k_vectors = _cartesian_from_fractional(lattice, k_vectors_frac)
if k_vectors_frac is None:
k_vectors_frac = _fractional_from_cartesian(lattice, k_vectors)
kc_flat = np.asarray(k_vectors, dtype=float).reshape(-1, 3)
kf_flat = np.asarray(k_vectors_frac, dtype=float).reshape(-1, 3)
tol = _default_kpath_tolerance(lattice) if tol is None else float(tol)
# Extend the sampled grid to cover higher-zone path coordinates when needed,
# while still tracking indices back to the original data grid.
ext_kc, ext_kf, base_indices, grid_indices = _extend_kgrid_to_path(
lattice,
kc_flat,
kf_flat,
path_frac,
tol=tol,
)
# Find nearest k-points on the grid for each point along the path
indices, match_distances = find_nearest_kpoints(ext_kf, path_frac, tol=tol, periodic=periodic and len(ext_kf) == len(kf_flat))
match_tolerance = tol
if not periodic:
match_distances = _polyline_point_distances(ext_kc[indices], path_cart)
match_tolerance = _kpath_tolerance_to_cartesian(lattice, tol)
return KPathSelection(path_cart=path_cart, path_frac=path_frac, k_dist=k_dist, labels=labels,
matched_cart=ext_kc[indices], matched_frac=ext_kf[indices],
matched_grid_indices=grid_indices[indices], matched_indices=base_indices[indices],
matched_distances=match_distances,
match_tolerance=match_tolerance)
[docs]
@dataclass
class KPathResult:
"""
Result of extracting data along a k-path in the Brillouin zone.
This dataclass holds all information needed for band structure plots and
analysis along a high-symmetry path.
Attributes
----------
k_cart : np.ndarray, shape (Npath, 3)
Cartesian k-vectors along the path
k_frac : np.ndarray, shape (Npath, 3)
Fractional k-vectors along the path (in reciprocal lattice units)
k_dist : np.ndarray, shape (Npath,)
Cumulative distance along the path for x-axis plotting
labels : List[Tuple[int, str]]
List of (index, label) pairs for high-symmetry points
values : np.ndarray
Data values along the path. The path axis is ``path_axis``. Examples:
``(Npath,)``, ``(Npath, n_bands)``, ``(Nw, Npath)``, or
``(Nw, Npath, n_bands)``.
indices : np.ndarray, shape (Npath,), dtype=int
Indices into the original k-grid for each path point.
Use to map path data back to the full k-grid.
matched_distances : np.ndarray, shape (Npath,)
Distance from ideal path point to matched grid point (for quality check)
Example
-------
>>> result = lattice.extract_kpath_data(energies, path='SQUARE_2D')
>>> plt.plot(result.k_dist, result.values)
>>> for idx, label in result.labels:
... plt.axvline(result.k_dist[min(idx, len(result.k_dist)-1)], label=label)
"""
k_cart : np.ndarray
k_frac : np.ndarray
k_dist : np.ndarray
labels : List[Tuple[int, str]]
values : np.ndarray
indices : np.ndarray
matched_distances : np.ndarray = field(default_factory=lambda: np.array([]))
path_axis : int = 0
@property
def n_points(self) -> int:
"""Number of points along the path."""
return len(self.k_dist)
@property
def n_bands(self) -> int:
"""Number of trailing channels per path point, flattening axes after ``path_axis``."""
tail_shape = self.values.shape[self.path_axis + 1:]
return int(np.prod(tail_shape)) if len(tail_shape) > 0 else 1
@property
def label_positions(self) -> np.ndarray:
"""X-axis positions (k_dist values) of the high-symmetry point labels."""
positions = []
for idx, _ in self.labels:
pos_idx = min(idx, len(self.k_dist) - 1) if len(self.k_dist) > 0 else 0
positions.append(self.k_dist[pos_idx] if len(self.k_dist) > 0 else 0.0)
return np.array(positions)
@property
def label_texts(self) -> List[str]:
"""Just the label strings for plotting."""
return [label for _, label in self.labels]
[docs]
def unique_indices(self) -> np.ndarray:
"""Return unique k-point indices (no duplicates from path segments)."""
return np.unique(self.indices)
[docs]
def max_match_distance(self) -> float:
"""Maximum distance from path to matched grid point."""
if len(self.matched_distances) == 0:
return 0.0
return float(np.max(self.matched_distances))
def bz_path_data(
lattice,
k_vectors : np.ndarray,
k_vectors_frac : np.ndarray,
values : np.ndarray,
path : Iterable[tuple[str, Iterable[float]]] | StandardBZPath | HighSymmetryPoints | None = None,
*,
points_per_seg : int = 40,
return_result : bool = True,
) -> KPathResult | Tuple[np.ndarray, np.ndarray, np.ndarray, List[Tuple[int, str]], np.ndarray]:
"""
Extract k-path data from a k-grid using fractional coordinate matching.
This function finds the closest k-points on the actual grid to an ideal path
through high-symmetry points. It handles periodic boundary conditions in k-space
and automatically reuses reciprocal-lattice copies of the sampled grid when the
requested path lies in an extended Brillouin-zone region.
It also allows to return a structured KPathResult dataclass or a tuple...
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors
k_vectors : np.ndarray, shape (..., 3)
Cartesian k-points (will be flattened)
k_vectors_frac : np.ndarray, shape (..., 3)
Fractional coordinates of k-points (will be flattened)
values : np.ndarray
Data values sampled on the k-grid. The k-grid axes may appear either as
the leading axes (e.g. ``(Lx, Ly, Lz, n_bands)``) or after batch axes
(e.g. ``(Nw, Lx, Ly, Lz)`` or ``(Nw, Lx, Ly, Lz, n_bands)``). A single
flattened k-grid axis of length ``Nk`` is also supported.
path : various, optional
Path specification. Can be:
- StandardBZPath enum value (e.g., StandardBZPath.SQUARE_2D)
- String name (e.g., 'SQUARE_2D')
- List of (label, [f1,f2,f3]) tuples
- HighSymmetryPoints object (uses default path)
- None: uses lattice's default path if available
points_per_seg : int
Number of interpolated points per path segment
return_result : bool
If True (default), return KPathResult dataclass.
If False, return tuple for backwards compatibility.
Returns
-------
KPathResult or tuple
If return_result=True: KPathResult dataclass with all path data. The
returned ``values`` preserve any leading batch axes and replace the
sampled k-grid axes with a path axis.
If return_result=False: (k_cart, k_frac, k_dist, labels, values) tuple
Examples
--------
>>> # Using default path from HighSymmetryPoints
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, HighSymmetryPoints.square_2d())
>>> plt.plot(result.k_dist, result.values)
>>> # Using standard path enum
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, 'SQUARE_2D')
>>> # Custom path
>>> custom_path = [('G', [0,0,0]), ('X', [0.5,0,0]), ('G', [0,0,0])]
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, custom_path)
>>>
>>> # Frequency-resolved values with shape (Nw, Lx, Ly, Lz)
>>> result_w = bz_path_data(lattice, k_grid, k_frac, S_qw, custom_path)
>>> result_w.values.shape
(Nw, result_w.n_points)
"""
k_vectors_arr = np.asarray(k_vectors)
kgrid_shape = tuple(int(dim) for dim in k_vectors_arr.shape[:-1])
if len(kgrid_shape) == 0:
kgrid_shape = (int(k_vectors_arr.reshape(-1, 3).shape[0]),)
nk = int(np.prod(kgrid_shape))
values_arr = np.asarray(values)
if values_arr.ndim == 0:
raise ValueError("values must have at least one axis corresponding to the sampled k-grid")
grid_axis = None
grid_span = None
for start in range(values_arr.ndim - len(kgrid_shape) + 1):
if tuple(values_arr.shape[start:start + len(kgrid_shape)]) == kgrid_shape:
grid_axis = start
grid_span = len(kgrid_shape)
break
if grid_axis is None:
for axis, axis_size in enumerate(values_arr.shape):
if int(axis_size) == nk:
grid_axis = axis
grid_span = 1
break
if grid_axis is None or grid_span is None:
raise ValueError(
f"Could not identify k-grid axes {kgrid_shape} in values with shape {values_arr.shape}. "
"Expected values to contain the sampled k-grid axes contiguously or as a flattened Nk axis."
)
batch_shape = values_arr.shape[:grid_axis]
feature_shape = values_arr.shape[grid_axis + grid_span:]
val_flat = values_arr.reshape(batch_shape + (nk,) + feature_shape)
# Select path points and find nearest k-points on the grid
selection: KPathSelection = bz_path_points(lattice, path=path, points_per_seg=points_per_seg, k_vectors=k_vectors, k_vectors_frac=k_vectors_frac,)
if not selection.has_matches:
raise ValueError("Path selection requires sampled k-vectors to extract path data.")
# Extract values for the matched k-points along the path
vals_sel = np.take(val_flat, selection.matched_indices, axis=len(batch_shape))
# Optionally return a structured dataclass with all path information, or a tuple for backwards compatibility
if return_result:
return KPathResult(
k_cart = selection.matched_cart,
k_frac = selection.matched_frac,
k_dist = selection.k_dist,
labels = selection.labels,
values = vals_sel,
indices = selection.matched_indices,
matched_distances = selection.matched_distances,
path_axis = len(batch_shape),
)
else:
return selection.matched_cart, selection.matched_frac, selection.k_dist, selection.labels, vals_sel
# -----------------------------------------------------------------------------------------------------------
#! BLOCH TRANSFORMATION
# -----------------------------------------------------------------------------------------------------------
@dataclass
class BlochTransformCache:
r"""
Cache for Bloch transformation matrices to avoid recomputation.
Attributes
----------
W : np.ndarray
Bloch projector matrix, shape (Nc, Ns, Nb)
W[ik, i, a] = (1/sqrt Nc) * exp(-ik\cdot r_i) * delta_{sub(i),a}
W_conj : np.ndarray
Complex conjugate of W for efficiency
kpoints : np.ndarray
K-point grid used for this cache, shape (Nc, 3)
kgrid : np.ndarray
Structured k-grid, shape (Lx, Ly, Lz, 3)
kgrid_frac : np.ndarray
Fractional k-grid coordinates, shape (Lx, Ly, Lz, 3)
lattice_hash : int
Hash of lattice parameters to detect changes
"""
W : np.ndarray
W_conj : np.ndarray
kpoints : np.ndarray
kgrid : np.ndarray
kgrid_frac : np.ndarray
lattice_hash : int
# Global cache dictionary: lattice_id -> BlochTransformCache
_bloch_cache: Dict[int, BlochTransformCache] = {}
def _get_lattice_hash(lattice: 'Lattice') -> int:
"""Generate a hash from lattice parameters to detect changes.
Includes boundary flux so the cache is invalidated when flux changes.
"""
flux_hash = ()
if hasattr(lattice, '_flux') and lattice._flux is not None:
flux_hash = tuple(lattice._flux.as_array())
return hash((
lattice._lx, lattice._ly, lattice._lz,
len(lattice._basis),
tuple(lattice._a1.flatten()),
tuple(lattice._a2.flatten()),
tuple(lattice._a3.flatten()),
tuple(lattice._k1.flatten()),
tuple(lattice._k2.flatten()),
tuple(lattice._k3.flatten()),
flux_hash,
))
def _get_bloch_transform_cache(lattice: 'Lattice', unitary_norm: bool = True) -> BlochTransformCache:
"""
Get or create cached Bloch transformation matrices.
Parameters
----------
lattice : Lattice
Lattice object
unitary_norm : bool
Whether to use unitary normalization (1/sqrt Nc)
Returns
-------
cache : BlochTransformCache
Cached transformation matrices
"""
lattice_hash = _get_lattice_hash(lattice)
lattice_id = id(lattice)
# 1. cache reuse
if lattice_id in _bloch_cache:
cache = _bloch_cache[lattice_id]
if cache.lattice_hash == lattice_hash:
return cache
# 2. lattice sizes
Lx, Ly, Lz = lattice._lx, max(lattice._ly, 1), max(lattice._lz, 1)
Nc = Lx * Ly * Lz
Nb = len(lattice._basis)
Ns = lattice.Ns
# 3. reciprocal basis and k-grid (same as calculate_dft_matrix)
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
frac_x = np.linspace(0.0, 1.0, Lx, endpoint=False)
frac_y = np.linspace(0.0, 1.0, Ly, endpoint=False)
frac_z = np.linspace(0.0, 1.0, Lz, endpoint=False)
# Apply flux-induced shift when boundary fluxes are present
if hasattr(lattice, '_flux_frac_shift'):
dfx, dfy, dfz = lattice._flux_frac_shift()
frac_x = frac_x + dfx
frac_y = frac_y + dfy
frac_z = frac_z + dfz
kx_frac, ky_frac, kz_frac = np.meshgrid(frac_x, frac_y, frac_z, indexing="ij")
kgrid_frac = np.stack([kx_frac, ky_frac, kz_frac], axis=-1) # (Lx,Ly,Lz,3)
kgrid = (kx_frac[..., None] * b1
+ ky_frac[..., None] * b2
+ kz_frac[..., None] * b3) # (Lx,Ly,Lz,3)
kpoints = kgrid.reshape(-1, 3) # (Nc,3)
# 4. real-space Bravais vectors and sublattice indices
R_cells = np.asarray(lattice.cells, float) # (Ns,3)
if R_cells.shape[0] != Ns:
raise ValueError("Mismatch in number of sites and lattice.cells.")
sub_idx = np.asarray(lattice.subs, dtype=int) # (Ns,)
if sub_idx.shape[0] != Ns:
raise ValueError("Mismatch in number of sites and lattice.subs.")
# Projector S[i, alpha] = delta_{beta_i, alpha}
S = np.zeros((Ns, Nb), dtype=complex)
S[np.arange(Ns), sub_idx] = 1.0
# 5. Bloch projectors: W[k,i,alpha] = exp(-i k·R_i) / sqrt(Nc) * S[i,alpha]
phases = np.exp(-1j * (kpoints @ R_cells.T)) # (Nc,Ns)
if unitary_norm:
phases /= np.sqrt(Nc)
W = phases[:, :, None] * S[None, :, :] # (Nc,Ns,Nb)
W_conj = W.conj()
cache = BlochTransformCache(
W = W,
W_conj = W_conj,
kpoints = kpoints,
kgrid = kgrid,
kgrid_frac = kgrid_frac,
lattice_hash= lattice_hash,
)
_bloch_cache[lattice_id] = cache
return cache
# -----------------------------------------------------------------------------------------------------------
#? Reciprocal Lattice Vectors
# -----------------------------------------------------------------------------------------------------------
def reciprocal_from_real(a1: np.ndarray, a2: np.ndarray, a3: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute reciprocal lattice vectors from real-space lattice vectors.
b1,b2,(b3) satisfying a_i . b_j = 2*pi*delta_ij.
Parameters
----------
a1, a2, a3 : np.ndarray
Real-space lattice vectors.
Returns
-------
k1, k2, k3 : np.ndarray
Reciprocal lattice vectors.
"""
A = np.column_stack([a1[:3], a2[:3], (a3 if a3 is not None else np.array([0.,0.,1.]))[:3]])
B = 2.0 * np.pi * np.linalg.inv(A).T
b1, b2, b3 = B[:,0], B[:,1], B[:,2]
return b1, b2, b3
def extract_momentum(eigvecs : np.ndarray,
lattice : 'Lattice',
*,
eigvals : np.ndarray = None,
tol : float = 1e-10,
):
"""
Extract crystal momentum vectors k from real-space eigenvectors.
Parameters
----------
eigvecs : np.ndarray
Eigenvectors in real-space basis, shape (Ns, n_states)
eigvals : np.ndarray, optional
Corresponding eigenvalues, shape (n_states,). Required if degeneracies are to be resolved.
tol : float
Degeneracy tolerance. States within |E_i - E_j| < tol are treated as degenerate.
Returns
-------
k_vectors : np.ndarray
Array of shape (n_states, dim), containing crystal momenta for each eigenstate.
"""
# Translation operators (list of matrices, one per dimension)
T_ops = lattice.translation_operators()
dim = len(T_ops)
# Precompute real and reciprocal bases
A = np.column_stack([lattice._a1, lattice._a2, lattice._a3])[:, :dim]
Ns, n_states = eigvecs.shape
k_vectors = np.zeros((n_states, dim), dtype=float)
# Case 1: No degeneracy information (just extract phases directly)
if eigvals is None:
for q in range(n_states):
psi = eigvecs[:, q]
thetas = []
for Ti in T_ops:
phase = np.vdot(psi, Ti @ psi) / np.vdot(psi, psi)
thetas.append(np.angle(phase))
# Solve theta = A^T * k -> k = (A^T)^{-1} * theta
kvec = np.linalg.solve(A.T, thetas)
k_vectors[q, :] = kvec % (2 * np.pi)
return k_vectors
# Case 2: Degeneracy-aware version
used = np.zeros(n_states, dtype=bool)
for i in range(n_states):
if used[i]:
continue
# find degenerate subspace
mask = np.abs(eigvals - eigvals[i]) < tol
used[mask] = True
subspace = eigvecs[:, mask]
if subspace.shape[1] == 1:
# non-degenerate state
psi = subspace[:, 0]
thetas = []
for Ti in T_ops:
phase = np.vdot(psi, Ti @ psi) / np.vdot(psi, psi)
thetas.append(np.angle(phase))
# Solve theta = A^T * k -> k = (A^T)^{-1} * theta
kvec = np.linalg.solve(A.T, thetas)
k_vectors[i, :] = kvec % (2 * np.pi)
else:
# degenerate subspace: diagonalize translations
# we need to extract multiple k-vectors
# each Ti gives phases along direction i
# mathematically, we diagonalize Ti in the subspace
# to get eigenvalues exp(i * theta_i)
# and eigenvectors give the k-vectors
for Ti in T_ops:
Ti_sub = subspace.conj().T @ (Ti @ subspace)
evals, evecsT = np.linalg.eig(Ti_sub)
phases = np.angle(evals)
# Each eigenvalue gives one momentum along direction i
for j, phi in enumerate(phases):
# Insert per subspace component
if j + i < n_states:
if k_vectors[j + i, :].any():
# Combine existing info
k_vectors[j + i, np.argmax(k_vectors[j + i, :] == 0)] = phi
else:
k_vectors[j + i, 0] = phi
# Solve theta_i = k * a_i for each state
for q in range(n_states):
kvec = np.linalg.solve(A.T, k_vectors[q, :dim])
k_vectors[q, :] = kvec % (2 * np.pi)
return k_vectors
# -------------------------------------------------------------------------------------------
#? Single site translation operators
# -------------------------------------------------------------------------------------------
def build_translation_operators(lattice: 'Lattice') -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Construct translation matrices (T1, T2, T3) acting on the full real-space basis.
Parameters
----------
lattice : Lattice
Lattice object with attributes:
- a1, a2, a3 : Real-space lattice vectors : shape (3,dim).
- b1, b2, b3 : Reciprocal lattice vectors : shape (3,dim).
- basis : np.ndarray of shape (n_basis, 3) with fractional coordinates of basis sites.
- Lx, Ly, Lz : Number of unit cells along each direction
- boundary_phase_from_winding(wx, wy, wz) : Method returning phase factor for given windings.
Returns
-------
T1, T2, T3 : np.ndarray
Translation matrices (Ns x Ns), each a complex permutation unitary
shifting states by +a1, +a2, +a3 respectively.
Notes
-----
- Works for 1D, 2D, or 3D (dimension inferred from provided a_vectors).
- Periodic boundaries are implemented automatically.
- Sublattice permutations are handled by matching translated site positions
modulo the lattice vectors.
"""
dim = lattice.dim
n_basis = lattice.multipartity
Ns = lattice.Ns
Lx, Ly, Lz = lattice.Lx, lattice.Ly, lattice.Lz
# Allocate matrices
T1 = np.zeros((Ns, Ns), dtype=complex)
T2 = np.zeros_like(T1)
T3 = np.zeros_like(T1)
Ts = [T1, T2, T3]
# Precompute indices
indices = np.arange(Ns)
basis_idx = indices % n_basis
cell_idx = indices // n_basis
nx = cell_idx % Lx
ny = (cell_idx // Lx) % Ly
nz = (cell_idx // (Lx * Ly)) % Lz
coords_idx = [nx, ny, nz]
dims = [Lx, Ly, Lz]
for dir_idx in range(dim):
# Calculate new coordinates and winding
# Translation along dir_idx means adding 1 to the corresponding coordinate
# The coordinate is coords_idx[dir_idx]
# New coordinate: (n + 1) % L
n_curr = coords_idx[dir_idx]
L = dims[dir_idx]
n_next = (n_curr + 1) % L
# Winding number: 1 if wrapping around boundary, 0 otherwise
# Wrapped when n_curr + 1 == L => n_next == 0
winding = (n_curr + 1) // L
# Compute new cell index
# cell_new = nx_new + Lx * ny_new + Lx * Ly * nz_new
# Only one coordinate changes
if dir_idx == 0:
nx_new = n_next
ny_new = ny
nz_new = nz
wx, wy, wz = winding, np.zeros_like(winding), np.zeros_like(winding)
elif dir_idx == 1:
nx_new = nx
ny_new = n_next
nz_new = nz
wx, wy, wz = np.zeros_like(winding), winding, np.zeros_like(winding)
elif dir_idx == 2:
nx_new = nx
ny_new = ny
nz_new = n_next
wx, wy, wz = np.zeros_like(winding), np.zeros_like(winding), winding
else:
raise ValueError(f"Invalid direction index {dir_idx}")
cell_new = nx_new + Lx * ny_new + Lx * Ly * nz_new
j_indices = cell_new * n_basis + basis_idx
# Compute phases
if hasattr(lattice, "boundary_phase_from_winding"):
# If lattice has boundary flux, we need to apply phases
# Optimization:
# 1. Compute phase for trivial winding (0,0,0) -> 1.0
# 2. Compute phase for non-trivial winding (only where winding > 0)
# Default phase is 1.0
phases = np.ones(Ns, dtype=complex)
# Find sites where winding occurred
mask_w = (winding > 0)
if np.any(mask_w):
# We can either iterate over unique windings or just compute for the single case w=1
# Usually winding is only 1.
# Check if we can vectorize boundary_phase_from_winding?
# It usually calls self._flux.phase(direction, winding).
# If we access lattice.flux directly? No, abstraction.
# Assume standard implementation: phase = exp(i * phi * w)
# We can just compute phase for w=1 in this direction.
# Calculate phase for w=1
w_args = [0, 0, 0]
w_args[dir_idx] = 1
phase_1 = lattice.boundary_phase_from_winding(*w_args)
if abs(phase_1 - 1.0) > 1e-14:
phases[mask_w] = phase_1
# If winding > 1 (unlikely for +1 translation unless L=1)
# L=1 => winding = 1.
# If L < 1 (impossible), etc.
# Technically if L=1, winding is 1.
# If someone defined L=0.5 (impossible for integer).
# If there are sites with winding > 1?
# (n+1)//L can be > 1 only if n >= 2L-1.
# But n < L. So n+1 <= L. So (n+1)//L <= 1.
# So winding is always 0 or 1.
pass
else:
phases = np.ones(Ns, dtype=complex)
# Assign to matrix
# Ts[dir_idx][j_indices, indices] = phases
Ts[dir_idx][j_indices, indices] = phases
return T1, T2, T3
def reconstruct_k_grid_from_blocks(blocks: List['QuadraticBlockDiagonalInfo']) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Reconstruct the structured k-space grid and energy array
from a flat list of QuadraticBlockDiagonalInfo objects.
Parameters
----------
blocks : list
Output of `ham.block_diagonal_bdg()[0]` (the list of per-k blocks).
Returns
-------
k_grid : np.ndarray
Shape (Lx, Ly, Lz, 3) array of k-vectors.
energies : np.ndarray
Shape (Lx, Ly, Lz, n_bands) array of eigenvalues at each k.
"""
# Extract indices and unique grid dimensions
indices = np.array([blk.block_index for blk in blocks])
energies = np.array([blk.en for blk in blocks])
Lx = indices[:, 0].max() + 1
Ly = indices[:, 1].max() + 1 if indices.shape[1] > 1 else 1
Lz = indices[:, 2].max() + 1 if indices.shape[1] > 2 else 1
n_bands = energies.shape[1]
# Allocate structured arrays
k_grid = np.zeros((Lx, Ly, Lz, 3))
k_grid_frac = np.zeros((Lx, Ly, Lz, 3))
energy_grid = np.zeros((Lx, Ly, Lz, n_bands))
# Fill by index
for blk in blocks:
ix, iy, iz = blk.block_index[0], blk.block_index[1], blk.block_index[2]
k_grid_frac[ix, iy, iz, :] = blk.frac_point
k_grid[ix, iy, iz, :] = blk.point
energy_grid[ix, iy, iz, :] = blk.en
return k_grid, k_grid_frac, energy_grid
# -------------------------------------------------------------------------------------------
#! SPACE TRANSFORMATIONS
# -------------------------------------------------------------------------------------------
def full_k_space_transform(lattice: Lattice, mat: np.ndarray, inverse: bool = False) -> np.ndarray:
r"""
Full Ns x Ns k-space transform using DFT matrix.
Computes:
H_k = F @ H_real @ F†
where F is the Ns x Ns DFT matrix:
F[n, i] = (1/sqrt Ns) * exp(-i k_n r_i)
This works for ANY lattice, independent of multipartition structure.
The result is an NsxNs matrix that is block-diagonal for translationally
invariant systems.
Parameters
----------
lattice : Lattice
Lattice object with DFT matrix and k-vectors
mat : np.ndarray
Real-space matrix, shape (Ns, Ns)
Returns
-------
H_k : np.ndarray
K-space matrix, shape (Ns, Ns)
"""
Ns = lattice.Ns
if mat.shape != (Ns, Ns):
raise ValueError(f"mat must have shape ({Ns}, {Ns}), got {mat.shape}")
# Get or calculate DFT matrix
F = lattice.dft
if F is None or F.shape != (Ns, Ns) or not np.any(F):
F = lattice.calculate_dft_matrix()
# Transform: H_k = F @ H_real @ F†
if not inverse:
F_dagger = F.conj().T
H_k_full = F @ mat @ F_dagger
return H_k_full
else:
# Inverse transform
F_inv = F.conj().T # DFT is unitary, so inverse is conjugate transpose
H_real_full = F_inv @ mat @ F
return H_real_full
def realspace_from_kspace(lattice, H_k: np.ndarray, kgrid: Optional[np.ndarray] = None) -> np.ndarray:
r"""
Inverse Bloch transform: H(k) blocks -> H_real (Ns x Ns).
Reconstructs the real-space Hamiltonian from k-space blocks using the inverse
Fourier transform. This is the exact inverse of `kspace_from_realspace()`.
Formula:
H_real = Σ_k W(k)† H(k) W(k)
where W[i,a] = (1/√Nc) . exp(-ik.r_i) . delta _{sublattice(i),a}
Parameters
----------
lattice : Lattice
Lattice object with geometry information
H_k : np.ndarray
K-space Hamiltonian blocks. Shape (Lx, Ly, Lz, Nb, Nb) or (Nk, Nb, Nb).
Must be in fftfreq order (as returned by kspace_from_realspace).
kgrid : Optional[np.ndarray]
K-point grid for reference. If None, reconstructs using fftfreq convention.
Shape (Lx, Ly, Lz, 3) or (Nk, 3). Must be in fftfreq order.
Returns
-------
H_real : np.ndarray
Real-space Hamiltonian (Ns x Ns).
Notes
-----
- Eigenvalues are preserved to machine precision (error ~1e-15)
- Both H_k and kgrid must be in fftfreq order (no fftshift applied)
- The reconstruction is exact: H_real_reconstructed ≈ H_real_original
Examples
--------
>>> lat = HoneycombLattice(dim=2, lx=2, ly=2, bc='pbc')
>>> H_real_orig = np.random.randn(lat.Ns, lat.Ns)
>>> H_real_orig = H_real_orig + H_real_orig.conj().T # Make Hermitian
>>>
>>> # Forward transform
>>> H_k, k_grid, k_frac = kspace_from_realspace(lattice, H_real_orig)
>>>
>>> # Inverse transform
>>> H_real_recon = realspace_from_kspace(lattice, H_k, k_grid)
>>>
>>> # Check reconstruction
>>> np.allclose(H_real_orig, H_real_recon) # True (to machine precision)
"""
import numpy as np
# Parse input shape
if H_k.ndim == 5:
# (Lx, Ly, Lz, Nb, Nb) format - blocks are already in correct order from kspace_from_realspace
Lx, Ly, Lz, Ns_block, Ns2 = H_k.shape
Nc = Lx * Ly * Lz
H_k_flat = H_k.reshape(Nc, Ns_block, Ns2)
elif H_k.ndim == 3:
# (Nk, Nb, Nb) format
Nk, Ns_block, Ns2 = H_k.shape
H_k_flat = H_k
Nc = Nk
else:
raise ValueError(f"H_k must be 3D or 5D array, got shape {H_k.shape}")
# Check Hermiticity
if Ns_block != Ns2:
raise ValueError(f"H_k blocks must be square: got {Ns_block}x{Ns2}")
Ns = Ns_block
# Infer lattice properties
# For multi-sublattice systems, Ns_block = Nb (number of sublattices), not total sites
Nb = lattice.multipartity # Number of sublattices
if Ns != Nb:
raise ValueError(f"H_k block size {Ns} != lattice sublattices {Nb}")
Lx = lattice._lx
Ly = max(lattice._ly, 1)
Lz = max(lattice._lz, 1)
expected_Nc = Lx * Ly * Lz
if Nc != expected_Nc:
raise ValueError(f"Number of k-points {Nc} != expected {expected_Nc}")
# Reciprocal vectors
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
# Build k-grid if not provided - use same convention as kspace_from_realspace
if kgrid is None:
# Use fftfreq convention to match forward transform
frac_x = np.linspace(0, 1, Lx, endpoint=False)
frac_y = np.linspace(0, 1, Ly, endpoint=False)
frac_z = np.linspace(0, 1, Lz, endpoint=False)
kx_frac, ky_frac, kz_frac = np.meshgrid(frac_x, frac_y, frac_z, indexing="ij")
kgrid = (kx_frac[..., None] * b1 +
ky_frac[..., None] * b2 +
kz_frac[..., None] * b3)
kpoints = kgrid.reshape(-1, 3)
else:
# kgrid is already in fftfreq order from kspace_from_realspace
if kgrid.ndim == 4:
kpoints = kgrid.reshape(-1, 3)
else:
kpoints = np.asarray(kgrid, float).reshape(-1, 3)
# Site coordinates and sublattice indices
coords = np.asarray(lattice.coordinates, float)
Ns_total = lattice.Ns # Total number of sites
indices = np.arange(Ns_total)
sub_idx = indices % Nb
# Projector S[i, a] = delta_{sub(i), a}
S = np.zeros((Ns_total, Nb), dtype=complex)
S[np.arange(Ns_total), sub_idx] = 1.0
# Inverse Bloch transform: H_real = (1/Nc) Σ_k W(k)† H(k) W(k)
# where W[i,a] = (1/sqrt Nc) * exp(-ik\cdot r_i) * S[i,a]
# Note: Using -i (same as forward transform) because W is the unitary transform
H_real = np.zeros((Ns_total, Ns_total), dtype=complex)
for ik, kvec in enumerate(kpoints):
# Phase: exp(-i k\cdot r_i) (same sign as forward transform)
phases = np.exp(-1j * (coords @ kvec))
# Bloch projector W[i,a] = (1/sqrt Nc) * exp(-ik\cdot r_i) * S[i,a]
W = (phases[:, None] * S) / np.sqrt(Nc)
# Accumulate: W H(k) W† (since W is (Ns_total, Nb) and H_k is (Nb, Nb))
H_real += W @ H_k_flat[ik] @ W.conj().T
# Note: No division by Nc needed - it's already in W normalization
# However, we accumulated Nc terms, so effectively: H = Σ_k WHW† = I due to completeness
# Ensure Hermiticity (average with conjugate to remove numerical noise)
H_real = 0.5 * (H_real + H_real.conj().T)
return H_real.astype(lattice._dtype if hasattr(lattice, '_dtype') else np.complex128)
def kspace_from_realspace(
lattice : Lattice,
H_real : np.ndarray,
kpoints : Optional[np.ndarray] = None,
require_full_grid : bool = False,
unitary_norm : bool = True,
return_transform : bool = False):
r"""
Bloch projector: H_real (NsxNs) -> H(k) $\in$ C^{NbxNb} at each k.
Transforms a real-space Hamiltonian into momentum space using the Bloch transform:
H_ab(k) = Σ_{i,j} W*_{i,a}(k) H_{i,j} W_{j,b}(k)
where W[i,a](k) = (1/√Nc) . exp(-ik.r_i) . delta _{sublattice(i),a}
Assumptions:
- Periodic boundary conditions (PBC)
- True translational invariance to preserve spectrum
- Site ordering is arbitrary; geometry (coordinates + basis) defines sublattices
Parameters
----------
lattice : Lattice
Lattice object with geometry information
H_real : np.ndarray
Real-space Hamiltonian matrix (Ns x Ns)
kpoints : Optional[np.ndarray]
Custom k-points to evaluate at. If None, uses full BZ grid.
require_full_grid : bool
If True, raises error if kpoints doesn't match full grid size
unitary_norm : bool
Use unitary normalization (1/√Nc) for Bloch transform
use_cache : bool
Use cached Bloch transformation matrices for speed (default: True)
return_transform : bool
If True, also return the Bloch unitary W for computing correlation functions
Returns
-------
Hk_grid : np.ndarray
Shape (Lx, Ly, Lz, Nb, Nb) if kpoints is None
Shape (Nk, Nb, Nb) if kpoints is provided
Momentum-space Hamiltonian blocks in fftfreq order
kgrid : np.ndarray
Shape (Lx, Ly, Lz, 3) or (Nk, 3)
K-point coordinates in fftfreq order (Γ at [0,0,0])
kgrid_frac : np.ndarray
Shape (Lx, Ly, Lz, 3) or None
Fractional k-point coordinates in fftfreq order
W : np.ndarray [only if return_transform=True]
Shape (Nc, Ns, Nb) or (Nk, Ns, Nb)
Bloch unitary matrix W[ik, i, a] = (1/√Nc) . exp(-ik.r_i) . delta _{sub(i),a}
Use for transforming operators: O_k = W† @ O_real @ W
Notes
-----
- K-points are in fftfreq order: k[0,0,0] = Γ point
- No fftshift is applied to maintain correspondence between k_grid and H_k indices
- For translationally invariant systems: spectrum(H_real) = union of spectrum(H(k))
"""
if H_real.ndim != 2 or H_real.shape[0] != H_real.shape[1]:
raise ValueError("H_real must be a square matrix.")
Ns = H_real.shape[0]
if Ns != lattice.Ns:
raise ValueError(f"H_real size {Ns} != lattice Ns {lattice.Ns}.")
# lattice sizes and counts
Lx, Ly, Lz = lattice._lx, max(lattice._ly, 1), max(lattice._lz, 1)
Nc = Lx * Ly * Lz # number of unit cells
Nb = len(lattice._basis) # number of basis sites per cell
if Ns % Nc != 0 or (Ns // Nc) != Nb:
raise ValueError(f"Ns={Ns} not compatible with Nc={Nc} and Nb={Nb} (Ns must be Nc*Mb).")
# Use full DFT transform - simpler and works for any lattice
if kpoints is None:
# Full k-space transform using DFT matrix
H_k_full = full_k_space_transform(lattice, H_real)
# Extract blocks: for translationally invariant systems,
# H_k_full is block-diagonal with Nc blocks of size NbxNb
# Sites are ordered as [cell0_sub0, cell0_sub1, ..., cell1_sub0, cell1_sub1, ...]
Hk_blocks = np.zeros((Nc, Nb, Nb), dtype=complex)
# i_move = (Nc // 2) * Nb if Nc % 2 == 0 else 0
i_move = 0
for ik in range(Nc):
i_start = ik * Nb + i_move % Ns
i_end = (ik + 1) * Nb + i_move % Ns
Hk_blocks[ik] = H_k_full[i_start:i_end, i_start:i_end]
# Get k-grid from lattice
cache = _get_bloch_transform_cache(lattice, unitary_norm)
# Reshape blocks to grid (fftfreq order)
Hk_grid = Hk_blocks.reshape(Lx, Ly, Lz, Nb, Nb)
Hk_grid = np.ascontiguousarray(Hk_grid)
Hk_grid = np.fft.fftshift(Hk_grid, axes=(0,1,2))
if return_transform:
# For compatibility, return dummy transform
W_grid = np.zeros((Lx, Ly, Lz, Ns, Nb), dtype=complex)
return Hk_grid, cache.kgrid, cache.kgrid_frac, W_grid
else:
return Hk_grid, cache.kgrid, cache.kgrid_frac
# Fallback: manual computation for custom k-points
# reciprocal basis
b1 = np.asarray(lattice._k1, float).reshape(3)
b2 = np.asarray(lattice._k2, float).reshape(3)
b3 = np.asarray(lattice._k3, float).reshape(3)
#! k-point mesh - either full grid or provided points
if kpoints is None:
# Use fftfreq convention: k_n = n/N for n = 0, 1, ..., N/2-1, -N/2, ..., -1
# This gives k $\in$ [-0.5, 0.5) in fractional coordinates
# which maps to the first Brillouin zone correctly for both even and odd N
frac_x = np.linspace(0, 1, Lx, endpoint=False) # sorted: [0, 1/N, ..., (N/2-1)/N, -N/2/N, ..., -1/N]
frac_y = np.linspace(0, 1, Ly, endpoint=False)
frac_z = np.linspace(0, 1, Lz, endpoint=False)
# Create meshgrid in the fftfreq order
kx_frac, ky_frac, kz_frac = np.meshgrid(frac_x, frac_y, frac_z, indexing="ij")
# Store fractional coordinates (shape: Lx, Ly, Lz, 3)
kgrid_frac = np.stack([kx_frac, ky_frac, kz_frac], axis=-1)
# Construct Cartesian k-vectors: k = f1*b1 + f2*b2 + f3*b3
kgrid = (kx_frac[..., None] * b1 +
ky_frac[..., None] * b2 +
kz_frac[..., None] * b3) # shape (Lx, Ly, Lz, 3)
# Γ point is at index [0,0,0] (fftfreq order)
# Flatten for computation
kpoints = kgrid.reshape(-1, 3) # shape (Nc, 3)
return_grid = True
else:
kpoints = np.asarray(kpoints, float).reshape(-1, 3)
kgrid = None
kgrid_frac = None
return_grid = False
# Determine number of k-points
Nk = kpoints.shape[0]
if require_full_grid and Nk != Nc:
raise ValueError(f"Round-trip requires Nk == Nc == {Nc}, got Nk={Nk}.")
#! geometric labeling: (cell, sub) while keeping order
indices = np.arange(Ns)
sub_idx = indices % Nb
# Ensure basis is padded with zeros for 1D/2D
basis_coords = np.zeros((Nb, 3), dtype=float)
basis_coords[:, :lattice._basis.shape[1]] = np.asarray(lattice._basis, float)
# Total position vector r_i = R_n + τ_a
coords = np.asarray(lattice.coordinates, float) # shape (Ns, 3)
# Projector S[i, a] = delta_{sub(i), a}
S = np.zeros((Ns, Nb), dtype=complex)
S[np.arange(Ns), sub_idx] = 1.0 # shape (Ns, Nb)
# Bloch transform: H_ab(k) = Σ_{i,j} W*_{i,a} H_{i,j} W_{j,b}
# where W_{i,a} = (1/sqrt Nc) * e^{-ik\cdot r_i} * delta_{sub(i),a}
# where r_i = R_cell + τ_sublattice is the full site position
# This accounts for both unit cell position AND basis vector in the phase
phases = np.exp(-1j * (kpoints @ coords.T)) # (Nk, Ns)
norm = np.sqrt(Nc) if unitary_norm else 1.0
phases /= norm
# Vectorized projector: W[ik, i, a] = phases[ik, i] * S[i, a]
W = phases[:, :, None] * S[None, :, :] # (Nk, Ns, Nb)
# Transform: H(k) = W† @ H @ W for all k
# Use einsum for efficiency: (Nk, Nb, Ns) @ (Ns, Ns) @ (Nk, Ns, Nb) -> (Nk, Nb, Nb)
if sp.issparse(H_real):
# For sparse matrices, use loop (einsum doesn't support sparse)
Hk = np.zeros((Nk, Nb, Nb), dtype=complex)
for ik in range(Nk):
Hk[ik] = W[ik].conj().T @ (H_real @ W[ik])
else:
Hk = np.einsum('kia,ij,kjb->kab', W.conj(), H_real, W)
if return_grid:
# Reshape blocks to grid (keep in fftfreq order - NO SHIFT)
# This ensures H_k[ix,iy,iz] corresponds to k_grid[ix,iy,iz]
Hk_grid = Hk.reshape(Lx, Ly, Lz, Nb, Nb)
if return_transform:
W_grid = W.reshape(Lx, Ly, Lz, Ns, Nb)
return Hk_grid, kgrid, kgrid_frac, W_grid
else:
return Hk_grid, kgrid, kgrid_frac
else:
if return_transform:
return Hk, kpoints, W
else:
return Hk, kpoints
# -----------------------------------------------------------------------------------------------------------
#! END OF FILE
# -----------------------------------------------------------------------------------------------------------