Source code for general_python.lattices

"""Lattice factory and registry for geometry-aware simulations.

The package provides canonical lattice classes (square, triangular, honeycomb,
hexagonal, graph) together with registry helpers for custom lattices.

Input/output contracts
----------------------
Factory functions return subclasses of :class:`Lattice` with explicit geometry
metadata (dimensions, boundary conditions, primitive vectors, and neighbor maps).
Typical constructor inputs are integer sizes ``(lx, ly, lz)``, a boundary mode,
and optional flux or graph descriptors.

Shape and dtype expectations
----------------------------
Coordinate arrays are expected as real-valued arrays with shape ``(ns, dim)``.
Index-based neighbor structures are integer arrays or lists over site ids in
``[0, ns)``. Plotting helpers consume NumPy-compatible arrays.

Numerical stability and determinism
-----------------------------------
Topology construction is deterministic for fixed parameters. Floating-point
roundoff can affect reciprocal-space formatting or plotting labels but should not
change connectivity.


-----------------------------------
file            : general_python/lattices/__init__.py
author          : Maksymilian Kliczkowski
email           : maxgrom97@gmail.com
license         : MIT
version         : 1.0
-----------------------------------
"""

from __future__ import annotations
from typing     import TYPE_CHECKING

import importlib
from collections                import OrderedDict
from typing                     import Any, Optional, Tuple, Type, Union
import numpy                    as np
if TYPE_CHECKING:
    import matplotlib.axes      as pltAxes
    import matplotlib.pyplot    as plt

# All type checks
if TYPE_CHECKING:
    from .                      import tools
    from .graph                 import GraphLattice
    from .hexagonal             import HexagonalLattice
    from .honeycomb             import HoneycombLattice
    from .lattice               import (
                                    BoundaryFlux,
                                    Lattice,
                                    Backend as LatticeBackend,
                                    LatticeBC,
                                    LatticeDirection,
                                    LatticeType,
                                    # Visualization helpers
                                    handle_boundary_conditions,
                                    handle_dim,
                                    HighSymmetryPoints,
                                    HighSymmetryPoint,
                                    KPathResult,
                                    StandardBZPath,
                                )
    from .square                import SquareLattice
    from .triangular            import TriangularLattice
    from .visualization import (
        format_lattice_summary,
        format_vector_table,
        format_real_space_vectors,
        format_reciprocal_space_vectors,
        format_brillouin_zone_overview,
        LatticePlotter,
        plot_real_space,
        plot_reciprocal_space,
        plot_brillouin_zone,
    )
    from .tools.region_handler import RegionType, LatticeRegionHandler
    from .tools.regions import KPRegion, LWRegion, HalfRegions, DiskRegion, PlaquetteRegion, CustomRegion, Region

# --------------------------------------------------------------------------------------------------

__all__ = [
    "BoundaryFlux",
    "Lattice",
    "LatticeBC",
    "LatticeDirection",
    "LatticeType",
    "LatticeBackend",
    # Core lattice classes
    "SquareLattice",
    "HexagonalLattice",
    "HoneycombLattice",
    "TriangularLattice",
    "GraphLattice",
    # Factory functions - symmetry registry
    "HighSymmetryPoints",
    "HighSymmetryPoint",
    "KPathResult",
    "StandardBZPath",
    # Factory functions - registry
    "register_lattice",
    "available_lattices",
    # Factory function - main entry point
    "choose_lattice",
    # Visualization utilities
    "plot_bonds",
    "format_lattice_summary",
    "format_vector_table",
    "format_real_space_vectors",
    "format_reciprocal_space_vectors",
    "format_brillouin_zone_overview",
    "LatticePlotter",
    "plot_real_space",
    "plot_reciprocal_space",
    "plot_brillouin_zone",
    "plot_lattice_structure",
    # Region utilities
    "tools",
    "RegionType",
    "LatticeRegionHandler",
    "Region",
    "KitaevPreskillRegion",
    "LevinWenRegion",
    "HalfRegions",
    "DiskRegion",
    "PlaquetteRegion",
    "CustomRegion",
    "KPRegion",
    "LWRegion",
    "get_predefined_region",
    "list_predefined_regions",
    # Testing utilities
    "run_lattice_tests",
]

