"""
Contains the general lattice class hierarchy and helpers.
This module defines the base :class:`Lattice` API used across general_python, together with
utility routines for boundary handling and symmetry metadata.
Currently, up to 3-spatial dimensions are supported...
------------------------------------------------------------------------------
File : general_python/lattices/lattice.py
Author : Maksymilian Kliczkowski
Date : 2025-02-01
Version : 2.0
------------------------------------------------------------------------------
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, Iterable, List, Literal, Mapping, Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
from .tools.lattice_tools import (
LatticeDirection, LatticeBC, LatticeType,
handle_boundary_conditions, handle_boundary_conditions_detailed, handle_dim
)
from .tools.lattice_kspace import (
bz_path_data, brillouin_zone_path, StandardBZPath,
PathTypes,
reciprocal_from_real, extract_momentum, reconstruct_k_grid_from_blocks,
build_translation_operators,
HighSymmetryPoints, HighSymmetryPoint,
KPathResult, find_nearest_kpoints
)
from .tools.region_handler import LatticeRegionHandler, RegionType
from .tools.lattice_flux import BoundaryFlux
if TYPE_CHECKING:
from ..common.flog import Logger
from .tools.lattice_kspace import KPathSelection
################################################################################
Backend = np
[docs]
class Lattice(ABC):
r"""
Abstract Base Class for defining lattice structures.
This class serves as the foundation for all lattice implementations in the `lattices` module.
It handles geometry, connectivity, boundary conditions, and k-space properties.
Indexing Convention
-------------------
Lattice sites are indexed linearly from ``0`` to ``Ns - 1``.
The mapping from spatial coordinates to linear index depends on the concrete implementation,
but typically follows a row-major (lexicographic) order:
* **1D**: Left to right.
* **2D**: Bottom-left to top-right (x varies fastest).
* **3D**: Front-bottom-left to back-top-right.
Features
--------
* **Geometry**: Calculation of real-space coordinates, unit vectors, and basis vectors.
* **Connectivity**: Automatic identification of Nearest Neighbors (NN) and Next-Nearest Neighbors (NNN).
* **Boundaries**: Support for various boundary conditions:
* ``PBC``: Periodic Boundary Conditions (torus topology).
* X-direction periodic, Y-direction periodic, Z-direction periodic
* ``OBC``: Open Boundary Conditions (hard edges).
* X-direction open, Y-direction open, Z-direction open
* ``MBC``: Mixed Boundary Conditions (e.g., cylinder topology).
* X-direction periodic, Y-direction open, Z-direction open
* ``SBC``: Switched Boundary Conditions (e.g. twisted cylinder).
* X-direction open, Y-direction periodic, Z-direction open
* **TWISTED**: Twisted Boundary Conditions with specified fluxes.
* **Reciprocal Space**: Automatic calculation of reciprocal lattice vectors and Brillouin Zone paths.
* **Visualization**: Integration with plotting utilities via ``.plot``.
Attributes
----------
Ns : int
Total number of sites in the lattice.
dim : int
Spatial dimension of the lattice (1, 2, or 3).
Lx, Ly, Lz : int
Linear dimensions of the lattice.
bc : LatticeBC
Active boundary condition.
coordinates : np.ndarray
Array of shape ``(Ns, 3)`` containing real-space coordinates of all sites.
nn : List[List[int]]
Adjacency list for nearest neighbors. ``nn[i]`` is a list of neighbors for site ``i``.
"""
_BAD_LATTICE_SITE = None
_DFT_LIMIT_SITES = 100
# ---------------------------------------------
#! INITIALIZATION
# ---------------------------------------------
@property
def bad_lattice_site(self):
''' Bad lattice site '''
return self._BAD_LATTICE_SITE
# Lattice constants - physical units where applicable
a = 1
b = 1
c = 1
unit_length = 1 # unit length in Angstroms - helper for physical calculations
[docs]
def __init__(self,
dim : int = None,
lx : int = 1,
ly : int = 1,
lz : int = 1,
bc : str = None, # boundary conditions
adj_mat : np.ndarray = None, # can be controlled by the user for generic graphs
flux : np.ndarray = None, # flux piercing the boundaries - for each direction - for topological models
*args,
**kwargs):
r'''
General Lattice class. This class contains the general lattice model.
Parameters
----------
dim : int, optional
Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.
lx : int, optional
Length of the lattice in the x-direction.
ly : int, optional
Length of the lattice in the y-direction.
lz : int, optional
Length of the lattice in the z-direction.
bc : str, optional
Boundary conditions (e.g., 'PBC', 'OBC').
adj_mat : np.ndarray, optional
Adjacency matrix for the lattice.
flux : np.ndarray, optional
Flux piercing the boundaries. This can be a dictionary specifying the
flux in each direction, or a single value applied to all directions. Importantly,
this automatically implies **TWISTED** boundary conditions, so the `bc` parameter can be left as None or set to 'TWISTED' for clarity.
'''
try:
from .tools.lattice_flux import _normalize_flux_dict
except ImportError as e:
raise ImportError("Failed to import lattice flux tools. Ensure that the QES package is correctly installed.") from e
self._dim = handle_dim(lx, ly, lz)[0] if dim is None else dim
self._bc = handle_boundary_conditions(bc, flux=flux) # Normalize boundary conditions and handle flux
# flux piercing the boundaries - for topological models
_raw_flux = None
if isinstance(self._bc, tuple):
# If we have a tuple, it means we have TWISTED BCs with flux information
self._bc, _raw_flux = self._bc
self._flux = _normalize_flux_dict(_raw_flux if _raw_flux is not None else flux)
self._raw_flux = _raw_flux if _raw_flux is not None else flux
self._lx = lx
self._ly = ly
self._lz = lz
self._lxly = lx * ly
self._lxlz = lx * lz
self._lylz = ly * lz
self._lxlylz = lx * ly * lz
self._ns = lx * ly * lz # Number of sites - set only initially as it is implemented in the children
self._type = LatticeType.SQUARE
self._adj_mat = adj_mat
# Region handler
self.regions = LatticeRegionHandler(self)
super().__init__(*args, **kwargs)
# neighbors
self._nn = [[]]
self._nn_forward = [[]]
self._nn_max_num = 0
self._nnn = [[]]
self._nnn_forward = [[]]
# helping lists
self._cells = [] # real space coordinates
self._fracs = [] # fractional coordinates
self._subs = [] # sub-lattice indices
self._spatial_norm = [[[]]] # three dimensional array for the spatial norm
# matrices for real space and inverse space vectors
vec_size = max(3, self._dim)
self._vectors = Backend.zeros((vec_size, vec_size)) # real space vectors - base vectors of the lattice
self._a1 = Backend.zeros(vec_size) # real space vectors - base vectors of the lattice
self._a2 = Backend.zeros(vec_size)
self._a3 = Backend.zeros(vec_size)
self._basis = Backend.zeros((0, vec_size)) # basis vectors within the unit cell
# inverse space vectors - ALWAYS 3D vectors for consistency
self._k1 = Backend.zeros(3) # inverse space vectors - reciprocal lattice vectors (3D)
self._k2 = Backend.zeros(3)
self._k3 = Backend.zeros(3)
# normal vectors (along the bonds - if required)
self._n1 = Backend.zeros((self._dim, self._dim ))# normal vectors along the bonds
self._n2 = Backend.zeros((self._dim, self._dim ))
self._n3 = Backend.zeros((self._dim, self._dim ))
# nearest neighbors vectors of the cells
self._delta_z = np.array([0.0, 0.0, self.a]) # UP
self._delta_x = np.array([self.a, 0.0, 0.0]) # RIGHT
self._delta_y = np.array([0.0, self.a, 0.0]) # FRONT
# bonds
self._bonds = [] # empty if not calculated...
self._rvectors = Backend.zeros((self._ns, 3)) # allowed values of the real space vectors
self._kvectors = Backend.zeros((self._ns, 3)) # allowed values of the inverse space vectors
# initialize dft matrix
self._dft = Backend.zeros((self._ns, self._ns), dtype = complex) # Discrete Fourier Transform matrix for the lattice model
# symmetries for momenta (if one uses the symmetry,
# returning to original one may result in using normalization)
self.sym_norm = {}
self.sym_map = {}
# -----------------------------------------------------------------------------
[docs]
def __str__(self):
''' String representation of the lattice '''
return "General Lattice"
[docs]
def __repr__(self):
''' Representation of the lattice '''
if self._bc is LatticeBC.TWISTED:
return f"{self._type.name},{self._bc.name},flux={self._flux},d={self._dim},Ns={self._ns},Lx={self._lx},Ly={self._ly},Lz={self._lz}"
return f"{self._type.name},{self._bc.name},d={self._dim},Ns={self._ns},Lx={self._lx},Ly={self._ly},Lz={self._lz}"
@property
def _flux_suffix(self) -> str:
"""Return a suffix string for ``__str__`` / ``__repr__`` that includes flux info."""
if self._flux is not None and self._flux.is_nontrivial:
return f",flux={self._flux}"
return ""
[docs]
def __len__(self):
''' Length of the lattice (number of sites) '''
return self._ns
[docs]
def __getitem__(self, index: int):
''' Get the site at the given index '''
if index < 0 or index >= self._ns:
raise IndexError("Lattice index out of range")
return index
[docs]
def __iter__(self):
''' Iterate over the lattice sites '''
for i in range(self._ns):
yield i
[docs]
def __contains__(self, item: int):
''' Check if the lattice contains the given site '''
return 0 <= item < self._ns
# -----------------------------------------------------------------------------
[docs]
def init(self, verbose: bool = False, *, force_dft: bool = False, **kwargs):
"""
Initializes the lattice object by calculating coordinates, reciprocal vectors, and neighbor lists.
This method performs the following steps:
1. Calculates the real-space coordinates, r-vectors, and k-vectors of the lattice.
2. If the number of sites (`self.Ns`) is less than 100, computes the discrete Fourier transform (DFT) matrix.
3. If an adjacency matrix (`self._adj_mat`) is provided:
- Determines the number of sites (`Ns`) from the adjacency matrix.
- For each site, identifies nearest neighbors (nn) as those connected by the highest weight in the adjacency matrix, and next-nearest neighbors (nnn) as those connected by the next highest distinct weight.
- Stores forward neighbors (indices greater than the current site) for both nn and nnn.
4. If no adjacency matrix is provided, calculates nearest and next-nearest neighbors using default methods.
5. Calculates normalization or symmetry properties of the lattice.
This method sets up all necessary neighbor lists and lattice properties required for further computations.
"""
self.calculate_reciprocal_vectors()
self.calculate_coordinates()
if verbose: print(" Lattice: Calculated coordinates.")
self.calculate_r_vectors()
if verbose: print(" Lattice: Calculated r-vectors.")
self.calculate_k_vectors()
if verbose: print(" Lattice: Calculated k-vectors.")
if self.Ns < Lattice._DFT_LIMIT_SITES or force_dft:
self.calculate_dft_matrix()
if verbose: print(" Lattice: Calculated DFT matrix.")
if self._adj_mat is not None:
Ns = self._adj_mat.shape[0]
self._ns = Ns
W = self._adj_mat
nn_list = []
nnn_list = []
for i in range(Ns):
#! sort by |weight| so signed couplings keep topology semantics
js = [j for j in range(Ns) if j != i and W[i, j] != 0]
sorted_js = sorted(js, key=lambda j: abs(W[i, j]), reverse=True)
if not sorted_js:
nn_list.append([])
nnn_list.append([])
continue
#! highest |weight| defines nn
max_w_abs = abs(W[i, sorted_js[0]])
nn_js = [j for j in sorted_js if abs(W[i, j]) == max_w_abs]
nn_list.append(nn_js)
if len(sorted_js) > len(nn_js):
#! find next distinct |weight|
remaining = [abs(W[i, j]) for j in sorted_js if abs(W[i, j]) != max_w_abs]
if remaining:
second_w = max(remaining)
nnn_js = [j for j in sorted_js if abs(W[i, j]) == second_w]
else:
nnn_js = []
else:
nnn_js = []
nnn_list.append(nnn_js)
self._nn = nn_list
self._nn_forward = [[j for j in nn_list[i] if j>i] for i in range(Ns)]
self._nnn = nnn_list
self._nnn_forward = [[j for j in nnn_list[i] if j>i] for i in range(Ns)]
if verbose: print(" Lattice: Calculated neighbors from adjacency matrix.")
if verbose: print(" Lattice: Calculated forward neighbors from adjacency matrix.")
if verbose: print(" Lattice: Calculated next-nearest neighbors from adjacency matrix.")
if verbose: print(" Lattice: Calculated forward next-nearest neighbors from adjacency matrix.")
else:
self.calculate_nn()
if verbose: print(" Lattice: Calculated nearest neighbors.")
self.calculate_nnn()
if verbose: print(" Lattice: Calculated next-nearest neighbors.")
self.calculate_norm_sym()
if verbose: print(" Lattice: Calculated normalization/symmetry.")
# Initialize the normal vectors along the bonds
self._n1 = self._delta_x / np.linalg.norm(self._delta_x)
self._n2 = self._delta_y / np.linalg.norm(self._delta_y)
self._n3 = self._delta_z / np.linalg.norm(self._delta_z)
# -----------------------------------------------------------------------------
#! Region generators
# -----------------------------------------------------------------------------
[docs]
def get_region(
self,
kind : Union[str, RegionType] = RegionType.HALF,
*,
origin : Optional[Union[int, List[float]]] = None,
radius : Optional[float] = None,
direction : Optional[str] = None,
sublattice : Optional[int] = None,
sites : Optional[List[int]] = None,
depth : Optional[int] = None,
plaquettes : Optional[List[int]] = None,
**kwargs
) -> List[int]:
r"""
Return a list of site indices defining a spatial region.
Parameters
----------
kind : str or RegionType
Type of region: 'half', 'disk', 'sublattice', 'graph', 'plaquette', 'custom'.
We also support specific half cuts like 'half_x', 'half_y', 'half_z' for convenience.
origin : int or list[float], optional
Center of the region. Can be a site index or coordinate vector.
radius : float, optional
Radius for 'disk' regions.
direction : str, optional
Direction for 'half' cuts ('x', 'y', 'z').
sublattice : int, optional
Sublattice index for 'sublattice' regions.
sites : list[int], optional
Explicit list of sites for 'custom' regions.
depth : int, optional
Depth/distance for 'graph' regions.
plaquettes : list[int], optional
List of plaquette indices for 'plaquette' regions.
Returns
-------
list[int]
Sorted list of site indices belonging to the region.
"""
return self.regions.get_region(
kind=kind,
origin=origin,
radius=radius,
direction=direction,
sublattice=sublattice,
sites=sites,
depth=depth,
plaquettes=plaquettes,
**kwargs
)
[docs]
def get_entropy_cuts(self, cut_type: str = "all", *, include_sublattice: bool = True, sweep_by_unit_cell: Optional[bool] = None) -> Dict[str, List[int]]:
"""
Return canonical bipartition cuts for entanglement-entropy workflows.
This is a convenience wrapper around :meth:`self.regions.get_entropy_cuts`.
"""
return self.regions.get_entropy_cuts(cut_type=cut_type, include_sublattice=include_sublattice, sweep_by_unit_cell=sweep_by_unit_cell)
[docs]
def generate_regions(self, kind: Union[str, RegionType] = RegionType.KITAEV_PRESKILL, **kwargs,):
"""
Generate many region candidates for a selected region type.
This is a thin wrapper around :meth:`self.regions.generate_regions`.
"""
return self.regions.generate_regions(kind=kind, **kwargs)
################################### GETTERS ###################################
@property
def lx(self): return self._lx
@property
def Lx(self): return self._lx
@lx.setter
def lx(self, value): self._lx = value; self._lxly = self._lx * self._ly; self._lxlylz = self._lxly * self._lz; self._lxlz = self._lx * self._lz
@Lx.setter
def Lx(self, value): self._lx = value; self._lxly = self._lx * self._ly; self._lxlylz = self._lxly * self._lz; self._lxlz = self._lx * self._lz
@property
def ly(self): return self._ly
@property
def Ly(self): return self._ly
@ly.setter
def ly(self, value): self._ly = value; self._lxly = self._lx * self._ly; self._lxlylz = self._lxly * self._lz; self._lylz = self._ly * self._lz
@Ly.setter
def Ly(self, value): self._ly = value; self._lxly = self._lx * self._ly; self._lxlylz = self._lxly * self._lz; self._lylz = self._ly * self._lz
@property
def lz(self): return self._lz
@property
def Lz(self): return self._lz
@lz.setter
def lz(self, value): self._lz = value; self._lxlylz = self._lxly * self._lz; self._lylz = self._ly * self._lz; self._lxlz = self._lx * self._lz
@Lz.setter
def Lz(self, value): self._lz = value; self._lxlylz = self._lxly * self._lz; self._lylz = self._ly * self._lz; self._lxlz = self._lx * self._lz
@property
def area(self): return self._lxly
@property
def volume(self): return self._lxlylz
@property
def lxly(self): return self._lxly
@property
def lxlz(self): return self._lxlz
@property
def lylz(self): return self._lylz
@property
def lxlylz(self): return self._lxlylz
@property
def dim(self): return self._dim
@dim.setter
def dim(self, value): self._dim = value
@property
def ns(self): return self._ns
@property
def Ns(self): return self._ns
@property
def sites(self): return self._ns
@property
def size(self): return self._ns
@property
def nsites(self): return self._ns
@ns.setter
def ns(self, value): self._ns = value
@Ns.setter
def Ns(self, value): self._ns = value
# -----------------------------------------------------------------------------
#! Physical
# -----------------------------------------------------------------------------
@property
def sites_per_cell(self) -> int:
"""Sites per unit cell (1 for Bravais, 2 for honeycomb, etc.)."""
n_cells = max(1, self._lx * self._ly * self._lz)
return max(1, self._ns // n_cells)
[docs]
def symmetry_perms(self, point_group: str = "full") -> np.ndarray:
"""
Generate space-group permutation table for this lattice.
Delegates to :func:`~.tools.lattice_symmetry.generate_space_group_perms`.
When TWISTED boundary conditions are active, the point-group part is
disabled (only translations are returned) because a generic flux
breaks point-group symmetry unless the flux respects it.
Parameters
----------
point_group : str
``'full'`` for maximal point group, ``'translations'`` for translations only.
Returns
-------
ndarray, shape (|G|, Ns)
"""
from .tools.lattice_symmetry import generate_space_group_perms
# Flux generically breaks point-group symmetry → translations only
if self.is_twisted and point_group == "full":
point_group = "translations"
return generate_space_group_perms(self.Lx, self.Ly, self.sites_per_cell, point_group)
# ------------------------------------------------------------------
#! Lattice symmetry information
# ------------------------------------------------------------------
[docs]
def lattice_symmetries(self) -> Dict[str, object]:
"""
Return a dictionary describing the spatial symmetries of this lattice.
The information is consistent for both single-particle and many-body
representations. When TWISTED boundary conditions are present the
point-group part is absent (flux generically breaks it).
Returns
-------
dict
Keys:
- ``'lattice_type'`` : :class:`LatticeType` enum
- ``'sites_per_cell'`` : int
- ``'n_cells'`` : number of unit cells
- ``'dim'`` : spatial dimension
- ``'bc'`` : boundary condition enum
- ``'is_periodic'`` : (bool, bool, bool) per direction
- ``'is_twisted'`` : bool
- ``'translation_group'`` : ZL_x x ZL_y (as tuple ``(Lx, Ly)``)
- ``'point_group'`` : str or None (``'D4'`` for square Lx==Ly, etc.)
- ``'space_group_order'`` : total number of space-group elements
- ``'flux'`` : :class:`BoundaryFlux` or None
"""
# Get periodicity flags for each direction
pbc_flags = self.periodic_flags()
# Determine point group
pg = None
if not self.is_twisted:
if hasattr(self, '_type') and self._type == LatticeType.SQUARE:
if self._lx == self._ly and pbc_flags[0] and pbc_flags[1]:
pg = 'D4'
elif pbc_flags[0] and pbc_flags[1]:
pg = 'D2'
elif hasattr(self, '_type') and self._type in (LatticeType.HONEYCOMB, LatticeType.HEXAGONAL):
if self._lx == self._ly and pbc_flags[0] and pbc_flags[1]:
pg = 'C6v' # full hexagonal point group for the lattice
n_cells = max(1, self._lx * self._ly * self._lz)
n_trans = self._lx * (self._ly if pbc_flags[1] else 1) * (self._lz if pbc_flags[2] else 1) if pbc_flags[0] else 1
pg_order = {'D4': 8, 'D2': 4, 'C6v': 12}.get(pg, 1) if pg else 1
return {
'lattice_type': self._type if hasattr(self, '_type') else None,
'sites_per_cell': self.sites_per_cell,
'n_cells': n_cells,
'dim': self._dim,
'bc': self._bc,
'is_periodic': pbc_flags,
'is_twisted': self.is_twisted,
'translation_group': (self._lx, self._ly if self._dim >= 2 else 1),
'point_group': pg,
'space_group_order': n_trans * pg_order,
'flux': self._flux,
}
[docs]
def symmetry_info(self) -> str:
"""
Return a human-readable summary of the lattice symmetries.
Consistent for both single-particle (band-structure / Bloch) and
many-body (Hilbert-space symmetry sectors) viewpoints.
Returns
-------
str
"""
d = self.lattice_symmetries()
lines = [
f"Lattice symmetry info ({d['lattice_type']})",
f" dim = {d['dim']}",
f" sites / cell = {d['sites_per_cell']}",
f" unit cells = {d['n_cells']}",
f" boundary cond. = {d['bc']}",
f" periodic (x,y,z) = {d['is_periodic']}",
]
if d['is_twisted']:
lines.append(f" twisted = True (flux breaks point-group!)")
lines.append(f" flux = {d['flux']}")
lines.append(f" translation group = Z_{d['translation_group'][0]} x Z_{d['translation_group'][1]}")
lines.append(f" point group = {d['point_group'] or 'trivial'}")
lines.append(f" |space group| = {d['space_group_order']}")
return "\n".join(lines)
@property
def a1(self): return self._a1
@a1.setter
def a1(self, value): self._a1 = value
@property
def a2(self): return self._a2
@a2.setter
def a2(self, value): self._a2 = value
@property
def a3(self): return self._a3
@a3.setter
def a3(self, value): self._a3 = value
# -----------------------------------------------------------------------------
# Inverse space vectors
# -----------------------------------------------------------------------------
@property
def k1(self): return self._k1
@k1.setter
def k1(self, value): self._k1 = value
@property
def b1(self): return self._k1
@b1.setter
def b1(self, value): self._k1 = value
@property
def k2(self): return self._k2
@k2.setter
def k2(self, value): self._k2 = value
@property
def b2(self): return self._k2
@b2.setter
def b2(self, value): self._k2 = value
@property
def k3(self): return self._k3
@k3.setter
def k3(self, value): self._k3 = value
@property
def b3(self): return self._k3
@b3.setter
def b3(self, value): self._k3 = value
# ------------------------------------------------------------------
@property
def n1(self): return self._n1
@n1.setter
def n1(self, value): self._n1 = value
@property
def n2(self): return self._n2
@n2.setter
def n2(self, value): self._n2 = value
@property
def n3(self): return self._n3
@n3.setter
def n3(self, value): self._n3 = value
@property
def basis(self): return self._basis
@basis.setter
def basis(self, value): self._basis = value
@property
def multipartity(self): return self._basis.shape[0]
@property
def vectors(self): return self._vectors
@vectors.setter
def vectors(self, value): self._vectors = value
@property
def avec(self): return np.stack((self._a1, self._a2, self._a3), axis=0)
@avec.setter
def avec(self, value): self._a1 = value[0]; self._a2 = value[1]; self._a3 = value[2]
@property
def bvec(self): return np.stack((self._k1, self._k2, self._k3), axis=0)
@bvec.setter
def bvec(self, value): self._k1 = value[0]; self._k2 = value[1]; self._k3 = value[2]
# ------------------------------------------------------------------
#! DFT Matrix
# ------------------------------------------------------------------
@property
def dft(self):
''' Return the discrete Fourier transform (DFT) matrix for the lattice. '''
return self._dft
@dft.setter
def dft(self, value): self._dft = value
@property
def nn(self):
''' Return the nearest-neighbor connectivity matrix for the lattice. '''
return self._nn
@nn.setter
def nn(self, value): self._nn = value
@property
def bonds(self):
''' Return the bond connectivity matrix for the lattice. '''
return self._bonds
@bonds.setter
def bonds(self, value): self._bonds = value
@property
def nn_forward(self):
''' Return the forward nearest-neighbor connectivity matrix for the lattice. '''
return self._nn_forward
@nn_forward.setter
def nn_forward(self, value): self._nn_forward = value
@property
def nnn(self):
''' Return the next-nearest-neighbor connectivity matrix for the lattice. '''
return self._nnn
@nnn.setter
def nnn(self, value): self._nnn = value
@property
def nnn_forward(self):
''' Return the forward next-nearest-neighbor connectivity matrix for the lattice. '''
return self._nnn_forward
@nnn_forward.setter
def nnn_forward(self, value): self._nnn_forward = value
@property
def coordinates(self):
''' Return the real-space coordinates of the lattice sites. '''
return self._coordinates
@coordinates.setter
def coordinates(self, value): self._coordinates = value
@property
def subs(self):
''' Return the sublattice indices of the lattice sites.
For a Bravais lattice, this would simply be an array of zeros.
For a non-Bravais lattice, this would indicate which sublattice each site belongs to. '''
return self._subs
@subs.setter
def subs(self, value): self._subs = value
@property
def cells(self):
''' Return the unit cell coordinates of the lattice sites. For a Bravais lattice,
this would simply be the integer coordinates of the unit cells.
For a non-Bravais lattice, this would include the basis vectors as well. '''
return self._cells
@cells.setter
def cells(self, value): self._cells = value
@property
def fracs(self):
''' Return fractional coordinates of the lattice sites. Example: for a square lattice, these would be (x/Lx, y/Ly, z/Lz) for each site. '''
return self._fracs
@fracs.setter
def fracs(self, value): self._fracs = value
@property
def kvectors(self):
''' Return the allowed k-vectors in reciprocal space for the lattice. '''
return self._kvectors
@kvectors.setter
def kvectors(self, value): self._kvectors = value
@property
def rvectors(self):
''' Return the allowed r-vectors in real space for the lattice. '''
return self._rvectors
@rvectors.setter
def rvectors(self, value): self._rvectors = value
@property
def bc(self): return self._bc
@bc.setter
def bc(self, value): self._bc = value
@property
def bc_x(self): return handle_boundary_conditions_detailed(self._bc, self._raw_flux).get('x', False)
@property
def bc_y(self): return handle_boundary_conditions_detailed(self._bc, self._raw_flux).get('y', False)
@property
def bc_z(self): return handle_boundary_conditions_detailed(self._bc, self._raw_flux).get('z', False)
@property
def cardinality(self): return self.get_nn_forward_num_max()
@cardinality.setter
def cardinality(self, value): self._nn_max_num = value
@property
def flux(self): return self._flux
@flux.setter
def flux(self, flux: Union['BoundaryFlux', Dict[str, float], None]):
'''
Set the flux piercing the boundaries for twisted boundary conditions.
'''
try:
from .tools.lattice_flux import _normalize_flux_dict
except ImportError:
raise ImportError("Setting flux requires the lattice_flux module. Please ensure it is available.")
self._bc = handle_boundary_conditions(self._bc, flux=flux) # Normalize boundary conditions and handle flux
# flux piercing the boundaries - for topological models
_raw_flux = None
if isinstance(self._bc, tuple):
# If we have a tuple, it means we have TWISTED BCs with flux information
self._bc, _raw_flux = self._bc
self._flux = _normalize_flux_dict(_raw_flux if _raw_flux is not None else flux)
self._raw_flux = _raw_flux if _raw_flux is not None else flux
@property
def name(self): return self.__str__()
@property
def type(self): return self._type if hasattr(self, '_type') else None
# ------------------------------------------------------------------
#! Sublattice
# ------------------------------------------------------------------
[docs]
def sublattice(self, site: int) -> int:
"""
Return the sublattice index for a given site.
By default, returns 0 for all sites (single sublattice).
Override in subclasses for multi-sublattice lattices.
"""
return site % self.multipartity
# ------------------------------------------------------------------
#! K-space
# ------------------------------------------------------------------
[docs]
def k_vector(self, qx, qy=0.0, qz=0.0) -> np.ndarray:
"""
Return the k-vector in Cartesian coordinates for given (qx, qy, qz)
in reciprocal lattice units.
"""
if self.k1 == None or self.k2 == None or self.k3 == None:
self.k1, self.k2, self.k3 = reciprocal_from_real(self.a1, self.a2, self.a3)
kvec = qx * self.k1[0,:]
if self.dim > 1:
kvec += qy * self.k2[0,:]
if self.dim > 2:
kvec += qz * self.k3[0,:]
return kvec
[docs]
def k_grid(self, n_k: Union[int, Tuple[int, int, int]], shift: Optional[Union[bool, Tuple[bool, bool, bool]]] = None) -> np.ndarray:
r"""
Generate a full k-point grid for the given lattice.
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors _k1, _k2, _k3.
n_k : Iterable[int]
Number of points (Lx, Ly, Lz) along each reciprocal direction.
We define the k-points as:
k = f1 * b1 + f2 * b2 + f3 * b3,
where f_i = n_i / N_i, with n_i = 0, 1, ..., N_i - 1.
Returns
-------
k_points : np.ndarray, shape (Nk, dim)
Cartesian coordinates of k-points in reciprocal space.
"""
try:
from .tools.lattice_kspace import generate_k_grid
except ImportError:
raise ImportError("k_grid requires the lattice_kspace module. Please ensure it is available.")
return generate_k_grid(lattice=self, n_k=n_k, shift=shift)
[docs]
def wigner_seitz_extend(self, k_points: np.ndarray, data: Optional[np.ndarray] = None, *, copies: Optional[Union[int, Iterable[int]]] = None, **kwargs) -> Tuple[np.ndarray, Optional[np.ndarray]]:
r'''
Extend k-space points and optional data across translated Brillouin zones.
The helper works for arbitrary k-space dimensions and any number of
reciprocal translation vectors. Legacy ``b1``/``b2``/``b3`` with
``nx``/``ny``/``nz`` remain supported for existing callers.
Allows to generate extended k-point grids for plotting band structures along high-symmetry paths...
Parameters
----------
k_points : ndarray, shape (N, dim)
Array of k-points in reciprocal space to be extended.
data : ndarray, shape (N, ...) or None
Optional data associated with each k-point (e.g. eigenvalues) to be extended alongside the k-points. Must have the same leading dimension as k_points.
copies : int or iterable of ints, optional
Number of translated copies to generate in each reciprocal direction. If an integer is provided, the same number of copies will be generated in all directions. If an iterable is
provided, it should have a length equal to the number of reciprocal lattice vectors (e.g. 3 for 3D), specifying the number of copies in each direction separately.
**kwargs
Additional keyword arguments to pass to the underlying ws_extend function. See its documentation for details.
Returns
-------
extended_k_points : ndarray, shape (M, dim)
Array of extended k-points in reciprocal space, including the original points and their translated copies
extended_data : ndarray, shape (M, ...) or None
Extended data associated with each k-point, if the input data was provided. Otherwise, None
'''
try:
from .tools.lattice_kspace import extend_kspace_data
except ImportError:
raise ImportError("wigner_seitz_extend requires the lattice_kspace module. Please ensure it is available.")
return extend_kspace_data(k_points=k_points, data=data, lattice=self, copies=copies, **kwargs)
[docs]
def wigner_seitz_mask(self, Kx, Ky=None, Kz=None, *, shells: int = 1, tol: float = 1e-12, **kwargs) -> np.ndarray:
"""
Return a boolean mask for the Wigner-Seitz cell in reciprocal space.
This can be used to identify which k-points lie within the first Brillouin zone.
Parameters
----------
Kx, Ky, Kz : array-like
Arrays of k-point coordinates in reciprocal space. This is a grid
of k-points for which we want to determine if they lie within the Wigner-Seitz cell.
shells : int
Number of shells of Wigner-Seitz cell to include in the mask.
tol : float
Tolerance for determining if a point is within the Wigner-Seitz cell, accounting for numerical precision issues.
**kwargs
Additional keyword arguments to pass to the underlying ws_bz_mask function. See its documentation for details.
"""
try:
from .tools.lattice_kspace import ws_bz_mask
except ImportError:
raise ImportError("wigner_seitz_mask requires the lattice_kspace module. Please ensure it is available.")
return ws_bz_mask(KX=Kx, KY=Ky, KZ=Kz, shells=shells, tol=tol, lattice=self, **kwargs)
[docs]
def wigner_seitz_shifts(self, *, copies: Optional[Union[int, Iterable[int]]] = None,
include_origin: bool = False, tol: float = 1e-12, **kwargs) -> np.ndarray:
"""
Return reciprocal-lattice translation vectors for Brillouin-zone copies.
This is the shared geometry helper for selecting or drawing translated
Brillouin zones. It returns zone-center shifts only, not an extended
k-mesh.
Parameters
----------
copies : int or iterable of int, optional
Number of translated copies to generate in each reciprocal
direction.
include_origin : bool, default=False
Whether to include the central Brillouin zone at ``Gamma``.
tol : float, default=1e-12
Tolerance used when removing numerically duplicated shifts.
**kwargs
Additional keyword arguments forwarded to
``tools.lattice_kspace.ws_bz_shifts``.
Returns
-------
np.ndarray
Array of reciprocal-space translation vectors for zone copies.
"""
try:
from .tools.lattice_kspace import ws_bz_shifts
except ImportError:
raise ImportError("wigner_seitz_shifts requires the lattice_kspace module. Please ensure it is available.")
return ws_bz_shifts(lattice=self, copies=copies, include_origin=include_origin, tol=tol, **kwargs)
# ------------------------------------------------------------------
#! High-symmetry points and BZ paths
# ------------------------------------------------------------------
[docs]
def high_symmetry_points(self) -> Optional[HighSymmetryPoints]:
"""
Return high-symmetry points for this lattice type.
Override in subclasses to provide lattice-specific high-symmetry points.
Returns None if not defined for this lattice type.
Returns
-------
HighSymmetryPoints or None
High-symmetry points with default path, or None if not defined.
Example
-------
>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> pts = lattice.high_symmetry_points()
>>> print(pts.Gamma.frac_coords) # (0.0, 0.0, 0.0)
>>> print(pts.default_path()) # ['Gamma', 'X', 'M', 'Gamma']
"""
# Base implementation tries to guess from lattice type
if hasattr(self, '_type'):
if self._type == LatticeType.SQUARE:
if self.dim == 1:
return HighSymmetryPoints.chain_1d()
elif self.dim == 2:
return HighSymmetryPoints.square_2d()
elif self.dim == 3:
return HighSymmetryPoints.cubic_3d()
elif self._type == LatticeType.HONEYCOMB:
return HighSymmetryPoints.honeycomb_2d()
elif self._type == LatticeType.HEXAGONAL:
return HighSymmetryPoints.hexagonal_2d()
return None
[docs]
def default_bz_path(self) -> Optional[List[Tuple[str, List[float]]]]:
"""
Return the default Brillouin zone path for this lattice.
Returns
-------
List[Tuple[str, List[float]]] or None
Default path as list of (label, [f1, f2, f3]) tuples, or None if not defined.
"""
hs_pts = self.high_symmetry_points()
if hs_pts is not None:
return hs_pts.get_default_path_points()
return None
[docs]
def default_resolve_path(self, path: Iterable[tuple[str, Iterable[float]]] | StandardBZPath | str | List[str] | HighSymmetryPoints) -> List[Tuple[str, List[float]]]:
"""
Resolve path input to a list of (label, fractional_coord) pairs.
Parameters
----------
path : list[(label, coords)], StandardBZPath, str, List[str], or HighSymmetryPoints
Path definition (fractional coordinates), standard enum, enum name string,
list of point labels, or HighSymmetryPoints object.
lattice : Lattice, optional
Lattice object used to resolve labels if path is a list of strings.
Returns
-------
resolved_path : list[(label, list[float])]
Resolved path as a list of (label, fractional_coord) pairs.
Example
-------
>>> path = _resolve_path_input("SQUARE_2D")
>>> for label, coord in path:
... print(f"{label}: {coord}")
"""
try:
from .tools.lattice_kspace import resolve_path_input
except ImportError:
raise ImportError("default_resolve_path requires the path_utils module. Please ensure it is available.")
return resolve_path_input(path, lattice=self)
[docs]
def contains_special_point(self, point: Union[str, HighSymmetryPoint, Tuple[float, ...], np.ndarray], *, tol: float = 1e-12) -> bool:
r"""
Return ``True`` if the lattice momentum grid contains a special point.
This method helps to check whether a finite lattice
contains a particular high-symmetry point in the Brillouin zone,
which is important for band structure calculations and topological analyses.
Parameters
----------
point
Special point identifier. Accepted forms:
- label string (e.g. ``"Gamma"``, ``"K"``, ``"K'"``),
- :class:`HighSymmetryPoint`,
- explicit fractional coordinate tuple/array.
tol : float
Absolute tolerance used in the coordinate match.
Notes
-----
The check is done in *fractional* reciprocal coordinates and naturally
includes flux-induced shifts from twisted boundary conditions because it
uses ``self.kvectors_frac``.
"""
# Get fractional k-vectors, calculating if not already available
kfrac = getattr(self, "kvectors_frac", None)
if kfrac is None:
self.calculate_k_vectors()
kfrac = getattr(self, "kvectors_frac", None)
if kfrac is None:
return False
# Resolve target point to fractional coordinates
target_frac = None
if isinstance(point, HighSymmetryPoint):
target_frac = np.asarray(point.frac_coords, dtype=float)
elif isinstance(point, str): # Label string - look up in high_symmetry_points
hs_pts = self.high_symmetry_points()
if hs_pts is None:
return False
p_obj = hs_pts.resolve(point) if hasattr(hs_pts, "resolve") else hs_pts.get(point)
if p_obj is None:
return False
target_frac = np.asarray(p_obj.frac_coords, dtype=float)
else: # Fractional coordinate tuple/array
try:
target_frac = np.asarray(point, dtype=float).reshape(-1)
except Exception:
return False
# Check if any k-vector matches the target fractional coordinates within tolerance
if target_frac is None or target_frac.size == 0:
return False
if target_frac.size < 3:
target_frac = np.pad(target_frac, (0, 3 - target_frac.size), mode="constant")
# Check dimensions and wrap to [0, 1) - we work in fractional coordinates so this naturally includes any flux-induced shifts
dim = 1 if self.dim == 1 else (2 if self.dim == 2 else 3)
grid = np.asarray(kfrac, dtype=float)
if grid.ndim != 2 or grid.shape[1] < dim:
return False
# Grid and target are wrapped to [0, 1) in fractional coordinates, so this check naturally includes any flux-induced shifts from twisted boundary conditions
grid = np.mod(grid[:, :dim], 1.0)
tgt = np.mod(target_frac[:dim], 1.0)
hits = np.all(np.isclose(grid, tgt[None, :], atol=tol, rtol=0.0), axis=1)
return bool(np.any(hits))
[docs]
def bz_path(self, path: Optional[Union[List[str], str, StandardBZPath]] = None, *, points_per_seg: int = 40) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, str]], np.ndarray]:
"""
Generate k-points along a Brillouin zone path.
Parameters
----------
path : list of str, str, StandardBZPath, or None
Path specification. Can be:
- List of high-symmetry point names: ['Gamma', 'X', 'M', 'Gamma']
- StandardBZPath enum or string: 'SQUARE_2D'
- None: use default path for this lattice
points_per_seg : int
Number of interpolated points per path segment.
Returns
-------
k_path : np.ndarray, shape (Npath, 3)
Cartesian k-points along the path.
k_dist : np.ndarray, shape (Npath,)
Cumulative distance for plotting x-axis.
labels : List[Tuple[int, str]]
Indices and labels for high-symmetry points.
k_path_frac : np.ndarray, shape (Npath, 3)
Fractional k-coordinates along the path.
Example
-------
>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> k_path, k_dist, labels, k_frac = lattice.bz_path()
>>> # Or with custom path:
>>> k_path, k_dist, labels, k_frac = lattice.bz_path(['Gamma', 'M', 'Gamma'])
"""
# Resolve path
if path is None:
resolved_path = self.default_bz_path()
if resolved_path is None:
raise ValueError(f"No default BZ path for {type(self).__name__}. Specify path explicitly.")
elif isinstance(path, list) and all(isinstance(p, str) for p in path):
# List of point names - look up in high_symmetry_points
hs_pts = self.high_symmetry_points()
if hs_pts is None:
raise ValueError(f"Cannot resolve point names for {type(self).__name__}. Use explicit fractional coordinates instead.")
resolved_path = hs_pts.get_path_points(path)
elif isinstance(path, (str, StandardBZPath)):
resolved_path = path # Will be resolved by brillouin_zone_path
else:
resolved_path = path
return brillouin_zone_path(self, resolved_path, points_per_seg=points_per_seg)
[docs]
def bz_path_points(self, path: Optional[Union[List[str], str, StandardBZPath]] = None, *,
points_per_seg: int = 40,
k_vectors: Optional[np.ndarray] = None, k_vectors_frac: Optional[np.ndarray] = None,
tol: float = 1e-12, periodic: bool = True) -> 'KPathSelection':
'''
Build an ideal Brillouin-zone path and optionally match it to an existing k-grid.
If no k-grid is provided, the returned object still contains the continuous
path geometry, which is useful for plotting or for constructing a path that
is not constrained to the sampled reciprocal mesh. When a sampled grid is
provided, reciprocal-lattice copies are generated automatically as needed
so paths in extended Brillouin-zone regions can still match the existing
data.
Parameters
----------
path : list of str, str, StandardBZPath, or None
Path specification. Can be:
- List of high-symmetry point names: ['Gamma', 'X', 'M', 'Gamma']
- StandardBZPath enum or string: 'SQUARE_2D'
- None: use default path for this lattice
points_per_seg : int
Number of interpolated points per path segment.
k_vectors : np.ndarray, shape (Nk, 3), optional
Cartesian k-vectors of the existing grid to match against.
k_vectors_frac : np.ndarray, shape (Nk, 3), optional
Fractional k-vectors of the existing grid to match against. Required if k_vectors is provided.
tol : float
Tolerance for matching path points to the existing k-grid. With
``periodic=True`` it is interpreted in fractional reciprocal
coordinates. With ``periodic=False`` it is interpreted in plotted
Cartesian reciprocal coordinates.
periodic : bool, default=True
If True, allow reciprocal-translation-equivalent points to match.
Set to False for visual matching in the displayed Brillouin-zone copy.
'''
try:
from .tools.lattice_kspace import bz_path_points
except ImportError:
raise ImportError("bz_path_points requires the lattice_kspace module. Please ensure it is available.")
return bz_path_points(
lattice=self, path=path, points_per_seg=points_per_seg,
k_vectors=k_vectors, k_vectors_frac=k_vectors_frac, tol=tol, periodic=periodic
)
[docs]
def bz_path_data(self,
k_vectors : np.ndarray,
k_vectors_frac : np.ndarray,
values : np.ndarray,
path : Optional[Union[List[str], PathTypes, str, StandardBZPath]] = None,
*,
points_per_seg : int = 40,
return_result : bool = True,
) -> Union[KPathResult, Tuple[np.ndarray, np.ndarray, List[Tuple[int, str]], np.ndarray]]:
r"""
Extract k-path data from a k-grid using fractional coordinate matching.
This function finds the closest k-points on the actual grid to an ideal path
through high-symmetry points. It handles periodic boundary conditions in k-space
and automatically reuses reciprocal-lattice copies of the sampled grid when the
requested path lies in an extended Brillouin-zone region.
It also allows to return a structured KPathResult dataclass or a tuple...
Parameters
----------
lattice : Lattice
Lattice object with reciprocal lattice vectors
k_vectors : np.ndarray, shape (..., 3)
Cartesian k-points (will be flattened)
k_vectors_frac : np.ndarray, shape (..., 3)
Fractional coordinates of k-points (will be flattened)
values : np.ndarray
Data values sampled on the k-grid. The k-grid axes may appear as
``(Lx, Ly, Lz, ...)`` or after leading batch axes such as time or
frequency, e.g. ``(Nw, Lx, Ly, Lz)`` or ``(Nw, Lx, Ly, Lz, ...)``.
A single flattened k-grid axis of length ``Nk`` is also supported.
path : various, optional
Path specification. Can be:
- StandardBZPath enum value (e.g., StandardBZPath.SQUARE_2D)
- String name (e.g., 'SQUARE_2D')
- List of (label, [f1,f2,f3]) tuples
- HighSymmetryPoints object (uses default path)
- None: uses lattice's default path if available
points_per_seg : int
Number of interpolated points per path segment
return_result : bool
If True (default), return KPathResult dataclass.
If False, return tuple for backwards compatibility.
Returns
-------
KPathResult or tuple
If return_result=True: KPathResult dataclass with all path data.
The returned ``values`` preserve any leading batch axes and replace
the k-grid axes with a path axis.
If return_result=False: (k_cart, k_frac, k_dist, labels, values) tuple
Examples
--------
>>> # Using default path from HighSymmetryPoints
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, HighSymmetryPoints.square_2d())
>>> plt.plot(result.k_dist, result.values)
>>> # Using standard path enum
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, 'SQUARE_2D')
>>> # Custom path
>>> custom_path = [('G', [0,0,0]), ('X', [0.5,0,0]), ('G', [0,0,0])]
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, custom_path)
"""
try:
from .tools.lattice_kspace import bz_path_data
except ImportError:
raise ImportError("bz_path_data requires the lattice_kspace module. Please ensure it is available.")
return bz_path_data(lattice=self, k_vectors=k_vectors, k_vectors_frac=k_vectors_frac, values=values, path=path,
points_per_seg=points_per_seg, return_result=return_result)
# ------------------------------------------------------------------
#! Boundary fluxes
# ------------------------------------------------------------------
@property
def flux(self) -> 'BoundaryFlux':
return self._flux
@flux.setter
def flux(self, value: Optional[Union[float, Mapping[Union[str, LatticeDirection], float]]]):
try:
from .tools.lattice_flux import _normalize_flux_dict
except ImportError:
raise ImportError("Setting flux requires the lattice_flux module. Please ensure it is available.")
self._flux = _normalize_flux_dict(value)
# When flux changes the BC becomes TWISTED (if non-trivial)
if self._flux is not None and self._flux.is_nontrivial:
self._bc = LatticeBC.TWISTED
[docs]
def set_flux(self, value: Optional[Union[float, Mapping[Union[str, LatticeDirection], float]]], *, reinit: bool = True) -> None:
"""
Set boundary flux and optionally recalculate k-vectors, DFT, and neighbors.
Parameters
----------
value : float, Mapping, or None
New flux specification (see :func:`_normalize_flux_dict`).
reinit : bool
If ``True`` (default), recalculate reciprocal vectors, k-vectors,
DFT matrix, and neighbor lists to be consistent with the new flux.
"""
self.flux = value # use the property setter
if reinit:
self.calculate_k_vectors()
if self.Ns < Lattice._DFT_LIMIT_SITES:
self.calculate_dft_matrix()
self.calculate_nn()
self.calculate_nnn()
@property
def has_flux(self) -> bool:
"""``True`` when a non-trivial boundary flux is attached."""
return self._flux is not None and bool(self._flux)
@property
def is_twisted(self) -> bool:
"""``True`` when the boundary conditions are TWISTED."""
return self._bc is LatticeBC.TWISTED
@property
def is_topological(self) -> bool:
r"""
``True`` when the lattice carries a non-trivial boundary flux.
A non-trivial flux (mod :math:`2\pi`) introduces a measurable Aharonov-Bohm
phase and may change the topological sector of the ground state.
"""
return self.has_flux
[docs]
def flux_summary(self) -> str:
"""Return a human-readable summary of the boundary-flux configuration."""
if self._flux is None:
return "No boundary flux (standard BC)"
parts = []
for d in LatticeDirection:
phi = self._flux.get(d)
phase = self._flux.phase(d)
parts.append(f" {d.name}: phi={phi:.4f} rad -> exp(i*phi)={phase:.4f}")
trivial = "TRIVIAL" if self._flux.is_trivial else "NON-TRIVIAL"
return f"Boundary fluxes ({trivial}):\n" + "\n".join(parts)
# ------------------------------------------------------------------
[docs]
def boundary_phase(self, direction: LatticeDirection, winding: int = 1) -> complex:
"""
Return the complex phase accumulated after crossing the boundary along ``direction``.
Parameters:
-----------
direction : LatticeDirection
The lattice direction (X, Y, or Z).
winding : int
The winding number (number of times crossing the boundary).
Returns:
--------
complex
The complex phase factor e^{i * flux * winding}.
"""
if self._flux is None:
return 1.0
return self._flux.phase(direction, winding=winding)
[docs]
def boundary_phases(self) -> np.ndarray:
"""
Return a lookup table of complex boundary phases.
Returns
-------
table : np.ndarray, shape ``(3, Ns+1)``
``table[d, w]`` is ``exp(i * w * phi_d)`` for direction *d* and
winding number *w*.
"""
ndirs = 3
ns = self.ns
table = np.ones((ndirs, ns + 1), dtype=np.complex128)
if self._flux is not None:
for d in LatticeDirection:
for w in range(ns + 1):
table[d.value, w] = self.boundary_phase(d, winding=w)
return table
[docs]
def boundary_phase_from_winding(self, wx: int, wy: int, wz: int) -> complex:
"""
Return total complex boundary phase accumulated from winding numbers.
If no winding (all zero), returns real 1.0.
"""
if wx == 0 and wy == 0 and wz == 0:
return 1.0
phase = 1.0
if wx != 0:
phase *= self.boundary_phase(LatticeDirection.X, winding=wx)
if wy != 0:
phase *= self.boundary_phase(LatticeDirection.Y, winding=wy)
if wz != 0:
phase *= self.boundary_phase(LatticeDirection.Z, winding=wz)
return phase if np.iscomplexobj(phase) and not np.isreal(phase) else float(np.real(phase))
[docs]
def bond_winding(self, i: int, j: int) -> tuple[int, int, int]:
"""
Compute how many times a bond (i -> j) crosses the periodic boundary
in each lattice direction.
Returns (wx, wy, wz), where each entry is 0 if no crossing,
+1 if wrapped positively, -1 if wrapped negatively.
Parameters:
-----------
i : int
Index of the starting lattice site.
j : int
Index of the ending lattice site.
Returns:
--------
tuple[int, int, int]
A tuple indicating the winding numbers (wx, wy, wz) for the bond from site i to site j.
"""
i, j = int(i), int(j)
x1, y1, z1 = self.get_coordinates(i)
x2, y2, z2 = self.get_coordinates(j)
wx = 0
wy = 0
wz = 0
# detect wrapping based on system size
if abs(x2 - x1) > self.Lx // 2: # assume even sizes, we wrap when crossing half the system
wx = -1 if x2 > x1 else +1
if self.dim > 1 and abs(y2 - y1) > self.Ly // 2: # assume even sizes, we wrap when crossing half the system
wy = -1 if y2 > y1 else +1
if self.dim > 2 and abs(z2 - z1) > self.Lz // 2: # assume even sizes, we wrap when crossing half the system
wz = -1 if z2 > z1 else +1
return (wx, wy, wz)
[docs]
def is_spanning(self, sites: Iterable[int]) -> bool:
"""
Check if a set of sites spans the lattice (non-contractible on a torus).
This method uses a BFS-based winding number tracking on the induced subgraph of
the provided site indices. If any loop with a non-zero winding number
along a periodic direction is found, the set is considered spanning.
"""
sites_set = set(sites)
if not sites_set:
return False
pflags = self.periodic_flags()
if not any(pflags):
return False # No periodic boundaries
remaining = set(sites_set)
while remaining:
start = next(iter(remaining))
visited = {start: (0, 0, 0)} # index -> (wx, wy, wz) relative to start
stack = [start]
remaining.remove(start)
while stack:
u = stack.pop()
curr_w = visited[u]
for v in self.nn[u]:
if v in sites_set:
wx, wy, wz = self.bond_winding(u, v)
target_w = (curr_w[0] + wx, curr_w[1] + wy, curr_w[2] + wz)
if v in visited:
# Already visited: check if winding is different
if visited[v] != target_w:
dw = (target_w[0] - visited[v][0],
target_w[1] - visited[v][1],
target_w[2] - visited[v][2])
# Winding is non-trivial if it aligns with a periodic direction
if (dw[0] != 0 and pflags[0]) or \
(dw[1] != 0 and pflags[1]) or \
(dw[2] != 0 and pflags[2]):
return True
else:
visited[v] = target_w
stack.append(v)
if v in remaining:
remaining.remove(v)
return False
[docs]
def bond_phase(self, i: int, j: int) -> complex:
r"""
Return the complex hopping phase factor for the bond :math:`i \to j`.
For bonds that do **not** cross a periodic boundary, this is 1.
For boundary-crossing bonds under TWISTED BC, the phase is
:math:`\exp(i\,\phi_\mu)` for each direction :math:`\mu` in which
the bond wraps.
This is the factor that should multiply the bare hopping amplitude
in real-space Hamiltonian construction.
Parameters
----------
i, j : int
Source and target site indices.
Returns
-------
complex
Phase factor (unit modulus).
"""
if self._flux is None:
return 1.0
wx, wy, wz = self.bond_winding(i, j)
return self.boundary_phase_from_winding(wx, wy, wz)
[docs]
def hopping_matrix_with_flux(self, *, include_nnn: bool = False) -> np.ndarray:
r"""
Build an :math:`N_s \times N_s` matrix of complex hopping amplitudes
that includes the Peierls phases from boundary fluxes.
Diagonal is zero. Off-diagonal ``H[i,j] = t_{ij} * phase(i->j)``
where ``t_{ij} = 1`` for all connected pairs and ``phase`` is the
product of boundary phases along directions that the bond wraps.
Parameters
----------
include_nnn : bool
If ``True``, include next-nearest-neighbor hoppings as well.
Returns
-------
H : np.ndarray, shape ``(Ns, Ns)``
Complex hopping matrix.
"""
Ns = self.Ns
H = np.zeros((Ns, Ns), dtype=complex)
for i in range(Ns):
for j in self._nn[i]:
if self.wrong_nei(j):
continue
j = int(j)
if 0 <= j < Ns:
H[i, j] = self.bond_phase(i, j)
if include_nnn and self._nnn is not None:
for i in range(Ns):
for j in self._nnn[i]:
if self.wrong_nei(j):
continue
j = int(j)
if 0 <= j < Ns:
H[i, j] += self.bond_phase(i, j)
return H
# ------------------------------------------------------------------
#! Chirality helpers
# ------------------------------------------------------------------
[docs]
def get_nnn_middle_sites(self, i: int, j: int, orientation: Optional[str] = None) -> list[int]:
"""
Return the list of 'middle' sites l that are nearest neighbors
of both i and j - i.e., sites forming two-step NNN paths i-l-j.
Works for any lattice that implements get_nn(site, idx)
and get_nn_num(site).
Parameters
----------
i, j : int
Site indices.
orientation : {'anticlockwise', 'clockwise', None}, optional
If provided, will sort/choose based on geometric angle.
Default: None (return all middle sites).
Returns
-------
list[int]
List of middle-site indices (can be 0, 1, or 2 elements).
"""
nn_i = self.get_nn(i)
nn_j = self.get_nn(j)
mids = list(set(nn_i).intersection(nn_j))
if not mids or orientation is None:
return mids
# Optional: choose one by local geometry
if len(mids) > 1:
ri = np.array(self.get_coordinates(i))
rj = np.array(self.get_coordinates(j))
centers = []
for l in mids:
rl = np.array(self.get_coordinates(l))
cross_z = np.cross(rl - ri, rj - rl)[-1]
centers.append((l, cross_z))
if orientation.lower().startswith("anti"):
mids = [l for (l, cz) in centers if cz > 0]
elif orientation.lower().startswith("clock"):
mids = [l for (l, cz) in centers if cz < 0]
return mids
[docs]
def get_chirality_sign(self, i: int, j: int, normal: Optional[np.ndarray] = None, orientation: Optional[str] = None) -> int:
r"""
Compute the local orientation (chirality) sign \nu_{ij} = \pm 1 for a NNN pair (i,j),
defined by the cross product of the two bond vectors i-l and l-j.
Works for any 2D or quasi-2D lattice with known site coordinates.
Parameters
----------
i, j : int
Site indices (next-nearest neighbors).
normal : np.ndarray, optional
Orientation of the lattice plane (default: +z for 2D).
Returns
-------
int
+1 for anticlockwise, -1 for clockwise, 0 if not a valid NNN pair.
"""
if self.dim < 2:
raise ValueError("Chirality sign is only defined for 2D or higher-dimensional lattices.")
# Default normal: +z
if normal is None:
normal = np.array([0, 0, 1.0])
# find common neighbor(s)
mids = self.get_nnn_middle_sites(i, j, orientation=orientation)
if not mids:
return 0
# Choose one middle site (if multiple, pick the first or average)
l = mids[0]
ri = np.array(self.get_coordinates(i), dtype=float)
rj = np.array(self.get_coordinates(j), dtype=float)
rl = np.array(self.get_coordinates(l), dtype=float)
d1 = rl - ri
d2 = rj - rl
cross = np.cross(d1, d2)
sign = np.sign(np.dot(cross, normal))
return int(sign)
# ------------------------------------------------------------------
#! Bond type helper
# ------------------------------------------------------------------
[docs]
def bond_type(self, i: int, j: int) -> str:
"""
Determine the bond type between sites i and j.
Parameters
----------
i, j : int
Site indices.
Returns
-------
str
'nn' for nearest neighbor, 'nnn' for next-nearest neighbor, 'none' otherwise.
"""
i, j = int(i), int(j)
if int(j) in self.nn[i]:
return 'nn'
elif int(j) in self.nnn[i]:
return 'nnn'
else:
return 'none'
# ------------------------------------------------------------------
#! Boundary helpers
# ------------------------------------------------------------------
[docs]
def periodic_flags(self) -> Tuple[bool, bool, bool]:
"""
Return booleans indicating whether (x, y, z) directions are periodic.
TWISTED boundary conditions are topologically equivalent to PBC
(the lattice is still a torus), so all three directions are periodic.
"""
match self._bc:
case LatticeBC.PBC:
return True, True, True
case LatticeBC.OBC:
return False, False, False
case LatticeBC.MBC:
return True, False, False
case LatticeBC.SBC:
return False, True, False
case LatticeBC.TWISTED:
# Twisted BCs are periodic with extra phases on boundary hops
return True, True, True
case _:
raise ValueError(f"Unsupported boundary condition {self._bc!r}")
[docs]
def is_periodic(self, direction: Optional[LatticeDirection] = None, allow_twisted: bool = True) -> bool:
"""
Check if a given direction has periodic boundary conditions.
"""
if direction is None:
return self.bc == LatticeBC.PBC or (allow_twisted and self.bc == LatticeBC.TWISTED)
flags = self.periodic_flags()
index = {LatticeDirection.X: 0, LatticeDirection.Y: 1, LatticeDirection.Z: 2}[direction]
return bool(flags[index])
@property
def typek(self): return self._type
@typek.setter
def typek(self, value): self._type = value
@property
def spatial_norm(self): return self._spatial_norm
@spatial_norm.setter
def spatial_norm(self, value): self._spatial_norm = value
# -----------------------------------------------------------------------------
[docs]
def site_index(self, x : int, y : int, z : int):
"""Convert (x, y, z) coordinates to a unique site index (row-major).
Default implementation uses standard lexicographic ordering.
Override in subclasses if a different indexing convention is needed.
"""
return z * (self._lx * self._ly) + y * self._lx + x
# -----------------------------------------------------------------------------
#! SITE HELPERS
# -----------------------------------------------------------------------------
[docs]
def site_diff(self, i: Union[int, tuple], j: Union[int, tuple], *, minimum_image: bool = False, real_space: bool = False) -> Tuple[float, float, float]:
"""
Return the displacement ``i -> j`` with optional PBC minimum-image wrapping.
Parameters
----------
i, j : int or tuple
Site indices or explicit coordinates.
minimum_image : bool, default=False
If True, wrap each periodic direction to the shortest displacement.
real_space : bool, default=False
If True and ``i, j`` are site indices, return displacement in real-space
vectors (uses :meth:`displacement`). Otherwise use lattice coordinates.
"""
if real_space and isinstance(i, int) and isinstance(j, int):
dr = np.asarray(self.displacement(i, j, minimum_image=minimum_image), dtype=float).reshape(-1)
if dr.size < 3:
dr = np.pad(dr, (0, 3 - dr.size), mode="constant")
return float(dr[0]), float(dr[1]), float(dr[2])
c1 = np.asarray(self.get_coordinates(i) if isinstance(i, int) else i, dtype=float).reshape(-1)
c2 = np.asarray(self.get_coordinates(j) if isinstance(j, int) else j, dtype=float).reshape(-1)
if c1.size < 3:
c1 = np.pad(c1, (0, 3 - c1.size), mode="constant")
if c2.size < 3:
c2 = np.pad(c2, (0, 3 - c2.size), mode="constant")
delta = c2[:3] - c1[:3]
if minimum_image:
flags = self.periodic_flags()
dims = (float(self.Lx), float(max(self.Ly, 1)), float(max(self.Lz, 1)))
for d, is_periodic in enumerate(flags):
if is_periodic and dims[d] > 0.0:
delta[d] -= dims[d] * np.round(delta[d] / dims[d])
return float(delta[0]), float(delta[1]), float(delta[2])
[docs]
def site_distance(self, i: Union[int, tuple], j: Union[int, tuple], *, minimum_image: bool = False, real_space: bool = False,) -> float:
"""
Return Euclidean distance between two sites/coordinates.
Parameters
----------
minimum_image : bool, default=False
If True, periodic directions use minimum-image convention.
real_space : bool, default=False
If True and inputs are indices, measure in real-space lattice vectors.
"""
if real_space and isinstance(i, int) and isinstance(j, int):
return float(self.distance(i, j, minimum_image=minimum_image))
x, y, z = self.site_diff(i, j, minimum_image=minimum_image, real_space=real_space)
return float(np.sqrt(x**2 + y**2 + z**2))
# -----------------------------------------------------------------------------
#! DFT MATRIX
# -----------------------------------------------------------------------------
[docs]
def calculate_reciprocal_vectors(self):
'''
Calculates the reciprocal lattice vectors based on the primitive vectors.
Always returns 3D vectors (padding with zeros for lower dimensions).
Returns:
- k1, k2, k3 : Reciprocal lattice vectors (always 3D)
'''
self._k1, self._k2, self._k3 = reciprocal_from_real(self.a1, self.a2, self.a3)
# Ensure 3D form (pad zeros if 1D or 2D)
self._k1 = np.pad(self._k1, (0, 3 - len(self._k1)))
self._k2 = np.pad(self._k2, (0, 3 - len(self._k2)))
self._k3 = np.pad(self._k3, (0, 3 - len(self._k3)))
return self._k1, self._k2, self._k3
[docs]
def calculate_dft_matrix(self, phase = False, use_fft: bool = False) -> np.ndarray:
r'''
Bloch-type DFT matrix on the site basis.
Indices:
i = (R, beta) real-space cell R and sublattice beta
n = (k, alpha) k-point k and sublattice alpha
Elements:
$$
F_{(k,alpha),(R,beta)} =
1/sqrt(Nc) * delta_{alpha,beta} * exp(-i k . R).
$$
This is unitary:
$$
F^\dagger F = I_{Ns}, F F^\dagger = I_{Ns},
$$
where Ns = Nc * Nb is the total number of sites, Nc is the number of
unit cells, and Nb is the number of sublattices.
IMPORTANT:
When boundary fluxes are present (TWISTED BC), the k-grid used to
build the DFT matrix is shifted by ``phi_mu / (2 pi L_mu)`` in each
direction, exactly as in :meth:`calculate_k_vectors`.
Note that this DFT matrix does not include basis-dependent phases
(i.e., exp(-i k . r_basis)).
Calculates the Discrete Fourier Transform (DFT) matrix for the lattice.
This method can be optimized using FFT (Fast Fourier Transform) in the future.
Reference: https://en.wikipedia.org/wiki/DFT_matrix
Parameters
-----------
- phase (bool): If True, adds a complex phase to the k-vectors.
Returns:
- DFT matrix (ndarray): The calculated DFT matrix.
'''
Ns = self.Ns
Lx, Ly, Lz = self._lx, max(self._ly, 1), max(self._lz, 1)
Nc = Lx * Ly * Lz
Nb = len(self._basis) if (self._basis is not None and len(self._basis) > 0) else 1 # Avoid division by zero
# Get site coordinates
cells = np.asarray(self._cells, dtype=float) # (Ns, 3)
if not cells.shape[0] == Ns:
raise ValueError("Mismatch in number of sites and coordinates.")
sub_idx = self.subs # (Ns,)
# Generate k-vectors (with flux-induced shift when applicable)
frac_x = np.linspace(0, 1, Lx, endpoint=False)
frac_y = np.linspace(0, 1, Ly, endpoint=False)
frac_z = np.linspace(0, 1, Lz, endpoint=False)
# Apply flux-induced shift to k-grid fractions
dfx, dfy, dfz = self._flux_frac_shift()
frac_x = frac_x + dfx
frac_y = frac_y + dfy
frac_z = frac_z + dfz
kx_frac, ky_frac, kz_frac = np.meshgrid(frac_x, frac_y, frac_z, indexing='ij')
b1 = np.asarray(self._k1, float).reshape(3)
b2 = np.asarray(self._k2, float).reshape(3)
b3 = np.asarray(self._k3, float).reshape(3)
kgrid = (kx_frac[..., None] * b1 +
ky_frac[..., None] * b2 +
kz_frac[..., None] * b3)
k_vectors = kgrid.reshape(-1, 3) # (Nc, 3)
# Build block DFT matrix
# Row index: ik*Nb + \alpha (k-point ik, sublattice \alpha)
# Col index: i (site i in real space)
F_block = np.zeros((Nc * Nb, Ns), dtype=complex) # DFT matrix
norm = np.sqrt(Nc) # Bloch normalization factor
# numpy path version (not used in loop)
phase_matrix = np.exp(-1j * (k_vectors @ cells.T)) / norm # (Nc, Ns)
selector = (sub_idx[None, :] == np.arange(Nb)[:, None]) # (Nb, Ns)
F = phase_matrix[:, None, :] * selector[None, :, :] # (Nc,Nb,Ns)
F_block = F.reshape(Nc * Nb, Ns)
self._dft = F_block
return F_block
# leave loop path
# for ik in range(Nc):
# # Phases for all sites at this k
# k = k_vectors[ik]
# phases = np.exp(-1j * (k @ r_vectors.T)) / norm # (Ns,)
# # Fill rows for this k-point (one row per sublattice)
# for alpha in range(Nb):
# row_idx = ik * Nb + alpha
# # Only connect to sites of sublattice \alpha
# for i in range(Ns):
# if sub_idx[i] == alpha:
# F_block[row_idx, i] = phases[i]
# return F_block
# -----------------------------------------------------------------------------
#! NEAREST NEIGHBORS
# -----------------------------------------------------------------------------
[docs]
def get_nei(self, site: int, **kwargs):
'''
Returns the nearest neighbors of a given site.
Parameters
-----------
- direction : direction of the lattice (can be X, Y, Z - default is X)
'''
if 'corr_len' in kwargs and 'direction' in kwargs:
direction = kwargs.get('direction', LatticeDirection.X)
corr_len = kwargs.get('corr_len', 1)
direction = self._adjust_direction(direction)
return self._get_neighbor_with_corr_len(site, direction, corr_len)
elif 'direction' in kwargs:
direction = kwargs.get('direction', LatticeDirection.X)
direction = self._adjust_direction(direction)
return self.get_nn_direction(site, direction)
else:
return self._nn[site]
[docs]
def get_nei_forward(self, site: int, num : int = -1):
'''
Returns the forward nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
- num : number of nearest neighbors
Returns:
- list of nearest neighbors
'''
if num < 0:
return self._nn_forward[site]
return self._nn_forward[site][num]
def _adjust_direction(self, direction : Union[LatticeDirection, int]):
'''
Adjust the direction to the lattice dimension
Parameters
-----------
- direction : direction of the lattice (can be X, Y, Z - default is X)
'''
if self.dim == 1:
return LatticeDirection.X
elif self.dim == 2 and LatticeDirection(direction) > LatticeDirection.Y:
return LatticeDirection.Y
elif self.dim == 3 and LatticeDirection(direction) > LatticeDirection.Z:
return LatticeDirection.Z
return direction
def _get_neighbor_with_corr_len(self, site : int, direction : Union[LatticeDirection, int], corr_len : int = 1):
'''
Returns the neighbor with a correlation length
- site : lattice site
- direction : direction of the lattice
- corr_len : correlation length
Returns:
- neighbor site (int)
'''
if self.bc == LatticeBC.PBC:
return self._pbc_neighbor(site, direction, corr_len)
elif self.bc == LatticeBC.OBC:
return self._obc_neighbor(site, direction, corr_len)
elif self.bc == LatticeBC.MBC:
return self._mbc_neighbor(site, direction, corr_len)
elif self.bc == LatticeBC.SBC:
return self._sbc_neighbor(site, direction, corr_len)
return (site + corr_len) % self._lxlylz
# -----------------------------------------------------------------------------
#! BOUNDARY CONDITIONS HELPERS
# -----------------------------------------------------------------------------
def _pbc_neighbor(self, site, direction, corr_len):
if direction == LatticeDirection.X:
return (site + corr_len) % self._lx
elif direction == LatticeDirection.Y:
return (site + corr_len * self._lx) % self._lxly
elif direction == LatticeDirection.Z:
return (site + corr_len * self._lxly) % self._lxlylz
def _obc_neighbor(self, site, direction, corr_len):
if direction == LatticeDirection.X:
return (site + corr_len) if (site + corr_len) < self._lx else self.bad_lattice_site
elif direction == LatticeDirection.Y:
return (site + corr_len * self._lx) if (site + corr_len * self._lx) < self._lxly else self.bad_lattice_site
elif direction == LatticeDirection.Z:
return (site + corr_len * self._lxly) if (site + corr_len * self._lxly) < self._lxlylz else self.bad_lattice_site
def _mbc_neighbor(self, site, direction, corr_len):
if direction == LatticeDirection.X:
return (site + corr_len) % self._lx
elif direction == LatticeDirection.Y:
return (site + corr_len * self._lx) if (site + corr_len * self._lx) < self._lxly else self.bad_lattice_site
elif direction == LatticeDirection.Z:
return (site + corr_len * self._lxly) if (site + corr_len * self._lxly) < self._lxlylz else self.bad_lattice_site
def _sbc_neighbor(self, site, direction, corr_len):
if direction == LatticeDirection.Y:
return (site + corr_len * self._lx) % self._lxly
elif direction == LatticeDirection.X:
return (site + corr_len) if (site + corr_len) < self._lx else self.bad_lattice_site
elif direction == LatticeDirection.Z:
return (site + corr_len * self._lxly) if (site + corr_len * self._lxly) < self._lxlylz else self.bad_lattice_site
# -----------------------------------------------------------------------------
#! Virtual methods
# -----------------------------------------------------------------------------
[docs]
@abstractmethod
def get_real_vec(self, x : int, y : int, z : int):
'''
Returns the real vector given the coordinates. Uses the lattice constants.
'''
pass
[docs]
@abstractmethod
def get_norm(self, x : int, y : int, z : int):
'''
Returns the norm of the vector given the coordinates.
'''
pass
[docs]
@abstractmethod
def get_nn_direction(self, site : int, direction : LatticeDirection):
'''
Returns the nearest neighbors in a given direction.
'''
pass
[docs]
def get_nnn_direction(self, site : int, direction : LatticeDirection):
'''
Returns the next nearest neighbors in a given direction.
'''
pass
# -----------------------------------------------------------------------------
#! NEIGHBOR VALIDATION
# -----------------------------------------------------------------------------
[docs]
def wrong_nei(self, nei):
"""
Check if a given neighbor index is invalid.
A neighbor is considered invalid if it is:
- None
- Equal to self.bad_lattice_site
- NaN (not a number)
- Less than 0
Parameters
----------
nei : Any
The neighbor index to check.
Returns
-------
bool
True if the neighbor index is invalid, False otherwise.
"""
return nei is None or \
nei == self.bad_lattice_site or \
np.isnan(nei) or \
nei < 0
# -----------------------------------------------------------------------------
#! NEAREST NEIGHBORS HELPERS
# -----------------------------------------------------------------------------
[docs]
def get_nn_num(self, site : int):
'''
Returns the number of nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
Returns:
- number of nearest neighbors
'''
if self._nn is None:
return 0
return len(self.nn[site])
[docs]
def get_nn(self, site, num : int = -1):
'''
Returns the nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
- num : number of nearest neighbors
Returns:
- list of nearest neighbors
'''
if num < 0:
return self._nn[site]
if self._nn is None:
return []
return self._nn[site][num]
[docs]
def get_nnn_num(self, site : int):
'''
Returns the number of next nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
Returns:
- number of next nearest neighbors
'''
if self._nnn is None:
return 0
return len(self._nnn[site])
[docs]
def get_nnn(self, site, num : int = -1):
'''
Returns the next nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
- num : number of next nearest neighbors
Returns:
- list of next nearest neighbors
'''
if num < 0:
return self._nnn[site]
if self._nnn is None:
return []
return self._nnn[site][num]
# -----------------------------------------------------------------------------
#! FORWARD NEAREST NEIGHBORS HELPERS
# -----------------------------------------------------------------------------
[docs]
def get_nn_forward_num_max(self):
'''
Returns the maximum number of forward nearest neighbors in the lattice.
Returns:
- maximum number of nearest neighbors
'''
if (self._nn_max_num is None or self._nn_max_num == 0) and self.nn_forward is not None:
max_nn = 0
for site in range(self.ns):
nn_num = len(self.nn_forward[site])
if nn_num > max_nn:
max_nn = nn_num
self._nn_max_num = max_nn
return self._nn_max_num
[docs]
def get_nn_forward_num(self, site : int):
'''
Returns the number of forward nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
Returns:
- number of nearest neighbors
'''
return len(self.nn_forward[site])
[docs]
def get_nn_forward(self, site : int, num : int = -1):
'''
Returns the forward nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
- num : number of nearest neighbors
Returns:
- list of nearest neighbors
'''
if not hasattr(self, '_nn_forward') or self._nn_forward is None:
return [] if num < 0 else -1
if num < 0:
return self._nn_forward[site]
return self._nn_forward[site][num] if num < len(self._nn_forward[site]) else -1
# -----------------------------------------------------------------------------
#! FORWARD NEXT NEAREST NEIGHBORS HELPERS
# -----------------------------------------------------------------------------
[docs]
def get_nnn_forward_num(self, site : int):
'''
Returns the number of forward next nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
Returns:
- number of next nearest neighbors
'''
return len(self.nnn_forward[site])
[docs]
def get_nnn_forward(self, site : int, num : int = -1):
'''
Returns the forward next nearest neighbors of a given site.
Parameters
-----------
- site : lattice site
- num : number of next nearest neighbors
Returns:
- list of next nearest neighbors
'''
if not hasattr(self, '_nnn_forward') or self._nnn_forward is None:
return [] if num < 0 else -1
if num < 0:
return self._nnn_forward[site]
return self._nnn_forward[site][num] if num < len(self._nnn_forward[site]) else -1
# -----------------------------------------------------------------------------
#! GENERAL NEIGHBORS HELPERS
# -----------------------------------------------------------------------------
[docs]
def neighbors(self, site: int, order=1):
'''Return neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), 'all' for both.'''
if order == 1:
return self._nn[site]
elif order == 2:
return self._nnn[site]
elif order == 'all':
if self._adj_mat is not None:
# return all neighbors from adjacency matrix
non_zero_indices = np.nonzero(self._adj_mat[site])[0]
return [i for i in non_zero_indices if i != site]
else:
return list(set(self._nn[site]) | set(self._nnn[site]))
else:
raise ValueError(f"Invalid neighbor order: {order}")
[docs]
def neighbors_forward(self, site: int, order=1):
'''Return forward neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), 'all' for both.'''
if order == 1:
return self._nn_forward[site]
elif order == 2:
return self._nnn_forward[site]
elif order == 'all':
if self._adj_mat is not None:
# return all neighbors from adjacency matrix
non_zero_indices = np.nonzero(self._adj_mat[site])[0]
return [i for i in non_zero_indices if i != site and i > site]
else:
return list(set(self._nn_forward[site]) | set(self._nnn_forward[site]))
else:
raise ValueError(f"Invalid neighbor order: {order}")
[docs]
def any_neighbor(self, site: int, order=1):
'''Return any neighbor (first) of given order or None.'''
neigh = self.neighbors(site, order)
return neigh[0] if neigh else Lattice._BAD_LATTICE_SITE
[docs]
def any_neighbor_forward(self, site: int, order=1):
'''Return any forward neighbor (first) of given order or None.'''
neigh = self.neighbors_forward(site, order)
return neigh[0] if neigh else Lattice._BAD_LATTICE_SITE
# =========================================================================
#! NetKet-inspired convenience API
# =========================================================================
@property
def n_nodes(self) -> int:
"""Number of nodes (sites) in the lattice — alias for ``Ns``."""
return self._ns
@property
def n_edges(self) -> int:
"""Number of unique undirected nearest-neighbour edges."""
return len(self.edges())
@property
def positions(self) -> np.ndarray:
"""Real-space position vectors (same as ``rvectors``)."""
return self.rvectors
@property
def site_offsets(self) -> np.ndarray:
"""Position offsets of sites inside the unit cell (same as ``basis``)."""
return self._basis
@property
def basis_coords(self) -> np.ndarray:
"""
Integer basis coordinates ``[nx, ny, nz, sub]`` for every site.
Shape ``(Ns, 4)`` — the first three columns are the cell-index
triplet and the last column is the sublattice label.
"""
if self._fracs is None or self._subs is None:
return None
return np.column_stack([self._fracs, self._subs])
@property
def ndim(self) -> int:
"""Spatial dimensionality of the lattice."""
return self._dim
@property
def extent(self) -> Tuple[int, ...]:
"""Number of unit cells in each direction ``(Lx, Ly, Lz)``."""
return (self._lx, self._ly, self._lz)
@property
def pbc(self) -> Tuple[bool, bool, bool]:
"""Per-axis periodicity flags (alias for ``periodic_flags()``)."""
return self.periodic_flags()
# Edge / bond queries
[docs]
def edges(self, *, filter_color: Optional[int] = None,
return_color: bool = False) -> List:
"""
Return list of nearest-neighbour edges.
Parameters
----------
filter_color : int, optional
If given, return only edges whose ``bond_type`` equals this colour.
return_color : bool
If *True* each element is ``(i, j, color)``; otherwise ``(i, j)``.
Returns
-------
list[tuple]
Unique undirected edges ``(i, j)`` with ``i < j``.
"""
if not hasattr(self, '_bonds') or not self._bonds:
self.calculate_bonds()
result = []
for i, j in self._bonds:
a, b = (i, j) if i < j else (j, i)
c = self.bond_type(a, b)
if filter_color is not None and c != filter_color:
continue
if return_color:
result.append((a, b, c))
else:
result.append((a, b))
# Deduplicate (forward list may still have symmetric pairs)
if not return_color:
result = sorted(set(result))
return result
@property
def edge_colors(self) -> List[int]:
"""
Sequence of bond-type colours for every edge in ``edges()``,
matching the order returned by ``edges()``.
"""
return [c for (_, _, c) in self.edges(return_color=True)]
# -- Displacement helpers -----------------------------------------------
[docs]
def displacement(self, i: int, j: int, *, minimum_image: bool = True) -> np.ndarray:
"""
Real-space displacement vector from site *i* to site *j*.
Parameters
----------
i, j : int
Site indices.
minimum_image : bool
If *True* (default) and the lattice is periodic, return the
shortest displacement under periodic boundary conditions.
Returns
-------
np.ndarray shape (3,)
"""
i, j = int(i), int(j)
dr = self.rvectors[j] - self.rvectors[i]
if not minimum_image:
return dr
# Minimum-image convention using fractional coordinates
flags = self.periodic_flags()
dims = [self._lx, max(self._ly, 1), max(self._lz, 1)]
dn = np.array(self._fracs[j], dtype=float) - np.array(self._fracs[i], dtype=float)
for d in range(3):
if flags[d]:
L = dims[d]
dn[d] -= L * np.round(dn[d] / L)
dr = dn[0] * self._a1 + dn[1] * self._a2 + dn[2] * self._a3
dr += self._basis[self._subs[j]] - self._basis[self._subs[i]]
return dr
[docs]
def distance(self, i: int, j: int, *, minimum_image: bool = True) -> float:
"""Euclidean distance between sites *i* and *j* (PBC-aware by default)."""
return float(np.linalg.norm(self.displacement(i, j, minimum_image=minimum_image)))
# -----------------------------------------------------------------------------
#! Standard getters
# -----------------------------------------------------------------------------
[docs]
def get_coordinates(self, *args): return self._coordinates if len(args) == 0 else self._coordinates[args[0]]
[docs]
def get_r_vectors(self,*args): return self._rvectors if len(args) == 0 else self._rvectors[args[0]]
[docs]
def get_k_vectors(self, *args): return self._kvectors if len(args) == 0 else self._kvectors[args[0]]
[docs]
def get_site_diff(self, i: int, j: int): return self.get_coordinates(j) - self.get_coordinates(i)
[docs]
def get_k_vec_idx(self, sym = False): pass
[docs]
def get_dft(self, *args):
'''
Returns the DFT matrix
'''
if len(args) == 0:
return self.dft
elif len(args) == 1:
# return row
return self.dft[args[0]]
else:
# return element
return self.dft[args[0], args[1]]
# -----------------------------------------------------------------------------
#! Spatial information
# -----------------------------------------------------------------------------
[docs]
def get_spatial_norm(self, *args):
'''
Returns the spatial norm at lattice site i or all of them
'''
if len(args) == 0:
return self.spatial_norm
elif len(args) == 1:
return self.spatial_norm[args[0]]
elif len(args) == 2:
return self.spatial_norm[args[0]][args[1]]
else:
return self.spatial_norm[args[0]][args[1]][args[2]]
[docs]
def get_difference_idx_matrix(self, cut = True) -> list:
'''
Returns the matrix with indcies corresponding to a slice from the QMC.
A usefull function for reading the position Green's function saved from:
@url https://github.com/makskliczkowski/DQMC
The Green's functions are saved in the following manner. If cut is True, data
has (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i.
For 1D simulation: 1 column and (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, ..., 0, ..., Lx)
For 2D simulation: (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, ..., 0, ..., Lx) and (2 * Ly - 1) columns for possition differences (-Ly, -Ly + 1, ..., 0, ..., Ly)
For 3D simulation: Same as in 2D but after (2 * Lx - 1) x (2 * Ly - 1) matrix has finished, a new slice for Lz appears for next columns Lz * (2*Ly - 1)
- cut : if true (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i.
'''
Lx, Ly, Lz = self.Lx, self.Ly, self.Lz
xnum = 2 * Lx - 1 if cut else Lx
ynum = 2 * Ly - 1 if cut else Ly
znum = 2 * Lz - 1 if cut else Lz
_slice = [[[0, 0, 0] for _ in range(ynum * znum)] for _ in range(xnum)]
for k in range(znum):
z = k - (Lz if cut else 0)
for i in range(xnum):
x = i - (Lx if cut else 0)
for j in range(ynum):
y = j - (Ly if cut else 0)
# x's are the rows and y's (*z's) are the columns
_slice[i][j + k * ynum][0] = x
_slice[i][j + k * ynum][1] = y
_slice[i][j + k * ynum][2] = z
return [[tuple(_slice[i][j]) for j in range(ynum * znum)] for i in range(xnum)]
############################ ABSTRACT CALCULATORS #############################
[docs]
def calculate_bonds(self):
'''
Calculates the bonds for the lattice using forward nn.
'''
self._bonds = []
for i in range(self.Ns):
nn_num = self.get_nn_forward_num(i)
for idx in range(nn_num):
j = self.get_nn_forward(i, idx)
if self.wrong_nei(j):
continue
self._bonds.append((i, j))
return self._bonds
[docs]
def calculate_coordinates(self):
"""
Calculates the coordinates for each lattice site in up to 3D.
Each site index i corresponds to:
cell = i // n_basis
sub = i % n_basis
where n_basis = len(self._basis) (e.g., 2 for honeycomb).
Works for any lattice with defined self._a1, _a2, _a3 and self._basis list.
"""
n_basis = len(self._basis)
indices = np.arange(self.Ns)
cell = indices // n_basis # integer division
sub = indices % n_basis # remainder
nx = cell % self.Lx
ny = (cell // self.Lx) % self.Ly if self._dim >= 2 else np.zeros_like(cell)
nz = (cell // (self.Lx * self.Ly)) % self.Lz if self._dim >= 3 else np.zeros_like(cell)
R = nx[:, None] * self._a1 + ny[:, None] * self._a2 + nz[:, None] * self._a3 # lattice vector
r = R + self._basis[sub] # add basis vector
self._coordinates = r
self._cells = R
self._fracs = np.stack((nx, ny, nz), axis=1)
self._subs = sub
return self._coordinates
[docs]
def calculate_r_vectors(self):
"""
Calculates the real-space vectors (r) for each site.
Must match the ordering in calculate_coordinates().
"""
n_basis = len(self._basis)
rv = np.zeros((self.Ns, 3))
for i in range(self.Ns):
cell = i // n_basis
sub = i % n_basis
nx = cell % self.Lx
ny = (cell // self.Lx) % self.Ly if self._dim >= 2 else 0
nz = (cell // (self.Lx * self.Ly)) % self.Lz if self._dim >= 3 else 0
rv[i] = nx * self._a1 + ny * self._a2 + nz * self._a3 + self._basis[sub]
self.rvectors = rv
return self.rvectors
def _flux_frac_shift(self) -> Tuple[float, float, float]:
r"""
Return the fractional k-grid shift induced by boundary fluxes.
If the flux in direction :math:`\mu` is :math:`\phi_\mu`, the standard
fractional coordinate :math:`f_\mu = n_\mu / L_\mu` is shifted to
:math:`f_\mu + \phi_\mu / (2\pi\,L_\mu)`.
Returns
-------
(dfx, dfy, dfz) : tuple[float, float, float]
"""
if self._flux is None:
return (0.0, 0.0, 0.0)
Ly = self._ly if self._dim >= 2 else 1
Lz = self._lz if self._dim >= 3 else 1
return self._flux.k_shift_fractions(self._lx, Ly, Lz)
[docs]
def calculate_k_vectors(self):
"""
Calculates the allowed reciprocal-space k-vectors (momentum grid)
consistent with the lattice size and primitive reciprocal vectors.
When boundary fluxes are present (TWISTED BC), the fractional
coordinates are shifted by :math:`\\phi_\\mu / (2\\pi L_\\mu)` in
each direction, so that the Bloch condition matches the twisted
boundary.
The sampling follows the same fftfreq ordering used by the Bloch
transform (Γ at index [0,0,0], followed by positive frequencies and
finally the negative branch). This keeps the analytic grids aligned
with the numerically constructed H(k) blocks.
"""
Lx = self.Lx
Ly = self.Ly if self._dim >= 2 else 1
Lz = self.Lz if self._dim >= 3 else 1
frac_x = np.fft.fftfreq(Lx)
frac_y = np.fft.fftfreq(Ly)
frac_z = np.fft.fftfreq(Lz)
# Apply flux-induced shift to k-grid fractions
dfx, dfy, dfz = self._flux_frac_shift()
frac_x = frac_x + dfx
frac_y = frac_y + dfy
frac_z = frac_z + dfz
kx_frac, ky_frac, kz_frac = np.meshgrid(frac_x, frac_y, frac_z, indexing="ij")
k_grid = (
kx_frac[..., None] * self._k1
+ ky_frac[..., None] * self._k2
+ kz_frac[..., None] * self._k3
)
self.kvectors = k_grid.reshape(-1, 3)
self.kvectors_frac = np.stack([kx_frac, ky_frac, kz_frac], axis=-1).reshape(-1, 3)
return self.kvectors
[docs]
def filter_k_vectors(self, qx: Optional[int] = None, qy: Optional[int] = None, qz: Optional[int] = None) -> np.ndarray:
"""
Filters the k-vectors to find those matching the specified fractional components.
Parameters
-----------
qx (int): Fractional component in the x-direction.
qy (int, optional): Fractional component in the y-direction. Defaults to None.
qz (int, optional): Fractional component in the z-direction. Defaults to None.
Returns:
np.ndarray: Array of indices of k-vectors matching the specified components.
"""
if self.kvectors_frac is None:
raise ValueError("k-vectors have not been calculated yet.")
mask = np.ones(len(self.kvectors_frac), dtype=bool)
if qx is not None:
mask &= (self.kvectors_frac[:, 0] == qx / self.Lx)
if qy is not None and self._dim >= 2:
mask &= (self.kvectors_frac[:, 1] == qy / self.Ly)
if qz is not None and self._dim >= 3:
mask &= (self.kvectors_frac[:, 2] == qz / self.Lz)
return np.where(mask)[0]
[docs]
def translation_operators(self):
"""Return translation matrices T1, T2, T3 on the one-hot basis."""
self._T1, self._T2, self._T3 = build_translation_operators(self)
return self._T1, self._T2, self._T3
# -----------------------------------------------------------------------------
#! Spatial norm calculators
# -----------------------------------------------------------------------------
[docs]
def calculate_norm_sym(self):
"""
Calculate a symmetry-normalization measure for each site.
Default: Euclidean norm of the coordinate vector.
Override in subclasses for lattice-specific behaviour.
"""
self._spatial_norm = { i: np.linalg.norm(self._coordinates[i]) for i in range(self._ns) }
# -----------------------------------------------------------------------------
#! Nearest neighbors
# -----------------------------------------------------------------------------
def _calculate_nn_pbc(self): return self.calculate_nn_in(True, True, True)
def _calculate_nn_obc(self): return self.calculate_nn_in(False, False, False)
def _calculate_nn_mbc(self): return self.calculate_nn_in(True, False, False)
def _calculate_nn_sbc(self): return self.calculate_nn_in(False, True, False)
[docs]
@abstractmethod
def calculate_nn_in(self, pbcx : bool, pbcy : bool, pbcz : bool): pass
[docs]
def calculate_nn(self):
'''
Calculates the nearest neighbors.
For TWISTED boundary conditions the neighbor *connectivity* is
identical to PBC — the flux phases are applied separately when
building the Hamiltonian or the DFT matrix.
'''
match (self._bc):
case LatticeBC.PBC:
self._calculate_nn_pbc()
case LatticeBC.OBC:
self._calculate_nn_obc()
case LatticeBC.MBC:
self._calculate_nn_mbc()
case LatticeBC.SBC:
self._calculate_nn_sbc()
case LatticeBC.TWISTED:
# Twisted BC: same neighbor connectivity as PBC
self._calculate_nn_pbc()
case _:
raise ValueError("The boundary conditions are not implemented.")
[docs]
def calculate_plaquettes(self, use_obc: bool = True): raise NotImplementedError("Plaquette calculation not implemented for this lattice.")
[docs]
def calculate_wilson_loops(self):
"""
Calculates the Wilson loops (non-contractible loops) for the lattice based on its boundary conditions.
Returns a list of lists, where each inner list contains the site indices of a Wilson loop.
Assumes standard lexicographic site indexing (x + y*Lx + z*Lx*Ly).
"""
loops = []
is_pbc_x, is_pbc_y, is_pbc_z = self.periodic_flags()
# Wilson loop along X (at y=0, z=0)
if is_pbc_x and self.lx > 0:
loops.append(list(range(self.lx)))
# Wilson loop along Y (at x=0, z=0)
if is_pbc_y and self.dim >= 2 and self.ly > 0:
loops.append(list(range(0, self.ly * self.lx, self.lx)))
# Wilson loop along Z (at x=0, y=0)
if is_pbc_z and self.dim >= 3 and self.lz > 0:
loops.append(list(range(0, self.lz * self.lx * self.ly, self.lx * self.ly)))
return loops
# -----------------------------------------------------------------------------
#! Next nearest neighbors
# -----------------------------------------------------------------------------
def _calculate_nnn_pbc(self): return self.calculate_nnn_in(True, True, True)
def _calculate_nnn_obc(self): return self.calculate_nnn_in(False, False, False)
def _calculate_nnn_mbc(self): return self.calculate_nnn_in(True, False, False)
def _calculate_nnn_sbc(self): return self.calculate_nnn_in(False, True, False)
[docs]
@abstractmethod
def calculate_nnn_in(self, pbcx : bool, pbcy : bool, pbcz : bool): pass
[docs]
def calculate_nnn(self):
'''
Calculates the next nearest neighbors.
Like :meth:`calculate_nn`, each ``calculate_nnn_in`` implementation
is expected to set ``self._nnn`` (and optionally ``self._nnn_forward``)
directly. The return value—if any—is stored as a fallback.
'''
match (self._bc):
case LatticeBC.PBC:
self._calculate_nnn_pbc()
case LatticeBC.OBC:
self._calculate_nnn_obc()
case LatticeBC.MBC:
self._calculate_nnn_mbc()
case LatticeBC.SBC:
self._calculate_nnn_sbc()
case LatticeBC.TWISTED:
# Twisted BC: same neighbor connectivity as PBC
self._calculate_nnn_pbc()
case _:
raise ValueError("The boundary conditions are not implemented.")
# -----------------------------------------------------------------------------
#! Saving the lattice
# -----------------------------------------------------------------------------
[docs]
def adjacency_matrix(self,
sparse : bool = False,
save : bool = True,
*,
mode : str = 'binary',
include_self : bool = False,
include_nnn : bool = False,
typed_self_separate : bool = True,
n_types : int = 3
) -> np.ndarray:
r"""
Construct adjacency matrix A_ij = 1 if i and j are neighbors.
Parameters:
save (bool):
save the adjacency matrix in the lattice object for future use.
mode (str):
'binary' :
A_ij = 1 if i and j are neighbors, 0 otherwise.
'typed' :
A_ij = weight of the bond between i and
j (1 for nn, 2 for nnn, etc.), 0 otherwise.
include_self (bool):
include self-connections (diagonal elements) if True.
include_nnn (bool):
include next-nearest neighbors if True.
typed_self_separate (bool):
if True, self-connections are given a unique weight (n_types)
to distinguish them from other types of connections.
n_types (int):
number of different neighbor types (nn, nnn, etc.) to consider.
sparse (bool):
return a scipy.sparse CSR matrix if True.
Returns:
A (ndarray or sparse CSR): adjacency matrix of size (Ns, Ns).
"""
mode = str(mode).lower()
if mode not in ("binary", "typed"):
raise ValueError("mode must be 'binary' or 'typed'")
Ns = int(self.ns)
# caching: keep separate caches per mode (so you don't collide)
if not hasattr(self, "_adj_cache"):
self._adj_cache = {}
cache_key = (mode, include_self, include_nnn, typed_self_separate, n_types, sparse)
if save and cache_key in self._adj_cache:
return self._adj_cache[cache_key]
# =====================================================================
# Binary adjacency
# =====================================================================
if mode == "binary":
A = np.zeros((Ns, Ns), dtype=np.float32)
# Nearest neighbors from _nn lists
if getattr(self, "_nn", None):
for i in range(Ns):
nbrs = self._nn[i] if self._nn[i] else []
for j in nbrs:
if self.wrong_nei(j):
continue
j = int(j)
if 0 <= j < Ns and j != i:
A[i, j] = 1.0
A[j, i] = 1.0
# Optional NNN
if include_nnn and getattr(self, "_nnn", None):
for i in range(Ns):
nbrs = self._nnn[i] if self._nnn[i] else []
for j in nbrs:
if self.wrong_nei(j):
continue
j = int(j)
if 0 <= j < Ns and j != i:
A[i, j] = 1.0
A[j, i] = 1.0
if include_self:
np.fill_diagonal(A, 1.0)
if sparse:
import scipy.sparse as sp
A = sp.csr_matrix(A)
if save:
self._adj_cache[cache_key] = A
return A
# =====================================================================
# Typed adjacency (e.g., Kitaev x/y/z bonds)
# =====================================================================
# We need a way to get bond types. We support multiple lattice APIs:
# 1) self.bonds_by_type() -> list/tuple length n_types, each a list of (i,j)
# 2) self.edge_types dict {(i,j): t} or {(min(i,j),max(i,j)): t}
# 3) self._bond_types list of (i,j,t)
# Otherwise: we raise with a clear message.
A_types = np.zeros((n_types, Ns, Ns), dtype=np.float32)
def add_typed_edge(i, j, t):
i = int(i); j = int(j); t = int(t)
if not (0 <= i < Ns and 0 <= j < Ns):
return
if i == j:
return
if not (0 <= t < n_types):
raise ValueError(f"Bond type t={t} outside [0,{n_types})")
A_types[t, i, j] = 1.0
A_types[t, j, i] = 1.0
used = False
# (1) bonds_by_type()
if hasattr(self, "bonds_by_type") and callable(self.bonds_by_type):
by_t = self.bonds_by_type()
if by_t is not None and len(by_t) >= n_types:
for t in range(n_types):
for (i, j) in by_t[t]:
add_typed_edge(i, j, t)
used = True
# (2) edge_types mapping
if not used and hasattr(self, "edge_types"):
et = getattr(self, "edge_types")
if isinstance(et, dict) and len(et) > 0:
for (i, j), t in et.items():
add_typed_edge(i, j, t)
used = True
# (3) internal list of typed bonds
if not used and hasattr(self, "_bond_types"):
bt = getattr(self, "_bond_types")
if bt is not None and len(bt) > 0:
for (i, j, t) in bt:
add_typed_edge(i, j, t)
used = True
# Fallback: if lattice has only _nn but no type info, you *cannot* make Kitaev-typed adjacency
if not used:
raise ValueError(
"typed adjacency requested, but no bond-type information found on this Lattice.\n"
"Add one of:\n"
" - Lattice.bonds_by_type() -> list length 3 of (i,j) bonds for x,y,z\n"
" - Lattice.edge_types dict {(i,j): t}\n"
" - Lattice._bond_types list of (i,j,t)\n"
"Otherwise use mode='binary'."
)
A_self = None
if include_self:
if typed_self_separate:
A_self = np.eye(Ns, dtype=np.float32)
else:
# add self-loops into every type channel (rarely what you want)
for t in range(n_types):
np.fill_diagonal(A_types[t], 1.0)
if sparse:
import scipy.sparse as sp
A_types = np.array([sp.csr_matrix(A_types[t]) for t in range(n_types)], dtype=object)
if A_self is not None:
A_self = sp.csr_matrix(A_self)
out = (A_types, A_self)
if save:
self._adj_cache[cache_key] = out
return out
[docs]
def print_neighbors(self, logger : 'Logger'):
"""
Logs the neighbors of each site in the lattice using the provided logger.
For each site in the lattice, this method retrieves its nearest neighbors and logs their indices.
Additionally, for each neighbor, it logs detailed information using a higher verbosity level.
Parameters
-----------
logger:
An object with an `info` method for logging messages. The `info` method should accept
parameters `lvl` (int) for verbosity level and `color` (str) for message color.
"""
def print_nei(msg, lvl = 1, color = 'green'):
if logger is not None:
logger.info(msg, lvl = lvl, color = color)
else:
print(msg)
for i in range(self.ns):
neighbors = self.get_nn(i)
print_nei(f"Neighbors of site {i}: {neighbors}", lvl = 1, color = 'green')
for j in range(len(neighbors)):
nei_in = self.get_nei(i, j)
print_nei(f"Neighbor {j} of site {i}: {nei_in}", lvl = 2, color = 'blue')
[docs]
def print_forward(self, logger : 'Logger'):
"""
Logs the forward nearest neighbors for each site in the lattice.
For each site in the lattice, this method retrieves the number of forward nearest neighbors
and logs their indices using the provided logger. The method outputs two levels of information:
- Level 1 (green): Lists the neighbors of each site.
- Level 2 (blue): Details each neighbor's index for the site.
Parameters
-----------
logger: A logging object with an `info` method that accepts a message,
a logging level (`lvl`), and a color (`color`).
"""
def print_nei(msg, lvl = 1, color = 'green'):
if logger is not None:
logger.info(msg, lvl = lvl, color = color)
else:
print(msg)
for i in range(self.ns):
neighbors = self.get_nn_forward_num(i)
print_nei(f"Neighbors of site {i}: {neighbors}", lvl = 1, color = 'green')
for j in range(neighbors):
nei_in = self.get_nn_forward(i, j)
print_nei(f"Neighbor {j} of site {i}: {nei_in}", lvl = 2, color = 'blue')
# -------------------------------------------------------------------------
#! Bloch Transform & Basis Operations
# -------------------------------------------------------------------------
[docs]
def get_geometric_encoding(self, *, tol=1e-6):
"""
Map each site i to (cell_idx, sub_idx) purely from geometry.
Returns
-------
cell_idx : (Ns,) int array in [0, Nc-1]
sub_idx : (Ns,) int array in [0, Nb-1]
"""
coords = np.asarray(self._coordinates, float) # (Ns,3)
a1 = np.asarray(self._a1, float).reshape(3)
a2 = np.asarray(self._a2, float).reshape(3)
a3 = np.asarray(self._a3, float).reshape(3)
A = np.column_stack([a1, a2, a3]) # (3,3)
Ainv = np.linalg.inv(A)
Nb = len(self._basis)
basis = np.zeros((Nb,3), float)
basis[:, :self._basis.shape[1]] = np.asarray(self._basis, float)
Lx, Ly, Lz = self._lx, max(self._ly,1), max(self._lz,1)
Nc = Lx*Ly*Lz
Ns = coords.shape[0]
# fractional cell coords (may be non-integers due to numeric noise)
frac = (Ainv @ coords.T).T # (Ns,3)
# wrap to [0,L) and round to nearest cell
cx = np.mod(np.rint(frac[:,0]), Lx).astype(int)
cy = np.mod(np.rint(frac[:,1]), max(Ly,1)).astype(int) if self._dim >= 2 else np.zeros(Ns, int)
cz = np.mod(np.rint(frac[:,2]), max(Lz,1)).astype(int) if self._dim >= 3 else np.zeros(Ns, int)
# residual within unit cell
Rrec = (cx[:,None]*a1[None,:] + cy[:,None]*a2[None,:] + cz[:,None]*a3[None,:]) # (Ns,3)
r_in = coords - Rrec
# assign sublattice by nearest basis vector
d2 = ((r_in[:,None,:] - basis[None,:,:])**2).sum(axis=2) # (Ns,Nb)
sub = np.argmin(d2, axis=1)
if not np.all(np.take_along_axis(d2, sub[:,None], axis=1)[:,0] < tol):
# If this trips, increase tol or check a1,a2,a3/basis consistency
raise ValueError("Some sites could not be matched to a basis position; adjust tolerance or geometry.")
cell = ((cz*Ly + cy)*Lx + cx).astype(int) # (Ns,)
return cell, sub
# -------------------------------------------------------------------------
#! INVERSE BLOCH TRANSFORM & K-SPACE OPERATIONS
# -------------------------------------------------------------------------
[docs]
def realspace_from_kspace(
self,
H_k : np.ndarray,
*,
block_diag : bool = True,
kgrid : Optional[np.ndarray] = None,
):
r"""
Inverse Bloch transform: reconstruct real-space matrix from k-space blocks.
This is the exact inverse of ``kspace_from_realspace()``. It reconstructs the
real-space Hamiltonian from momentum-space blocks using the inverse Fourier transform:
.. math::
H_{\text{real}} = \sum_k W(k)^\dagger H(k) W(k)
where :math:`W(k)` is the Bloch unitary matrix.
Parameters
----------
H_k : np.ndarray
K-space Hamiltonian blocks in one of two formats:
- **Grid format**: shape ``(Lx, Ly, Lz, Nb, Nb)`` for full BZ grid
(as returned by ``kspace_from_realspace`` with ``block_diag=True``)
- **List format**: shape ``(Nk, Nb, Nb)`` for custom k-points
Must be in **fftfreq order** (no fftshift applied) to match the forward transform.
block_diag : bool, default=True
**Mode selector matching the forward transform:**
- If ``True``: Expects ``H_k`` in block-diagonal format (grid or list of blocks)
and returns reconstructed real-space matrix.
- If ``False``: Expects ``H_k`` as full transformed matrix ``(Ns, Ns)`` and
applies inverse DFT directly.
kgrid : Optional[np.ndarray], default=None
K-point grid for reference (only used when ``block_diag=True``).
- If ``None``: Assumes ``H_k`` is on the full BZ grid in fftfreq order
- If provided: Must match the k-points used for the forward transform
Shape ``(Lx, Ly, Lz, 3)`` or ``(Nk, 3)`` in fftfreq order.
Returns
-------
H_real : np.ndarray
Reconstructed real-space matrix with shape ``(Ns, Ns)`` where
``Ns = Nc * Nb`` is the total number of sites.
Notes
-----
- **Round-trip accuracy**:
- Eigenvalues are preserved to machine precision (~1e-15)
- Both ``H_k`` and ``kgrid`` must be in **fftfreq order** (no fftshift)
- The reconstruction is exact for translationally invariant systems:
- ``H_real_reconstructed ≈ H_real_original`` to numerical precision
- For systems with periodic boundary conditions, the forward and inverse
transforms form a perfect isometry on the Hilbert space.
Examples
--------
**Example 1: Round-trip transform (full grid)**
>>> # Forward transform
>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>>
>>> # Inverse transform
>>> H_real_recon = lattice.realspace_from_kspace(H_k, kgrid=k_grid)
>>>
>>> # Verify reconstruction
>>> np.allclose(H_real, H_real_recon) # True
**Example 2: Inverse transform without explicit kgrid**
>>> # If kgrid is omitted, it's reconstructed using fftfreq convention
>>> H_real_recon = lattice.realspace_from_kspace(H_k)
>>> np.allclose(H_real, H_real_recon) # True
**Example 3: Full matrix mode (inverse DFT)**
>>> H_k_full = lattice.kspace_from_realspace(H_real, block_diag=False)
>>> H_real_recon = lattice.realspace_from_kspace(H_k_full, block_diag=False)
>>> np.allclose(H_real, H_real_recon) # True
See Also
--------
kspace_from_realspace : Forward Bloch transform (real-space to k-space)
structure_factor : Compute momentum-resolved structure factors
References
----------
.. [1] Bloch's theorem and Fourier analysis on lattices
.. [2] Ashcroft & Mermin, "Solid State Physics" (1976), Chapter 8
"""
from .tools.lattice_kspace import realspace_from_kspace, full_k_space_transform
if block_diag:
# Block-diagonal inverse (standard mode)
return realspace_from_kspace(lattice=self, H_k=H_k, kgrid=kgrid)
else:
# Full matrix inverse DFT
if kgrid is not None:
raise ValueError("kgrid parameter is only used with block_diag=True. For full matrix mode (block_diag=False), kgrid is not needed.")
return full_k_space_transform(lattice=self, mat=H_k, inverse=True)
[docs]
def kspace_from_realspace(
self,
mat : np.ndarray,
block_diag : bool = False,
kpoints : Optional[np.ndarray] = None,
unitary_norm : bool = True,
return_transform: bool = False,
):
r"""
Transform a real-space matrix (Hamiltonian, operator, correlator) to momentum space.
This method provides a convenient interface to the Bloch transform for periodic systems.
The transform uses the formula:
.. math::
H_{ab}(k) = \sum_{i,j} W^*_{i,a}(k) H_{i,j} W_{j,b}(k)
where :math:`W_{i,a}(k) = \frac{1}{\sqrt{N_c}} e^{-ik \cdot r_i} \delta_{\text{sub}(i),a}`
Parameters
----------
mat : np.ndarray
Real-space matrix with shape ``(Ns, Ns)`` where ``Ns = Nc * Nb`` is the total
number of sites (unit cells x basis sites per cell).
block_diag : bool, default=False
**Mode selector for different output formats:**
- If ``False``: Returns full transformed matrix ``H_k_full`` with shape ``(Ns, Ns)``
This is the complete DFT of the real-space matrix, useful for structure factors.
- If ``True``: Returns block-diagonal form with k-space blocks ``H_k``, momentum grid,
and fractional coordinates. This is the standard mode for band structure calculations.
**Output:** ``(H_k, k_grid, k_grid_frac)`` where:
* ``H_k``: shape ``(Lx, Ly, Lz, Nb, Nb)`` - Hamiltonian blocks at each k-point
* ``k_grid``: shape ``(Lx, Ly, Lz, 3)`` - Cartesian k-point coordinates
* ``k_grid_frac``: shape ``(Lx, Ly, Lz, 3)`` - Fractional k-point coordinates
kpoints : Optional[np.ndarray], default=None
**Custom k-point sampling** (only used when ``block_diag=True``):
- If ``None``: Uses automatic full Brillouin zone grid based on lattice size
(recommended for most use cases)
- If provided: Array of shape ``(Nk, 3)`` with custom k-points in Cartesian coordinates
Returns ``(H_k, kpoints)`` with ``H_k`` shape ``(Nk, Nb, Nb)``
unitary_norm : bool, default=True
Use unitary normalization :math:`1/\sqrt{N_c}` for the Bloch transform.
If ``False``, uses normalization :math:`1/N_c` instead. Keep ``True`` for
standard quantum mechanics convention preserving operator norms.
return_transform : bool, default=False
If ``True``, also return the Bloch unitary matrix ``W`` used for the transformation.
This is useful for transforming additional operators or computing correlation functions.
**Note:** Only available when ``block_diag=True``. The unitary is returned as a 4th
output value with shape ``(Lx, Ly, Lz, Ns, Nb)`` or ``(Nk, Ns, Nb)`` if custom
k-points are provided.
Returns
-------
**Case 1: block_diag=False (default)**
H_k_full : np.ndarray
Full transformed matrix with shape ``(Ns, Ns)``. This is the complete DFT
of the input matrix, preserving all information.
**Case 2: block_diag=True, kpoints=None (full grid)**
H_k : np.ndarray
K-space Hamiltonian blocks with shape ``(Lx, Ly, Lz, Nb, Nb)`` where:
- ``Lx, Ly, Lz`` are the lattice dimensions
- ``Nb`` is the number of basis sites per unit cell
- ``H_k[ix, iy, iz]`` is the ``Nb x Nb`` block at k-point ``[ix, iy, iz]``
k_grid : np.ndarray
Cartesian k-point coordinates with shape ``(Lx, Ly, Lz, 3)``.
The Γ-point is at index ``[Lx//2, Ly//2, Lz//2]`` after fftshift.
k_grid_frac : np.ndarray
Fractional k-point coordinates with shape ``(Lx, Ly, Lz, 3)``.
Values are in the range ``[0, 1)`` corresponding to the first Brillouin zone.
W : np.ndarray, optional
Bloch unitary matrix with shape ``(Lx, Ly, Lz, Ns, Nb)``.
Only returned if ``return_transform=True``.
Use for transforming operators: ``O_k = W† @ O_real @ W``
**Case 3: block_diag=True, kpoints provided (custom sampling)**
H_k : np.ndarray
K-space Hamiltonian blocks with shape ``(Nk, Nb, Nb)`` where ``Nk`` is
the number of custom k-points provided.
kpoints_out : np.ndarray
Echo of the input k-points with shape ``(Nk, 3)``.
W : np.ndarray, optional
Bloch unitary matrix with shape ``(Nk, Ns, Nb)``.
Only returned if ``return_transform=True``.
Examples
--------
**Example 1: Full matrix transform for structure factor**
>>> H_k_full = lattice.kspace_from_realspace(H_real, block_diag=False)
>>> # H_k_full has shape (Ns, Ns)
**Example 2: Block-diagonal form for band structure (recommended)**
>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>> # H_k has shape (Lx, Ly, Lz, Nb, Nb)
>>> # Diagonalize each block to get bands
>>> energies = np.linalg.eigvalsh(H_k) # shape (Lx, Ly, Lz, Nb)
**Example 3: Custom k-points (e.g., high-symmetry path)**
>>> k_path = lattice.generate_kpath(['Γ', 'X', 'M', 'Γ'], npoints=100)
>>> H_k, k_pts = lattice.kspace_from_realspace(
... H_real, block_diag=True, kpoints=k_path
... )
>>> # H_k has shape (100, Nb, Nb)
>>> energies = np.linalg.eigvalsh(H_k) # shape (100, Nb)
**Example 4: Get Bloch unitary for operator transforms**
>>> H_k, k_grid, k_frac, W = lattice.kspace_from_realspace(
... H_real, block_diag=True, return_transform=True
... )
>>> # Transform another operator using the same W
>>> O_k = np.einsum('kia,ij,kjb->kab', W.conj(), O_real, W)
Notes
-----
- **Periodic boundary conditions (PBC)** are assumed for the Bloch transform.
- The method assumes **translational invariance** of the system, which ensures
the spectrum of ``H_real`` equals the union of spectra of ``H(k)`` blocks.
- For the full grid (``kpoints=None``), the k-points follow the **fftfreq convention**
with the Γ-point initially at index ``[0, 0, 0]``, then shifted to the center.
- Site ordering is arbitrary; the method uses the lattice geometry (coordinates + basis)
to correctly identify sublattices and apply phases.
- For sparse input matrices, automatic conversion to dense format is performed.
See Also
--------
realspace_from_kspace :
Inverse transform from k-space to real-space
structure_factor :
Compute momentum-resolved structure factors with reduction options
generate_kpath :
Generate high-symmetry k-point paths for band structure plotting
References
----------
.. [1] Ashcroft & Mermin, "Solid State Physics" (1976), Chapter 8
.. [2] Bloch's theorem and periodic boundary conditions
"""
from .tools.lattice_kspace import full_k_space_transform, kspace_from_realspace
if block_diag:
# Block-diagonal k-space form (standard for band structure)
return kspace_from_realspace(lattice=self, H_real=mat, kpoints=kpoints, unitary_norm=unitary_norm, return_transform=return_transform,)
else:
# Full matrix transform (for structure factors, etc.)
if return_transform:
raise ValueError("return_transform=True is only available with block_diag=True. Use block_diag=True to get the Bloch unitary matrix.")
if kpoints is not None:
raise ValueError("Custom kpoints are only available with block_diag=True. Use block_diag=True for custom k-point sampling.")
return full_k_space_transform(lattice=self, mat=mat, inverse=False)
[docs]
def structure_factor(self, mat: np.ndarray, *,
reduction : Literal["none", "sum", "trace", "mean", "diag"] = "sum",
norm : Literal["none", "cell", "site"] = "none",
):
r"""
Convert a real-space correlation matrix into a momentum-resolved structure factor.
This is a convenience wrapper around the basis-aware Bloch projector in
``QES.general_python.lattices.tools.lattice_kspace.kspace_from_realspace``.
The real-space input ``mat`` is first transformed into the multipartite
k-space block representation evaluated on ``self.kvectors``
.. math::
C_{\alpha\beta}(q)
=
\frac{1}{N_c}
\sum_{R,R'}
e^{-i q\cdot(R-R')}
\langle O_{R,\alpha} O_{R',\beta} \rangle,
where ``R, R'`` label unit cells and ``alpha, beta`` label basis sites
inside the unit cell. The ``reduction`` argument then decides how this
multipartite object is converted into a scalar structure factor at each
sampled momentum ``q``.
Parameters
----------
mat : np.ndarray
Real-space correlation or operator matrix with shape ``(Ns, Ns)``
or batched shape ``(..., Ns, Ns)``. Any leading axes, e.g. time,
frequency, disorder sample, or state index, are preserved.
reduction : {"none", "sum", "trace", "mean", "diag"}, default="sum"
How to reduce the multipartite k-space blocks:
- ``"none"``:
return the full k-space blocks
``C(q)`` with shape ``(Lx, Ly, Lz, Nb, Nb)`` (i.e., no reduction).
- ``"sum"``:
return ``sum_{alpha,beta} C_{alpha beta}(q)`` (i.e., sum over all entries of each block).
- ``"trace"``:
return ``sum_alpha C_{alpha alpha}(q)`` (i.e., sum over diagonal entries of each block).
- ``"mean"``:
return the arithmetic mean of all multipartite block entries at each ``q`` (i.e., sum over all entries and divide by ``Nb^2``).
- ``"diag"``:
return the eigenvalues of each block, which can be useful for identifying dominant modes or instabilities. The output shape will be ``(Lx, Ly, Lz, Nb)`` since each block's eigenvalues are returned as a vector of length ``Nb``.
norm : {"none", "cell", "site"}, default="none"
Optional post-normalization of the returned k-space quantity:
- ``"none"``:
keep the raw Bloch-projector normalization, i.e. the blocks
``C(q)`` defined above with the prefactor ``1 / N_c``.
- ``"cell"``:
alias for ``"none"`` kept for readability when you want to
emphasize unit-cell normalization.
- ``"site"``:
divide the returned blocks or reduced values by the number of
basis sites ``N_b``. For scalar reductions such as ``"sum"``, this
converts the default unit-cell normalization into the more common
site normalization ``1 / N_s`` used in
:math:`S(q) = \langle O_{-q} O_q \rangle`.
Returns
-------
values : np.ndarray
Momentum-resolved structure factor. For input shape ``(..., Ns, Ns)``
the output shape is:
- ``(..., Lx, Ly, Lz, Nb, Nb)`` for ``reduction="none"``
- ``(..., Lx, Ly, Lz, Nb)`` for ``reduction="diag"``
- ``(..., Lx, Ly, Lz)`` for ``"sum"``, ``"trace"``, or ``"mean"``
For a single input matrix ``(Ns, Ns)``, the leading ``...`` is absent.
k_grid : np.ndarray
Cartesian sampled k-grid with shape ``(Lx, Ly, Lz, 3)``.
k_frac : np.ndarray
Fractional sampled k-grid with shape ``(Lx, Ly, Lz, 3)``.
Notes
-----
Use ``reduction="none"`` when sublattice-resolved information matters.
Use one of the scalar reductions when you want a single value per
momentum that can be fed directly into ``bz_path_data``.
The default ``norm="none"`` preserves the existing unit-cell
normalization. For comparisons against structure factors built from
Fourier-transformed site operators, ``norm="site"`` is typically the
physically relevant choice.
Examples
--------
>>> Sq, k_grid, k_frac = lattice.structure_factor(corr_zz, reduction="sum")
>>> path = lattice.bz_path_data(k_grid, k_frac, Sq, path=['Gamma', 'K', 'M', 'Gamma'])
>>>
>>> # Frequency-resolved data with shape (Nw, Ns, Ns)
>>> Sqw, k_grid, k_frac = lattice.structure_factor(corr_zz_w, reduction="sum")
>>> # Sqw has shape (Nw, Lx, Ly, Lz)
"""
try:
from .tools.lattice_kspace import kspace_from_realspace as _kspace_from_realspace
except ImportError:
raise ImportError("k-space transformation tools not found. Ensure that the lattice_kspace module is available.")
mat = np.asarray(mat)
Ns = self.Ns
if mat.ndim < 2 or mat.shape[-2:] != (Ns, Ns):
raise ValueError(f"mat must have shape (Ns, Ns) or (..., Ns, Ns) with Ns={Ns}, got {mat.shape}")
# Use the explicit k-point Bloch projector so the transform follows the
# same site-coordinate convention as other Fourier-based observables.
batch_shape = mat.shape[:-2]
flat_mat = mat.reshape((-1, Ns, Ns))
Ck_blocks = []
k_grid = None
k_frac = None
grid_shape = (self._lx, max(self._ly, 1), max(self._lz, 1))
for mat_i in flat_mat:
blocks_i, k_points_i = _kspace_from_realspace(
lattice = self,
H_real = mat_i,
kpoints = self.kvectors,
)
blocks_i = np.asarray(blocks_i).reshape(grid_shape + blocks_i.shape[-2:])
Ck_blocks.append(blocks_i)
if k_grid is None:
k_grid = np.asarray(k_points_i, dtype=float).reshape(grid_shape + (3,))
k_frac = np.asarray(self.kvectors_frac, dtype=float).reshape(grid_shape + (3,))
# Stack back into original batch shape
Ck_blocks = np.stack(Ck_blocks, axis=0)
if batch_shape:
Ck_blocks = Ck_blocks.reshape(batch_shape + Ck_blocks.shape[1:])
else:
Ck_blocks = Ck_blocks[0]
reduction_key = reduction.lower()
if reduction_key == "none":
values = Ck_blocks
elif reduction_key == "sum":
values = np.sum(Ck_blocks, axis=(-2, -1))
elif reduction_key == "trace":
values = np.trace(Ck_blocks, axis1=-2, axis2=-1)
elif reduction_key == "mean":
values = np.mean(Ck_blocks, axis=(-2, -1))
elif reduction_key == "diag":
values = np.linalg.eigvals(Ck_blocks)
else:
raise ValueError(f"Unknown reduction '{reduction}'. Use one of: 'none', 'sum', 'trace', 'mean', 'diag'.")
norm_key = norm.lower()
if norm_key in {"none", "cell"}:
scale = 1.0
elif norm_key == "site":
scale = 1.0 / float(max(1, self.multipartity))
else:
raise ValueError(f"Unknown norm '{norm}'. Use one of: 'none', 'cell', 'site'.")
if scale != 1.0:
values = values * scale
return values, k_grid, k_frac
# -------------------------------------------------------------------------
#! Presentation helpers (text / plots)
# -------------------------------------------------------------------------
[docs]
def summary_string(self, *, precision: int = 3) -> str:
"""
Return a textual summary of lattice metadata.
"""
from .visualization import format_lattice_summary
return format_lattice_summary(self, precision=precision)
[docs]
def real_space_table(self, *, max_rows: int = 10, precision: int = 3) -> str:
"""
Return a formatted table of real-space vectors.
"""
from .visualization import format_real_space_vectors
return format_real_space_vectors(self, max_rows=max_rows, precision=precision)
[docs]
def reciprocal_space_table(self, *, max_rows: int = 10, precision: int = 3) -> str:
"""
Return a formatted table of reciprocal-space vectors.
"""
from .visualization import format_reciprocal_space_vectors
return format_reciprocal_space_vectors(self, max_rows=max_rows, precision=precision)
[docs]
def brillouin_zone_overview(self, *, precision: int = 3) -> str:
"""
Return a textual overview of the sampled Brillouin zone.
"""
from .visualization import format_brillouin_zone_overview
return format_brillouin_zone_overview(self, precision=precision)
[docs]
def describe(self, *,
precision : int = 3,
max_rows : int = 10,
include_vectors : bool = True,
include_reciprocal : bool = True,
include_brillouin_zone : bool = True) -> str:
"""
Combine multiple presentation helpers into a single multi-section string.
"""
sections: list[str] = [self.summary_string(precision=precision)]
if include_vectors:
sections.append("Real-space vectors:\n" + self.real_space_table(max_rows=max_rows, precision=precision))
if include_reciprocal:
sections.append("Reciprocal-space vectors:\n" + self.reciprocal_space_table(max_rows=max_rows, precision=precision))
if include_brillouin_zone:
sections.append("Brillouin zone:\n" + self.brillouin_zone_overview(precision=precision))
return "\n\n".join(sections)
# -------------------------------------------------------------------------
#! PLOTTING HELPERS
# -------------------------------------------------------------------------
[docs]
def plot_real_space(self, **kwargs):
"""
Convenience wrapper returning the matplotlib figure and axes for a real-space scatter plot.
"""
from .visualization import plot_real_space
return plot_real_space(self, **kwargs)
[docs]
def plot_reciprocal_space(self, **kwargs):
"""
Scatter-plot of reciprocal lattice vectors (k-points).
Parameters mirror :func:`plot_real_space`
--------------------------------------------------------------------------
lattice : Lattice
The lattice object to plot.
ax : Axes, optional
Matplotlib axes to plot on. If None, a new figure is created.
show_indices : bool, default=False
If True, annotate each k-point with its index.
show_axes : bool, default=True
If False, hides the coordinate axes.
color : str, default="C1"
Color of the k-point markers.
marker : str, default="o"
Marker style.
figsize : tuple, optional
Figure size in inches (width, height).
title : str, optional
Title of the plot.
elev, azim : float, optional
Elevation and azimuth angles for 3D plots.
extend_kpoints : bool, default=False
If True, draw translated reciprocal-space copies around the original mesh.
extend_copies : int or iterable of int, default=2
Number of copies per reciprocal direction used when ``extend_kpoints=True``.
Scalars are applied to all active reciprocal directions.
extend_tol : float, default=1e-10
Tolerance used to identify which extended points are already present in
the original reciprocal mesh.
**scatter_kwargs
Include:
- point_edgecolor: Color of the marker edges (default "white").
- point_zorder: Z-order for the scatter points (default 5).
- color_extended: Color for translated copies (default "C2").
- edgecolor_extended: Edge color for translated copies (default "gray").
- marker_extended: Marker for translated copies (default ``marker``).
- Any other valid arguments for `ax.scatter`.
"""
from .visualization import plot_reciprocal_space
return plot_reciprocal_space(self, **kwargs)
[docs]
def plot_brillouin_zone(self, **kwargs):
"""
Convenience wrapper returning the matplotlib figure and axes for a Brillouin zone plot.
Parameters
----------
lattice : Lattice
The lattice object containing k-vectors.
ax : Axes, optional
Matplotlib axes to plot on. If None, a new figure is created.
facecolor : str, default="tab:blue"
Color to fill the Brillouin Zone area.
edgecolor : str, default="black"
Color for the Brillouin Zone boundary.
alpha : float, default=0.25
Transparency level for the Brillouin Zone fill.
figsize : tuple, optional
Figure size in inches (width, height).
title : str, optional
Title of the plot.
elev, azim : float, optional
Elevation and azimuth angles for 3D plots.
"""
from .visualization import plot_brillouin_zone
return plot_brillouin_zone(self, **kwargs)
[docs]
def plot_structure(self, **kwargs):
"""
Convenience wrapper returning the matplotlib figure and axes for a detailed lattice structure plot.
Parameters
----------
show_indices : bool
If True, annotates nodes with their site indices.
highlight_boundary : bool
If True, draws boundary nodes with a distinct color/edge.
show_axes : bool
If False, hides the coordinate axes for a cleaner diagram.
partition_colors : tuple of str, optional
Colors to use for bipartite/sublattice coloring. If provided, nodes are
colored based on sublattice parity.
show_periodic_connections : bool
If True, indicates wrap-around connections textually or graphically.
show_primitive_cell : bool
If True, overlays the primitive unit cell vectors/box.
... other kwargs passed to the underlying plotting function (e.g., node size, color map, etc.), see plot_lattice_structure() for details.
"""
from .visualization import plot_lattice_structure
return plot_lattice_structure(self, **kwargs)
[docs]
def plot_high_symmetry(self, **kwargs):
"""
Convenience wrapper for plotting the Brillouin zone, high-symmetry path,
and sampled reciprocal mesh.
Parameters
----------
path : list[str], str, or iterable[(label, frac)], optional
High-symmetry path specification. If omitted, the lattice default path
is used.
show_kpoints : bool, default=True
Draw sampled reciprocal-space mesh points.
show_bz : bool, default=True
Draw the first Brillouin zone.
show_path : bool, default=True
Draw the ideal high-symmetry path.
show_matched_kpoints : bool, default=True
Highlight sampled k-points whose distance to the path is within the
matching tolerance.
points_per_seg : int, default=40
Number of interpolation points per path segment for the ideal path.
path_match_tol : float, optional
Distance tolerance used when highlighting mesh points near the
drawn path.
extend : bool, default=False
Draw translated copies of the sampled k-mesh.
extend_copies : int or iterable[int], optional
Number of reciprocal-cell copies per direction. In 2D,
``extend_copies=1`` includes the first shell around the first Brillouin
zone and ``extend_copies=2`` includes the second shell as well.
show_background_bz : bool, default=False
Draw translated Brillouin-zone copies behind the mesh.
hs_plot : {"none", "markers", "labels", "both"}, default="markers"
Whether to draw exact high-symmetry markers, labels, or both.
legend_kwargs : dict, optional
Extra keyword arguments passed to ``axis.legend``.
**kwargs
Additional style overrides forwarded to ``plot_high_symmetry_points``.
"""
try:
from .visualization import plot_high_symmetry_points
except ImportError as e:
raise ImportError(f"Failed to import plotting module for high-symmetry points: {e}")
return plot_high_symmetry_points(self, **kwargs)
@property
def plot(self):
"""
Access plotting utilities for this lattice.
Returns a LatticePlotter instance providing methods:
- real_space(**kwargs) : Scatter plot of sites.
- reciprocal_space(**kwargs) : Scatter plot of reciprocal lattice vectors.
- brillouin_zone(**kwargs) : Visualization of the Brillouin Zone.
- structure(**kwargs) : Detailed connectivity plot with boundaries.
Example:
>>> lat.plot.structure(show_indices=True, highlight_boundary=True)
>>> lat.plot.brillouin_zone()
"""
from .visualization.plotting import LatticePlotter
return LatticePlotter(self)
#############################################################################################################
#! SAVE LATTICE HELPERS
#############################################################################################################
[docs]
def save_bonds(lattice : Lattice, directory : Union[str], filename : str):
'''
Saves the bonds of the lattice to a file
Parameters
-----------
- lattice : lattice model
- directory : directory to save the file
- filename : filename
Returns:
- True if the file has been saved, False otherwise
'''
try:
from ..common import hdf5man as HDF5Mod
except ImportError as e:
raise ImportError(f"Failed to import HDF5 module for saving bonds: {e}")
if lattice.type == LatticeType.HONEYCOMB:
# get the bonds
bonds = -Backend.ones((lattice.ns, 3))
# go through all
for i in range(lattice.ns):
num_of_nn = len(lattice.get_nn_forward_num(i))
for nn in range(num_of_nn):
nei = lattice.get_nei_forward(i, nn)
if nei >= 0:
bonds[i, nn] = nei
# save the bonds
try:
HDF5Mod.save_hdf5(directory, filename, bonds)
except Exception as e:
print(f"An error has occured while saving the bonds: {e}")
return False
return True
return False
#############################################################################################################
#! END OF FILE
#############################################################################################################