"""Lattice factory and registry for geometry-aware simulations.
The package provides canonical lattice classes (square, triangular, honeycomb,
hexagonal, graph) together with registry helpers for custom lattices.
Input/output contracts
----------------------
Factory functions return subclasses of :class:`Lattice` with explicit geometry
metadata (dimensions, boundary conditions, primitive vectors, and neighbor maps).
Typical constructor inputs are integer sizes ``(lx, ly, lz)``, a boundary mode,
and optional flux or graph descriptors.
Shape and dtype expectations
----------------------------
Coordinate arrays are expected as real-valued arrays with shape ``(ns, dim)``.
Index-based neighbor structures are integer arrays or lists over site ids in
``[0, ns)``. Plotting helpers consume NumPy-compatible arrays.
Numerical stability and determinism
-----------------------------------
Topology construction is deterministic for fixed parameters. Floating-point
roundoff can affect reciprocal-space formatting or plotting labels but should not
change connectivity.
-----------------------------------
file : general_python/lattices/__init__.py
author : Maksymilian Kliczkowski
email : maxgrom97@gmail.com
license : MIT
version : 1.0
-----------------------------------
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import importlib
from collections import OrderedDict
from typing import Any, Optional, Tuple, Type, Union
import numpy as np
if TYPE_CHECKING:
import matplotlib.axes as pltAxes
import matplotlib.pyplot as plt
# All type checks
if TYPE_CHECKING:
from . import tools
from .graph import GraphLattice
from .hexagonal import HexagonalLattice
from .honeycomb import HoneycombLattice
from .lattice import (
BoundaryFlux,
Lattice,
Backend as LatticeBackend,
LatticeBC,
LatticeDirection,
LatticeType,
# Visualization helpers
handle_boundary_conditions,
handle_dim,
HighSymmetryPoints,
HighSymmetryPoint,
KPathResult,
StandardBZPath,
)
from .square import SquareLattice
from .triangular import TriangularLattice
from .visualization import (
format_lattice_summary,
format_vector_table,
format_real_space_vectors,
format_reciprocal_space_vectors,
format_brillouin_zone_overview,
LatticePlotter,
plot_real_space,
plot_reciprocal_space,
plot_brillouin_zone,
)
from .tools.region_handler import RegionType, LatticeRegionHandler
from .tools.regions import KPRegion, LWRegion, HalfRegions, DiskRegion, PlaquetteRegion, CustomRegion, Region
# --------------------------------------------------------------------------------------------------
__all__ = [
"BoundaryFlux",
"Lattice",
"LatticeBC",
"LatticeDirection",
"LatticeType",
"LatticeBackend",
# Core lattice classes
"SquareLattice",
"HexagonalLattice",
"HoneycombLattice",
"TriangularLattice",
"GraphLattice",
# Factory functions - symmetry registry
"HighSymmetryPoints",
"HighSymmetryPoint",
"KPathResult",
"StandardBZPath",
# Factory functions - registry
"register_lattice",
"available_lattices",
# Factory function - main entry point
"choose_lattice",
# Visualization utilities
"plot_bonds",
"format_lattice_summary",
"format_vector_table",
"format_real_space_vectors",
"format_reciprocal_space_vectors",
"format_brillouin_zone_overview",
"LatticePlotter",
"plot_real_space",
"plot_reciprocal_space",
"plot_brillouin_zone",
"plot_lattice_structure",
# Region utilities
"tools",
"RegionType",
"LatticeRegionHandler",
"Region",
"KitaevPreskillRegion",
"LevinWenRegion",
"HalfRegions",
"DiskRegion",
"PlaquetteRegion",
"CustomRegion",
"KPRegion",
"LWRegion",
"get_predefined_region",
"list_predefined_regions",
# Testing utilities
"run_lattice_tests",
]
LatticeFactory = Type[Any]
_LATTICE_REGISTRY : "OrderedDict[str, Any]" = OrderedDict()
_CORE_EXPORTS = {
"BoundaryFlux" : (".lattice", "BoundaryFlux"),
"Lattice" : (".lattice", "Lattice"),
"LatticeBackend" : (".lattice", "Backend"),
"LatticeBC" : (".lattice", "LatticeBC"),
"LatticeDirection" : (".lattice", "LatticeDirection"),
"LatticeType" : (".lattice", "LatticeType"),
"handle_boundary_conditions" : (".lattice", "handle_boundary_conditions"),
"handle_dim" : (".lattice", "handle_dim"),
"HighSymmetryPoints" : (".lattice", "HighSymmetryPoints"),
"HighSymmetryPoint" : (".lattice", "HighSymmetryPoint"),
"KPathResult" : (".lattice", "KPathResult"),
"StandardBZPath" : (".lattice", "StandardBZPath"),
}
_LATTICE_EXPORTS = {
"SquareLattice" : (".square", "SquareLattice"),
"HexagonalLattice" : (".hexagonal", "HexagonalLattice"),
"HoneycombLattice" : (".honeycomb", "HoneycombLattice"),
"TriangularLattice" : (".triangular", "TriangularLattice"),
"GraphLattice" : (".graph", "GraphLattice"),
}
_VIS_EXPORTS = {
"format_lattice_summary" : (".visualization", "format_lattice_summary"),
"format_vector_table" : (".visualization", "format_vector_table"),
"format_real_space_vectors" : (".visualization", "format_real_space_vectors"),
"format_reciprocal_space_vectors" : (".visualization", "format_reciprocal_space_vectors"),
"format_brillouin_zone_overview" : (".visualization", "format_brillouin_zone_overview"),
"LatticePlotter" : (".visualization", "LatticePlotter"),
"plot_real_space" : (".visualization", "plot_real_space"),
"plot_reciprocal_space" : (".visualization", "plot_reciprocal_space"),
"plot_brillouin_zone" : (".visualization", "plot_brillouin_zone"),
}
_TOOLS_EXPORTS = {
"tools" : (".tools", None),
"RegionType" : (".tools.region_handler", "RegionType"),
"LatticeRegionHandler" : (".tools.region_handler", "LatticeRegionHandler"),
}
_LAZY_REGION_EXPORTS = {
"Region" : (".tools.regions", "Region"),
"KitaevPreskillRegion" : (".tools.regions", "KitaevPreskillRegion"),
"LevinWenRegion" : (".tools.regions", "LevinWenRegion"),
"HalfRegions" : (".tools.regions", "HalfRegions"),
"DiskRegion" : (".tools.regions", "DiskRegion"),
"PlaquetteRegion" : (".tools.regions", "PlaquetteRegion"),
"CustomRegion" : (".tools.regions", "CustomRegion"),
"KPRegion" : (".tools.regions", "KPRegion"),
"LWRegion" : (".tools.regions", "LWRegion"),
"get_predefined_region" : (".tools.regions", "get_predefined_region"),
"list_predefined_regions": (".tools.regions", "list_predefined_regions"),
}
# ---------------------------------------------------------------------------------------------------
def _load_export(module_name: str, attr_name: str | None):
module = importlib.import_module(module_name, package=__name__)
value = module if attr_name is None else getattr(module, attr_name)
return value
def _resolve_lattice_entry(entry: Any):
if isinstance(entry, tuple):
module_name, class_name = entry
return _load_export(module_name, class_name)
return entry
[docs]
def register_lattice(name: str, lattice_cls: LatticeFactory, *aliases: str, overwrite: bool = False):
"""
Register a lattice class under ``name`` and optional ``aliases``.
"""
lattice_base = _load_export(".lattice", "Lattice")
if not isinstance(lattice_cls, tuple) and not issubclass(lattice_cls, lattice_base):
raise TypeError(f"Registered lattice must inherit from Lattice; got {lattice_cls!r}")
keys = (name, *aliases)
for key in keys:
if not overwrite and key in _LATTICE_REGISTRY:
raise KeyError(f"Lattice '{key}' already registered. Pass overwrite=True to replace.")
for key in keys:
_LATTICE_REGISTRY[key] = lattice_cls
[docs]
def available_lattices() -> Tuple[str, ...]:
"""
Return tuple of registered lattice identifiers.
"""
return tuple(_LATTICE_REGISTRY.keys())
# Default registrations
register_lattice("square", _LATTICE_EXPORTS["SquareLattice"], "SquareLattice")
register_lattice("hexagonal", _LATTICE_EXPORTS["HexagonalLattice"], "HexagonalLattice")
register_lattice("honeycomb", _LATTICE_EXPORTS["HoneycombLattice"], "HoneycombLattice")
register_lattice("triangular", _LATTICE_EXPORTS["TriangularLattice"], "TriangularLattice")
register_lattice("graph", _LATTICE_EXPORTS["GraphLattice"], "GraphLattice")
# --------------------------------------------------------------------------------------------------
[docs]
def plot_bonds(lattice : Lattice,
ax : 'pltAxes' = None,
**line_kwargs) -> 'pltAxes':
'''
Plot physical bonds of the lattice using primitive vectors (a1,a2,a3).
Args:
ax (Axes):
existing matplotlib Axes; new one if None.
include_nnn (bool):
include next-nearest bonds if True.
**line_kwargs: passed to ax.plot and ax.scatter.
'''
if lattice is None:
raise ValueError("Lattice cannot be None")
colorsCycle = _load_export("..common.plot", "colorsCycle")
dim = lattice.dim
Ns = lattice.ns
colors = [next(colorsCycle) for _ in range(10)]
if ax is None:
import matplotlib.pyplot as plt
if dim == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
else:
fig, ax = plt.subplots()
#! lattice indices
coords = np.asarray(lattice.coordinates)
if coords.ndim != 2 or coords.shape[1] < dim:
raise ValueError(f"Coordinates must have at least {dim} columns, got {coords.shape}")
coords_idx = coords[:, :dim]
#! primitive vectors
a1 = np.array(lattice.a1, dtype=float)[:dim]
a2 = np.array(lattice.a2, dtype=float)[:dim]
a3 = np.array(lattice.a3, dtype=float)[:dim]
prims = np.vstack((a1, a2, a3)) # shape (3, >=dim)
M = prims[:dim, :dim] # shape (dim, dim)
#! real-space positions
pos = coords_idx.dot(M) # shape (Ns, dim)
#! adjacency
A = lattice.adjacency_matrix(sparse=False, save=False)
#! draw bonds
if dim == 1:
xs = pos[:, 0]
ys = np.zeros_like(xs)
for i in range(Ns):
nonzero = np.nonzero(A[i])[0]
for ctr, item in enumerate(nonzero):
color = colors[ctr % len(colors)]
j = item
if j > i:
ax.plot([xs[i], xs[j]], [ys[i], ys[j]], **line_kwargs, color=color)
ax.scatter(xs, ys, **line_kwargs)
ax.set_xlabel('$x$')
ax.set_yticks([])
elif dim == 2:
xs, ys = pos[:, 0], pos[:, 1]
for i in range(Ns):
nonzero = np.nonzero(A[i])[0]
for ctr, item in enumerate(nonzero):
color = colors[ctr % len(colors)]
ax.plot([xs[i], xs[item]], [ys[i], ys[item]], **line_kwargs, color=color)
ax.scatter(xs, ys, **line_kwargs)
ax.set_xlabel('$x$')
ax.set_ylabel('$y$')
ax.set_aspect('equal', 'box')
elif dim == 3:
xs, ys, zs = pos[:, 0], pos[:, 1], pos[:, 2]
for i in range(Ns):
nonzero = np.nonzero(A[i])[0]
for ctr, item in enumerate(nonzero):
color = colors[ctr % len(colors)]
ax.plot([xs[i], xs[item]], [ys[i], ys[item]], [zs[i], zs[item]], **line_kwargs, color=color)
ax.scatter(xs, ys, zs, **line_kwargs)
ax.set_xlabel('$x$')
ax.set_ylabel('$y$')
ax.set_zlabel('$z$')
ax.set_box_aspect((1, 1, 1))
else:
raise ValueError(f"Unsupported lattice dimension: {dim}")
return ax
[docs]
def plot_lattice_structure(lattice, **kwargs):
"""
Wrapper for the visualization module's lattice structure plotter.
"""
from .visualization import _plot_lattice_structure_visual
return _plot_lattice_structure_visual(lattice, **kwargs)
####################################################################################################
#! Tests
####################################################################################################
[docs]
def run_lattice_tests(dim=2, lx=5, ly=5, lz=1, bc=None, typek="square"):
"""
Run automated tests for a lattice in 1D, 2D, or 3D.
Args:
dim (int): Lattice dimension (1, 2, or 3)
lx (int): Number of sites in the x-direction
ly (int): Number of sites in the y-direction (ignored if dim=1)
lz (int): Number of sites in the z-direction (ignored if dim < 3)
bc : Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC)
typek (str) : Type of lattice ("square", "hexagonal", or "honeycomb")
"""
# If no boundary condition is provided, default to periodic
if bc is None:
bc = _load_export(".lattice", "LatticeBC").PBC
lattice = choose_lattice(typek, dim=dim, lx=lx, ly=ly, lz=lz, bc=bc)
print(f"Running tests for {lattice}")
## Test 1: Nearest Neighbors
print("\n1) Testing nearest neighbors...")
for i in range(lattice.Ns):
neighbors = lattice.get_nei(i)
print(f"\tSite {i}: Nearest Neighbors: {neighbors}")
## Test 2: Forward Nearest Neighbors
print("\n2) Testing forward nearest neighbors...")
for i in range(lattice.Ns):
forward_neighbors = lattice.get_nn_forward(i)
print(f"\tSite {i}: Forward Neighbors: {forward_neighbors}")
## Test 3: Coordinate Mapping
print("\n3) Testing coordinate mapping...")
for i in range(lattice.Ns):
coords = lattice.get_coordinates(i)
idx = lattice.site_index(*coords)
print(f"\tSite {i}: Coordinates {coords} -> Index {idx}")
print("\tCoordinate mapping test passed!")
## Test 4: Performance (for large lattices)
if lattice.Ns > 1000:
print("\n4) Running performance test (large lattice)...")
try:
start_time = time.time()
lattice.calculate_dft_matrix()
end_time = time.time()
print(f"\tPerformance test passed! Time taken: {end_time - start_time:.2f} seconds")
except Exception as e:
print(f"\tPerformance test failed: {e}")
## Generate Lattice Plot
plot_lattice_structure(lattice)
print(f"\n(ok) All tests completed successfully for {lattice}!")
####################################################################################################
#! Test the lattice module
####################################################################################################
def _handle_type(typek):
"""
Resolve an identifier (string or ``LatticeType``) to a registered lattice class.
"""
if typek is None:
return _resolve_lattice_entry(_LATTICE_REGISTRY["square"])
if isinstance(typek, str):
cls = _LATTICE_REGISTRY.get(typek)
if cls is None:
raise ValueError(f"Unknown lattice type '{typek}'. Available: {available_lattices()}.")
return _resolve_lattice_entry(cls)
lattice_type_enum = _load_export(".lattice", "LatticeType")
if isinstance(typek, lattice_type_enum):
mapping = {
lattice_type_enum.SQUARE : "square",
lattice_type_enum.HEXAGONAL : "hexagonal",
lattice_type_enum.HONEYCOMB : "honeycomb",
lattice_type_enum.GRAPH : "graph",
}
key = mapping.get(typek)
if key is None:
raise ValueError(f"Unsupported lattice type enum {typek!r}.")
return _resolve_lattice_entry(_LATTICE_REGISTRY[key])
raise ValueError(f"Unknown lattice type: {typek!r}")
[docs]
def choose_lattice(typek : Optional[str] = 'square',
dim : Optional[int] = None,
lx : Optional[int] = 1,
ly : Optional[int] = 1,
lz : Optional[int] = 1,
bc : Optional[Union[str, LatticeBC]] = None,
flux : Optional[Union[float, BoundaryFlux, dict]] = None,
**kwargs):
"""
Returns an instance of a lattice of the desired type.
Args:
typek (str):
Type of lattice ("square", "hexagonal", or "honeycomb")
dim (int):
Dimension (1, 2, or 3)
lx (int):
Number of sites in x-direction
ly (int):
Number of sites in y-direction
lz (int):
Number of sites in z-direction (ignored if dim < 3)
bc: Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC)
flux:
Optional boundary flux specification forwarded to the lattice
constructor. Accepts a scalar phase, :class:`BoundaryFlux`, or a
mapping from directions to phases (in radians).
Returns:
Lattice: An instance of the desired lattice.
"""
#! handle boundary conditions
handle_boundary_conditions = _load_export(".lattice", "handle_boundary_conditions")
handle_dim = _load_export(".lattice", "handle_dim")
_bc = handle_boundary_conditions(bc)
#! handle dimensions
dim, lx, ly, lz = handle_dim(lx, ly, lz)
#! handle type
_class = _handle_type(typek)
graph_cls = _load_export(".graph", "GraphLattice")
if issubclass(_class, graph_cls):
return _class(bc=_bc, flux=flux, **kwargs)
return _class(dim=dim, lx=lx, ly=ly, lz=lz, bc=_bc, flux=flux, **kwargs)
def __getattr__(name: str):
for mapping in (_CORE_EXPORTS, _LATTICE_EXPORTS, _VIS_EXPORTS, _TOOLS_EXPORTS, _LAZY_REGION_EXPORTS):
if name in mapping:
module_name, attr_name = mapping[name]
value = _load_export(module_name, attr_name)
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
extra = list(_CORE_EXPORTS.keys()) + list(_LATTICE_EXPORTS.keys()) + list(_VIS_EXPORTS.keys()) + list(_TOOLS_EXPORTS.keys()) + list(_LAZY_REGION_EXPORTS.keys())
return sorted(list(globals().keys()) + extra)
####################################################################################################
#! EOF
####################################################################################################