LatticeFactory      = Type[Any]
_LATTICE_REGISTRY   : "OrderedDict[str, Any]" = OrderedDict()
_CORE_EXPORTS = {
    "BoundaryFlux"                      : (".lattice", "BoundaryFlux"),
    "Lattice"                           : (".lattice", "Lattice"),
    "LatticeBackend"                    : (".lattice", "Backend"),
    "LatticeBC"                         : (".lattice", "LatticeBC"),
    "LatticeDirection"                  : (".lattice", "LatticeDirection"),
    "LatticeType"                       : (".lattice", "LatticeType"),
    "handle_boundary_conditions"        : (".lattice", "handle_boundary_conditions"),
    "handle_dim"                        : (".lattice", "handle_dim"),
    "HighSymmetryPoints"                : (".lattice", "HighSymmetryPoints"),
    "HighSymmetryPoint"                 : (".lattice", "HighSymmetryPoint"),
    "KPathResult"                       : (".lattice", "KPathResult"),
    "StandardBZPath"                    : (".lattice", "StandardBZPath"),
}
_LATTICE_EXPORTS = {
    "SquareLattice"                     : (".square", "SquareLattice"),
    "HexagonalLattice"                  : (".hexagonal", "HexagonalLattice"),
    "HoneycombLattice"                  : (".honeycomb", "HoneycombLattice"),
    "TriangularLattice"                 : (".triangular", "TriangularLattice"),
    "GraphLattice"                      : (".graph", "GraphLattice"),
}
_VIS_EXPORTS = {
    "format_lattice_summary"            : (".visualization", "format_lattice_summary"),
    "format_vector_table"               : (".visualization", "format_vector_table"),
    "format_real_space_vectors"         : (".visualization", "format_real_space_vectors"),
    "format_reciprocal_space_vectors"   : (".visualization", "format_reciprocal_space_vectors"),
    "format_brillouin_zone_overview"    : (".visualization", "format_brillouin_zone_overview"),
    "LatticePlotter"                    : (".visualization", "LatticePlotter"),
    "plot_real_space"                   : (".visualization", "plot_real_space"),
    "plot_reciprocal_space"             : (".visualization", "plot_reciprocal_space"),
    "plot_brillouin_zone"               : (".visualization", "plot_brillouin_zone"),
}
_TOOLS_EXPORTS = {
    "tools"                             : (".tools", None),
    "RegionType"                        : (".tools.region_handler", "RegionType"),
    "LatticeRegionHandler"              : (".tools.region_handler", "LatticeRegionHandler"),
}
_LAZY_REGION_EXPORTS = {
    "Region"                : (".tools.regions", "Region"),
    "KitaevPreskillRegion"  : (".tools.regions", "KitaevPreskillRegion"),
    "LevinWenRegion"        : (".tools.regions", "LevinWenRegion"),
    "HalfRegions"           : (".tools.regions", "HalfRegions"),
    "DiskRegion"            : (".tools.regions", "DiskRegion"),
    "PlaquetteRegion"       : (".tools.regions", "PlaquetteRegion"),
    "CustomRegion"          : (".tools.regions", "CustomRegion"),
    "KPRegion"              : (".tools.regions", "KPRegion"),
    "LWRegion"              : (".tools.regions", "LWRegion"),
    "get_predefined_region" : (".tools.regions", "get_predefined_region"),
    "list_predefined_regions": (".tools.regions", "list_predefined_regions"),
}

# ---------------------------------------------------------------------------------------------------

def _load_export(module_name: str, attr_name: str | None):
    module = importlib.import_module(module_name, package=__name__)
    value = module if attr_name is None else getattr(module, attr_name)
    return value

def _resolve_lattice_entry(entry: Any):
    if isinstance(entry, tuple):
        module_name, class_name = entry
        return _load_export(module_name, class_name)
    return entry

[docs] def register_lattice(name: str, lattice_cls: LatticeFactory, *aliases: str, overwrite: bool = False): """ Register a lattice class under ``name`` and optional ``aliases``. """ lattice_base = _load_export(".lattice", "Lattice") if not isinstance(lattice_cls, tuple) and not issubclass(lattice_cls, lattice_base): raise TypeError(f"Registered lattice must inherit from Lattice; got {lattice_cls!r}") keys = (name, *aliases) for key in keys: if not overwrite and key in _LATTICE_REGISTRY: raise KeyError(f"Lattice '{key}' already registered. Pass overwrite=True to replace.") for key in keys: _LATTICE_REGISTRY[key] = lattice_cls
[docs] def available_lattices() -> Tuple[str, ...]: """ Return tuple of registered lattice identifiers. """ return tuple(_LATTICE_REGISTRY.keys())
# Default registrations register_lattice("square", _LATTICE_EXPORTS["SquareLattice"], "SquareLattice") register_lattice("hexagonal", _LATTICE_EXPORTS["HexagonalLattice"], "HexagonalLattice") register_lattice("honeycomb", _LATTICE_EXPORTS["HoneycombLattice"], "HoneycombLattice") register_lattice("triangular", _LATTICE_EXPORTS["TriangularLattice"], "TriangularLattice") register_lattice("graph", _LATTICE_EXPORTS["GraphLattice"], "GraphLattice") # --------------------------------------------------------------------------------------------------
[docs] def plot_bonds(lattice : Lattice, ax : 'pltAxes' = None, **line_kwargs) -> 'pltAxes': ''' Plot physical bonds of the lattice using primitive vectors (a1,a2,a3). Args: ax (Axes): existing matplotlib Axes; new one if None. include_nnn (bool): include next-nearest bonds if True. **line_kwargs: passed to ax.plot and ax.scatter. ''' if lattice is None: raise ValueError("Lattice cannot be None") colorsCycle = _load_export("..common.plot", "colorsCycle") dim = lattice.dim Ns = lattice.ns colors = [next(colorsCycle) for _ in range(10)] if ax is None: import matplotlib.pyplot as plt if dim == 3: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') else: fig, ax = plt.subplots() #! lattice indices coords = np.asarray(lattice.coordinates) if coords.ndim != 2 or coords.shape[1] < dim: raise ValueError(f"Coordinates must have at least {dim} columns, got {coords.shape}") coords_idx = coords[:, :dim] #! primitive vectors a1 = np.array(lattice.a1, dtype=float)[:dim] a2 = np.array(lattice.a2, dtype=float)[:dim] a3 = np.array(lattice.a3, dtype=float)[:dim] prims = np.vstack((a1, a2, a3)) # shape (3, >=dim) M = prims[:dim, :dim] # shape (dim, dim) #! real-space positions pos = coords_idx.dot(M) # shape (Ns, dim) #! adjacency A = lattice.adjacency_matrix(sparse=False, save=False) #! draw bonds if dim == 1: xs = pos[:, 0] ys = np.zeros_like(xs) for i in range(Ns): nonzero = np.nonzero(A[i])[0] for ctr, item in enumerate(nonzero): color = colors[ctr % len(colors)] j = item if j > i: ax.plot([xs[i], xs[j]], [ys[i], ys[j]], **line_kwargs, color=color) ax.scatter(xs, ys, **line_kwargs) ax.set_xlabel('$x$') ax.set_yticks([]) elif dim == 2: xs, ys = pos[:, 0], pos[:, 1] for i in range(Ns): nonzero = np.nonzero(A[i])[0] for ctr, item in enumerate(nonzero): color = colors[ctr % len(colors)] ax.plot([xs[i], xs[item]], [ys[i], ys[item]], **line_kwargs, color=color) ax.scatter(xs, ys, **line_kwargs) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_aspect('equal', 'box') elif dim == 3: xs, ys, zs = pos[:, 0], pos[:, 1], pos[:, 2] for i in range(Ns): nonzero = np.nonzero(A[i])[0] for ctr, item in enumerate(nonzero): color = colors[ctr % len(colors)] ax.plot([xs[i], xs[item]], [ys[i], ys[item]], [zs[i], zs[item]], **line_kwargs, color=color) ax.scatter(xs, ys, zs, **line_kwargs) ax.set_xlabel('$x$') ax.set_ylabel('$y$') ax.set_zlabel('$z$') ax.set_box_aspect((1, 1, 1)) else: raise ValueError(f"Unsupported lattice dimension: {dim}") return ax
[docs] def plot_lattice_structure(lattice, **kwargs): """ Wrapper for the visualization module's lattice structure plotter. """ from .visualization import _plot_lattice_structure_visual return _plot_lattice_structure_visual(lattice, **kwargs)
#################################################################################################### #! Tests ####################################################################################################
[docs] def run_lattice_tests(dim=2, lx=5, ly=5, lz=1, bc=None, typek="square"): """ Run automated tests for a lattice in 1D, 2D, or 3D. Args: dim (int): Lattice dimension (1, 2, or 3) lx (int): Number of sites in the x-direction ly (int): Number of sites in the y-direction (ignored if dim=1) lz (int): Number of sites in the z-direction (ignored if dim < 3) bc : Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC) typek (str) : Type of lattice ("square", "hexagonal", or "honeycomb") """ # If no boundary condition is provided, default to periodic if bc is None: bc = _load_export(".lattice", "LatticeBC").PBC lattice = choose_lattice(typek, dim=dim, lx=lx, ly=ly, lz=lz, bc=bc) print(f"Running tests for {lattice}") ## Test 1: Nearest Neighbors print("\n1) Testing nearest neighbors...") for i in range(lattice.Ns): neighbors = lattice.get_nei(i) print(f"\tSite {i}: Nearest Neighbors: {neighbors}") ## Test 2: Forward Nearest Neighbors print("\n2) Testing forward nearest neighbors...") for i in range(lattice.Ns): forward_neighbors = lattice.get_nn_forward(i) print(f"\tSite {i}: Forward Neighbors: {forward_neighbors}") ## Test 3: Coordinate Mapping print("\n3) Testing coordinate mapping...") for i in range(lattice.Ns): coords = lattice.get_coordinates(i) idx = lattice.site_index(*coords) print(f"\tSite {i}: Coordinates {coords} -> Index {idx}") print("\tCoordinate mapping test passed!") ## Test 4: Performance (for large lattices) if lattice.Ns > 1000: print("\n4) Running performance test (large lattice)...") try: start_time = time.time() lattice.calculate_dft_matrix() end_time = time.time() print(f"\tPerformance test passed! Time taken: {end_time - start_time:.2f} seconds") except Exception as e: print(f"\tPerformance test failed: {e}") ## Generate Lattice Plot plot_lattice_structure(lattice) print(f"\n(ok) All tests completed successfully for {lattice}!")
#################################################################################################### #! Test the lattice module #################################################################################################### def _handle_type(typek): """ Resolve an identifier (string or ``LatticeType``) to a registered lattice class. """ if typek is None: return _resolve_lattice_entry(_LATTICE_REGISTRY["square"]) if isinstance(typek, str): cls = _LATTICE_REGISTRY.get(typek) if cls is None: raise ValueError(f"Unknown lattice type '{typek}'. Available: {available_lattices()}.") return _resolve_lattice_entry(cls) lattice_type_enum = _load_export(".lattice", "LatticeType") if isinstance(typek, lattice_type_enum): mapping = { lattice_type_enum.SQUARE : "square", lattice_type_enum.HEXAGONAL : "hexagonal", lattice_type_enum.HONEYCOMB : "honeycomb", lattice_type_enum.GRAPH : "graph", } key = mapping.get(typek) if key is None: raise ValueError(f"Unsupported lattice type enum {typek!r}.") return _resolve_lattice_entry(_LATTICE_REGISTRY[key]) raise ValueError(f"Unknown lattice type: {typek!r}")
[docs] def choose_lattice(typek : Optional[str] = 'square', dim : Optional[int] = None, lx : Optional[int] = 1, ly : Optional[int] = 1, lz : Optional[int] = 1, bc : Optional[Union[str, LatticeBC]] = None, flux : Optional[Union[float, BoundaryFlux, dict]] = None, **kwargs): """ Returns an instance of a lattice of the desired type. Args: typek (str): Type of lattice ("square", "hexagonal", or "honeycomb") dim (int): Dimension (1, 2, or 3) lx (int): Number of sites in x-direction ly (int): Number of sites in y-direction lz (int): Number of sites in z-direction (ignored if dim < 3) bc: Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC) flux: Optional boundary flux specification forwarded to the lattice constructor. Accepts a scalar phase, :class:`BoundaryFlux`, or a mapping from directions to phases (in radians). Returns: Lattice: An instance of the desired lattice. """ #! handle boundary conditions handle_boundary_conditions = _load_export(".lattice", "handle_boundary_conditions") handle_dim = _load_export(".lattice", "handle_dim") _bc = handle_boundary_conditions(bc) #! handle dimensions dim, lx, ly, lz = handle_dim(lx, ly, lz) #! handle type _class = _handle_type(typek) graph_cls = _load_export(".graph", "GraphLattice") if issubclass(_class, graph_cls): return _class(bc=_bc, flux=flux, **kwargs) return _class(dim=dim, lx=lx, ly=ly, lz=lz, bc=_bc, flux=flux, **kwargs)
def __getattr__(name: str): for mapping in (_CORE_EXPORTS, _LATTICE_EXPORTS, _VIS_EXPORTS, _TOOLS_EXPORTS, _LAZY_REGION_EXPORTS): if name in mapping: module_name, attr_name = mapping[name] value = _load_export(module_name, attr_name) globals()[name] = value return value raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def __dir__(): extra = list(_CORE_EXPORTS.keys()) + list(_LATTICE_EXPORTS.keys()) + list(_VIS_EXPORTS.keys()) + list(_TOOLS_EXPORTS.keys()) + list(_LAZY_REGION_EXPORTS.keys()) return sorted(list(globals().keys()) + extra) #################################################################################################### #! EOF ####################################################################################################