API Reference

This section contains the detailed API reference for the modules in the General Python Utilities library. The documentation below is generated automatically by Sphinx using autodoc, listing the classes, functions, and constants available in each module.

Algebra Module

Backend-aware linear algebra interfaces and solver entry points.

This package exposes numerical kernels used across the project:

  • Krylov and direct linear solvers.

  • Preconditioner abstractions.

  • Backend helpers for NumPy/JAX interoperability.

  • Random wrappers used in reproducible scientific workflows.

Input/output and dtype contracts

Most public APIs accept array-like vectors and matrices that are converted to the active backend where possible. Shapes follow linear-algebra conventions, for example A has shape (n, n) and b has shape (n,) or (n, k). Dtype promotion follows backend rules; explicit float64 or complex128 is recommended for ill-conditioned problems.

Numerical stability and determinism

Stability depends on solver choice, conditioning, and preconditioning quality. For reproducibility, set random seeds via algebra.ran_wrapper and keep backend selection fixed in a run. NumPy and JAX results should agree up to floating-point roundoff; small differences can appear due to kernel fusion and reduction order.

The module uses lazy imports to keep import time low.

general_python.algebra.overlap(v1, mat, v2: object | None = None)[source]

Compute the overlap v1^H @ mat @ v2 using the active backend.

general_python.algebra.norm(v, ord=None)[source]

Compute the norm of vector v using the active backend.

general_python.algebra.matvec(mat, vec)[source]

Compute the matrix-vector product mat @ vec using the active backend.

general_python.algebra.project(v, basis)[source]

Project vector v onto the subspace spanned by the basis vectors.

general_python.algebra.utils

This module provides utilities for importing and managing the linear algebra backend (NumPy/JAX), random number generation, JIT compilation, and backend configuration for the general_python package.

Features

  • BackendManager: Class to detect and manage the active backend (NumPy/JAX), including linear algebra, random number generation, and SciPy modules.

  • Functions to retrieve backend components (get_backend, get_global_backend).

  • Global access to the active backend via the backend_mgr instance.

  • Utilities for JIT compilation (maybe_jit), hardware info (get_hardware_info), and array padding (pad_array).

  • Clear version reporting and backend status printing.

Usage

Import the backend manager and utilities:

from ..algebra.utils import get_backend, backend_mgr, maybe_jit
xp = backend_mgr.np  # NumPy or JAX numpy
rng = backend_mgr.default_rng
@maybe_jit
def my_func(x):
    return xp.sum(x)

Testing

To run tests for this module and the general_python package:

pytest
# or
python -m unittest discover -s Python/test

Test coverage includes: - Singleton identity and initialization - Operator algebra correctness - Monte Carlo sampling - NQS training and evaluation

See the test/ directory for details.

Email : maksymilian.kliczkowski@pwr.edu.pl

general_python.algebra.utils.JIT(x)
general_python.algebra.utils.jax_jit(x)
general_python.algebra.utils.ACTIVE_JIT(x)
general_python.algebra.utils.is_jax_array(x: Any) bool[source]

Checks if an object is likely a JAX array (including traced).

Parameters:

x (Any) – The object to check.

Returns:

True if x is a JAX array (including traced), False otherwise.

Return type:

bool

general_python.algebra.utils.is_traced_jax(x: Any) bool

Checks if an object is likely a JAX array (including traced).

Parameters:

x (Any) – The object to check.

Returns:

True if x is a JAX array (including traced), False otherwise.

Return type:

bool

general_python.algebra.utils.get_backend(backend_spec: str | Any | None = None, random: bool = False, seed: int | None = None, scipy: bool = False) Any | Tuple[Any, ...][source]

Return backend modules based on the provided specifier.

Delegates to the global backend_mgr.get_backend_modules.

Parameters:
  • backend_spec (str or module or None, optional) – Backend specifier (“numpy”, “jax”, np, jnp, “default”, None). Defaults to the globally active backend.

  • random (bool, optional) – If True, include the random module/state. For JAX, also return a PRNG key. For NumPy, returns a seeded RNG instance. Default is False. For JAX backend, ensure you split the returned PRNG key before use.

  • seed (int, optional) – Seed for the random component. If None, uses the global default seed to generate the component for this call. Providing a seed here creates a new RNG/Key for this call based on that specific seed.

  • scipy (bool, optional) – If True, also return the associated SciPy module. Default is False.

Returns:

  • module or tuple – Requested backend components. See BackendManager.get_backend_modules docs.

  • **Example for using JAX backend with random number generation (****)

  • >>> import general_python.algebra.utils as abu

  • >>> import numpy as np

  • >>> if abu.JAX_AVAILABLE

  • … jax_np, (jax_rnd, key), jax_sp = abu.get_backend(“jax”, random=True, scipy=True, seed=42)

  • … key, subkey = jax_rnd.split(key) # Split the key!

  • … random_vector = jax_rnd.uniform(subkey, shape=(5,)) # Use subkey

  • … print(random_vector)

general_python.algebra.utils.get_global_backend(random: bool = False, seed: int | None = None, scipy: bool = False) Any | Tuple[Any, ...][source]

Return the globally configured default backend modules.

Delegates to backend_mgr.get_global_backend_modules.

Parameters:
  • random (bool, optional) – If True, include the random module/state and potentially a key/RNG.

  • seed (int, optional) – Optional seed for this specific request’s random component.

  • scipy (bool, optional) – If True, also return the associated SciPy module.

Returns:

The global default backend module(s). See BackendManager.get_backend_modules.

Return type:

module or tuple

general_python.algebra.utils.maybe_jit(func)[source]

Maybe apply JAX JIT compilation to the function.

general_python.algebra.utils.distinguish_type(typek: Any, backend: Literal['numpy', 'jax'] = 'numpy') Type[generic] | Any[source]

Given a type (e.g. np.float32, jnp.int64, or int), return the corresponding dtype object in either NumPy or JAX.

Parameters:
  • typek – A dtype class or Python int; may be from numpy, jax.numpy, or the builtin int.

  • backend – ‘numpy’ or ‘jax’ — which library the returned dtype should belong to.

Returns:

The requested dtype object (e.g. np.float64 or jnp.int32).

Return type:

dtype

Raises:

ValueError – If typek isn’t one of the supported types, or if you ask for JAX but JAX isn’t available.

general_python.algebra.utils.get_hardware_info() Tuple[int, int][source]

Get the number of available JAX devices and CPU cores.

Returns:

Number of JAX devices (e.g., GPUs/TPUs) if JAX is available, else 0. n_threads :

Number of CPU cores available to the system.

Return type:

n_devices

class general_python.algebra.utils.RNGManager(np_rng: Generator | None, jax_rng: Any | None, py_rng: Random | None)[source]

Bases: object

Container for synchronized random-number generator state.

np_rng

NumPy random generator used by NumPy-backed helpers.

Type:

numpy.random._generator.Generator | None

jax_rng

JAX PRNG key or key-like state used by JAX-backed helpers.

Type:

Any | None

py_rng

Python random.Random instance for standard-library randomness.

Type:

random.Random | None

np_rng: Generator | None
jax_rng: Any | None
py_rng: Random | None
__init__(np_rng: Generator | None, jax_rng: Any | None, py_rng: Random | None) None
class general_python.algebra.utils.BackendManager(default_seed: int = 42, prefer_jax: bool = False)[source]

Bases: object

Manages the numerical backend (NumPy or JAX) state.

Provides access to the appropriate linear algebra module (np/jnp), random number generator, SciPy module, JIT compiler, and backend info.

is_jax_available

True if JAX was successfully imported.

Type:

bool

name

Name of the active backend (“numpy” or “jax”).

Type:

str

np

The active array module (numpy or jax.numpy).

Type:

module

random

The active random module (numpy.random or jax.random).

Type:

module

scipy

The active SciPy module (scipy or jax.scipy).

Type:

module

key

The default JAX PRNG key (if JAX is active).

Type:

Optional[PRNGKey]

jit

The JIT compiler function (jax.jit or identity).

Type:

Callable

default_seed

The seed used for default RNG initialization.

Type:

int

default_rng

Default NumPy RNG instance.

Type:

np.random.Generator | np_random

default_jax_key

Default JAX key instance.

Type:

Optional[PRNGKey]

int_dtype

Default integer type for the active backend.

Type:

Type

float_dtype

Default float type for the active backend.

Type:

Type

complex_dtype

Default complex type for the active backend.

Type:

Type

__init__(default_seed: int = 42, prefer_jax: bool = False)[source]

Initializes the manager, detects JAX, and sets the active backend.

Parameters:
  • default_seed – The seed for initializing default random generators.

  • prefer_jax – If True and JAX is available, use JAX as the default. Otherwise, use NumPy.

default_seed: int
is_jax_available: bool
name: str
np: Any
random: Any
scipy: Any
key: PRNGKey | None
jit: Callable
default_jax_key: PRNGKey | None
detected_jax_backend: str | None
detected_jax_devices: List[Device] | None
set_active_backend(name: str)[source]

Explicitly sets the active backend globally managed by this instance.

Parameters:

name – “numpy”, “jax”

Raises:

ValueError – If ‘jax’ is requested but not available, or invalid name.

print_info()[source]

Prints backend configuration and library versions in a table. Displays the active backend, available libraries, and their versions. Also shows the active integer, float, and complex types.

The output is formatted for better readability.

get_backend_modules(backend_spec: str | Any | None, use_random: bool = False, seed: int | None = None, use_scipy: bool = False) Any | Tuple[Any, ...][source]

Returns backend modules based on the specifier.

Parameters:
  • backend_spec – Backend identifier. Can be: - String: “numpy”, “np”, “jax”, “jnp”, “default”. - Module: numpy or jax.numpy. - None: Uses the manager’s active backend.

  • use_random – If True, include the random module/state. For JAX, returns (jax.random, key). For NumPy, returns (rng_instance, None).

  • seed – Seed for the random number generator. If None, uses the manager’s default seed. If provided, creates a new RNG/Key for this request, independent of the manager’s default state.

  • use_scipy – If True, include the SciPy module.

Returns:

The requested module(s) as a single module or a tuple. Format:

(main_module) or (main_module, random_part, scipy_module) where random_part is (rng_instance, None) for numpy or (jrn_module, key) for jax.

get_global_backend_modules(use_random: bool = False, seed: int | None = None, use_scipy: bool = False) Any | Tuple[Any, ...][source]

Returns the globally configured default backend modules from the manager.

Uses the manager’s current active backend (self.name). If a specific seed is provided, it generates a random component based on that seed for this request, otherwise uses the manager’s default seed/key framework.

Parameters:
  • use_random – If True, include the random module/state.

  • seed – Optional seed for this specific request’s random component.

  • use_scipy – If True, include the SciPy module.

Returns:

The global default backend module(s). See get_backend_modules for format.

reseed(seed: int) RNGManager[source]

Reseed the manager’s RNGs without doing work at import time. Returns an RNGManager instance you can stash if needed.

next_key() None[source]

Return a fresh JAX subkey and advance the manager’s internal key.

split_keys(n: int) Any[source]

Return n fresh subkeys and advance the manager’s internal key once.

seed_scope(seed: int, *, touch_numpy_global: bool = False, touch_python_random: bool = True)[source]

Temporarily set deterministic seeds and restore previous states on exit. Use it as: ``` with seed_scope(seed):

# Your code here

```

Parameters:
  • seed (int) – The seed to set.

  • touch_numpy_global (bool) – Whether to touch the global NumPy random state.

  • touch_python_random (bool) – Whether to touch the Python random state.

spawn_np_generators(n: int) list[Generator][source]

Create n independent NumPy generators using SeedSequence. Use one per worker/process to avoid correlated streams.

spawn_jax_keys(n: int)[source]

Deterministically produce n independent JAX keys.

general_python.algebra.utils.pad_array(x, target_size: int, pad_value, *, backend=None)[source]

Pad a one-dimensional array to target_size with pad_value.

Parameters:
  • x – Input one-dimensional NumPy or JAX array.

  • target_size – Desired output length. Must be at least x.shape[0].

  • pad_value – Value used to initialize positions beyond the input length.

  • backend – Optional array module. If omitted, the backend is inferred from x.

Returns:

Array with shape (target_size,) and the same dtype as x.

Return type:

array-like

general_python/algebra/backend_linalg.py

Unified linear algebra backend providing matrix operations, decompositions, and transformations with support for multiple backends (NumPy, JAX).

This module consolidates linalg.py and linalg_sparse.py functionality into a single backend-aware interface.

Features:
  • Matrix transformations and basis changes

  • Outer and Kronecker products (dense and sparse)

  • Inner products and overlaps

  • Matrix properties (trace, Hilbert-Schmidt norm)

  • Matrix creation (identity, etc.)

  • Eigendecomposition (dense, sparse, Lanczos)

  • State manipulation (Givens rotations)

  • Backend support: NumPy, SciPy, JAX

Type Safety:
  • Proper dtype handling and promotion

  • Complex values preserved through computations

  • Backend-specific type conversions handled transparently

Author: Maksymilian Kliczkowski Email: maksymilian.kliczkowski@pwr.edu.pl

general_python.algebra.backend_linalg.change_basis(unitary_matrix: ndarray | Any, state_vector: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Transform state vector to new basis.

V’ = U V

Parameters:
  • unitary_matrix (array-like, shape (N, N)) – Unitary transformation matrix.

  • state_vector (array-like, shape (N,) or (N, 1)) – State vector to transform.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Transformed state vector.

Return type:

Array

general_python.algebra.backend_linalg.change_basis_matrix(unitary_matrix: ndarray | Any, matrix: ndarray | Any, direction: Literal['forward', 'backward'] = 'forward', backend: str = 'default') ndarray | Any[source]

Change basis of matrix using unitary transformation.

Forward: A’ = U A Udagger Backward: A’ = Udagger A U

Parameters:
  • unitary_matrix (array-like, shape (N, N)) – Unitary transformation matrix U.

  • matrix (array-like, shape (N, N)) – Matrix to transform A.

  • direction ({'forward', 'backward'}, optional) – Transformation direction (default: ‘forward’).

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Transformed matrix.

Return type:

Array

general_python.algebra.backend_linalg.outer(A: ndarray | Any, B: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute outer product of two vectors.

C = A otimes B (element-wise: C[i,j] = A[i] * B[j])

Parameters:
  • A (array-like, shape (N,)) – First vector.

  • B (array-like, shape (M,)) – Second vector.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Outer product, shape (N, M).

Return type:

Array

general_python.algebra.backend_linalg.kron(A: ndarray | Any, B: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute Kronecker product of two matrices (dense).

C = A otimes B

Parameters:
  • A (array-like, shape (N, M)) – First matrix.

  • B (array-like, shape (P, Q)) – Second matrix.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Kronecker product, shape (N*P, M*Q).

Return type:

Array

general_python.algebra.backend_linalg.kron_sparse(A: ndarray | Any, B: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute Kronecker product of sparse matrices.

Preserves sparsity structure efficiently.

Parameters:
  • A (array-like or sparse, shape (N, M)) – First matrix.

  • B (array-like or sparse, shape (P, Q)) – Second matrix.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Kronecker product, sparse if JAX BCOO available, else dense.

Return type:

Array

general_python.algebra.backend_linalg.inner(vec1: ndarray | Any, vec2: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute inner product of two vectors.

<v1|v2> = v1dagger cdot v2

Parameters:
  • vec1 (array-like, shape (N,)) – First vector.

  • vec2 (array-like, shape (N,)) – Second vector.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Inner product (complex if inputs complex, real otherwise).

Return type:

scalar

general_python.algebra.backend_linalg.ket_bra(vec: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute ket-bra (outer product) of a vector.

|Psi ><Psi | = Psi Psi^dagger

Parameters:
  • vec (array-like, shape (N,)) – Vector.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Outer product matrix, shape (N, N).

Return type:

Array

general_python.algebra.backend_linalg.bra_ket(vec: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Compute bra-ket (inner product) of a vector with itself.

<Psi |Psi > = Psi dagger cdot Psi = ||Psi ||²

Parameters:
  • vec (array-like, shape (N,)) – Vector.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Squared norm (always real and non-negative).

Return type:

scalar

general_python.algebra.backend_linalg.overlap(a: ndarray | Any, O: ndarray | Any, b: ndarray | Any | None = None, backend: str = 'default') ndarray | Any[source]

Compute matrix element <a|O|b>.

Supports:
  • 1D vectors: returns scalar

  • 2D matrices (columns are states): returns matrix of overlaps

  • Mixed dimensions: returns vector

Parameters:
  • a (array-like, shape (dim,) or (dim, n)) – Left states. Columns are individual states if 2D.

  • O (array-like or sparse, shape (dim, dim)) – Operator matrix.

  • b (array-like, shape (dim,) or (dim, m), optional) – Right states. If None, defaults to a.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

<a|O|b>. Shape depends on input dimensions.

Return type:

scalar, 1D or 2D Array

Examples

>>> # Single pair of vectors: scalar result
>>> s = overlap(psi, H, phi)  # shape: (), dtype: complex
>>> # Multiple right states: vector result
>>> v = overlap(psi, H, phi_array)  # shape: (n,), dtype: complex
>>> # Multiple left and right: matrix result
>>> M = overlap(psi_array, H, phi_array)  # shape: (m, n), dtype: complex
general_python.algebra.backend_linalg.overlap_diagonal(a: ndarray | Any, O: ndarray | Any, b: ndarray | Any | None = None, backend: str = 'default') ndarray | Any[source]

Compute only diagonal elements <a_i|O|b_i>.

More efficient than full overlap() when only diagonals needed.

Parameters:
  • a (array-like, shape (dim,) or (dim, n)) – Left states (columns are states if 2D).

  • O (array-like, shape (dim, dim)) – Operator matrix.

  • b (array-like, shape (dim,) or (dim, n), optional) – Right states. If None, defaults to a.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Diagonal elements <a_i|O|b_i>.

Return type:

scalar or 1D Array

general_python.algebra.backend_linalg.trace(matrix: ndarray | Any, backend: str = 'default') Any[source]

Compute matrix trace.

Tr(A) = Sum_i A_ii

Parameters:
  • matrix (array-like, shape (N, N)) – Matrix to compute trace of.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Trace of matrix.

Return type:

scalar

general_python.algebra.backend_linalg.hilbert_schmidt_norm(matrix: ndarray | Any, backend: str = 'default') Any[source]

Compute Hilbert-Schmidt norm of matrix.

||A||_HS = √(Tr(Adagger A))

Parameters:
  • matrix (array-like, shape (N, N)) – Matrix to compute norm of.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Hilbert-Schmidt norm.

Return type:

scalar (real)

general_python.algebra.backend_linalg.frobenius_norm(matrix: ndarray | Any, backend: str = 'default') Any[source]

Compute Frobenius norm of matrix.

||A||_F = √(Sum_ij |A_ij|²)

Parameters:
  • matrix (array-like, shape (N, M)) – Matrix to compute norm of.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Frobenius norm.

Return type:

scalar (real)

general_python.algebra.backend_linalg.identity(n: int, dtype: str | dtype | None = None, backend: str = 'default') ndarray | Any[source]

Create identity matrix I_n.

Parameters:
  • n (int) – Size of identity matrix.

  • dtype (data-type, optional) – Element data type (default: backend default).

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Identity matrix, shape (n, n).

Return type:

Array

general_python.algebra.backend_linalg.identity_sparse(n: int, dtype: str | dtype | None = None, backend: str = 'default') csr_matrix | BCOO[source]

Create sparse identity matrix.

Parameters:
  • n (int) – Size of identity matrix.

  • dtype (data-type, optional) – Element data type (default: float64).

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Sparse identity: scipy.sparse.csr_matrix (numpy) or BCOO (jax).

Return type:

sparse matrix

general_python.algebra.backend_linalg.to_dense(matrix: ndarray | Any, backend: str = 'default') ndarray | Any[source]

Convert sparse or other formats to dense array.

Parameters:
  • matrix (array-like) – Matrix to convert.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Dense array.

Return type:

Array

general_python.algebra.backend_linalg.to_sparse(matrix: ndarray | Any, backend: str = 'default', format: str = 'csr') csr_matrix | BCOO[source]

Convert dense matrix to sparse format.

Parameters:
  • matrix (array-like, shape (N, M)) – Matrix to convert.

  • backend (str, optional) – Numerical backend (default: ‘default’).

  • format (str, optional) – Sparse format: ‘csr’, ‘csc’, ‘coo’, ‘bsr’ (numpy) or ‘bcoo’ (jax). Default: ‘csr’.

Returns:

Sparse matrix in specified format.

Return type:

sparse matrix

general_python.algebra.backend_linalg.eig(matrix: ndarray | Any, backend: str = 'default', **kwargs) Tuple[ndarray | Any, ndarray | Any][source]

Eigendecomposition of general matrix.

Computes all eigenvalues and eigenvectors.

Parameters:
  • matrix (array-like, shape (N, N)) – Matrix to diagonalize.

  • backend (str, optional) – Numerical backend (default: ‘default’).

  • **kwargs – Additional arguments (backend-specific).

Returns:

  • eigenvalues (Array, shape (N,)) – Eigenvalues.

  • eigenvectors (Array, shape (N, N)) – Eigenvectors as columns.

general_python.algebra.backend_linalg.eigh(matrix: ndarray | Any, backend: str = 'default', **kwargs) Tuple[ndarray | Any, ndarray | Any][source]

Eigendecomposition of Hermitian matrix (dense).

More stable than eig() for Hermitian/symmetric matrices.

Parameters:
  • matrix (array-like, shape (N, N)) – Hermitian matrix to diagonalize.

  • backend (str, optional) – Numerical backend (default: ‘default’).

  • **kwargs – Additional arguments (backend-specific).

Returns:

  • eigenvalues (Array, shape (N,), real) – Eigenvalues in ascending order.

  • eigenvectors (Array, shape (N, N)) – Eigenvectors as columns.

general_python.algebra.backend_linalg.eigsh(matrix: ndarray | Any, k: int = 6, which: Literal['smallest', 'largest'] = 'smallest', backend: str = 'default', **kwargs) Tuple[ndarray | Any, ndarray | Any][source]

Partial eigendecomposition of Hermitian matrix (sparse).

Computes k extremal eigenvalues and eigenvectors. Uses SciPy sparse methods.

Parameters:
  • matrix (array-like, shape (N, N)) – Hermitian matrix to diagonalize.

  • k (int, optional) – Number of eigenvalues to compute (default: 6).

  • which ({'smallest', 'largest'}, optional) – Which eigenvalues to compute (default: ‘smallest’).

  • backend (str, optional) – Numerical backend (default: ‘default’).

  • **kwargs – Additional arguments passed to scipy.sparse.linalg.eigsh.

Returns:

  • eigenvalues (Array, shape (k,)) – k extremal eigenvalues.

  • eigenvectors (Array, shape (N, k)) – Corresponding eigenvectors as columns.

general_python.algebra.backend_linalg.givens_rotation(V: ndarray | Any, i: int, j: int, theta: float, backend: str = 'default') ndarray | Any[source]

Apply Givens rotation to matrix or vector.

Rotates the (i,j) plane by angle θ.

Parameters:
  • V (array-like) – Matrix or vector to rotate.

  • i (int) – Indices of rotation plane.

  • j (int) – Indices of rotation plane.

  • theta (float) – Rotation angle in radians.

  • backend (str, optional) – Numerical backend (default: ‘default’).

Returns:

Rotated matrix/vector.

Return type:

Array

Abstract interfaces for backend-aware linear-system solvers.

The module defines shared types and base classes for algorithms that solve linear systems of the form A x = b. Implementations may operate on explicit matrices, matrix-vector callables, or Fisher/Gram factors and can use optional left preconditioning through callables of the form r -> M^{-1} r.

The base interface is intentionally compatible with NumPy and JAX. Static solver kernels are preferred for algorithm implementations because they are easier to JIT-compile and reuse from configured solver instances.

email : maxgrom97@gmail.com

class general_python.algebra.solver.SolverType(*values)[source]

Bases: Enum

Enumeration class for the different types of solvers.

DIRECT = 1
BACKEND = 2
PSEUDO_INVERSE = 3
SCIPY_CG = 4
SCIPY_MINRES = 5
SCIPY_GMRES = 6
SCIPY_DIRECT = 7
CG = 8
MINSR = 9
MINRES = 10
MINRES_QLP = 11
GMRES = 12
ARNOLDI = 13
class general_python.algebra.solver.SolverErrorMsg(*values)[source]

Bases: Enum

Enumeration class for solver error messages.

MATVEC_FUNC_NOT_SET = 101
MAT_NOT_SET = 102
DIM_MISMATCH = 106
METHOD_NOT_IMPL = 109
BACKEND_MISMATCH = 111
INVALID_INPUT = 112
COMPILATION_NA = 113
exception general_python.algebra.solver.SolverError(code: SolverErrorMsg, message: str | None = None)[source]

Bases: Exception

Base class for exceptions in the solver module.

__init__(code: SolverErrorMsg, message: str | None = None)[source]
class general_python.algebra.solver.SolverResult(x: ndarray | Array, converged: bool, iterations: int, residual_norm: float | None)[source]

Bases: NamedTuple

Stores the result of a solver’s static execution.

x

The computed solution vector.

Type:

Array

converged

Whether the solver reached the desired tolerance.

Type:

bool

iterations

The number of iterations performed.

Type:

int

residual_norm

The norm of the final residual (||b - Ax||).

Type:

Optional[float]

x: ndarray | Array

Alias for field number 0

converged: bool

Alias for field number 1

iterations: int

Alias for field number 2

residual_norm: float | None

Alias for field number 3

class general_python.algebra.solver.Solver(backend: str = 'default', dtype: Type | None = None, eps: float = 1e-08, maxiter: int = 1000, default_precond: Preconditioner | None = None, matvec_func: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, sigma: float | None = None, is_gram: bool = False, **kwargs)[source]

Bases: ABC

Abstract base class for linear system solvers

Targets problems of the form $Ax = b$.

Primarily defines the static interface solve that concrete algorithm implementations (like CG, MINRES) must provide.

Also includes static helpers create_matvec_from_* to construct the matrix-vector product function and an optional instance method solve_instance for convenience when working with configured Solver objects.

Normally, one should focus on using the implementation provided by the constructor-set function due to the optimized call. Nevertheless, instance based setups are convenient for the less-time consuming tasks.

The solver can be initialized with a matrix A, a matrix-vector multiplication function, or a Fisher matrix S. The solver can also be initialized with a preconditioner M.

The solver can be used to solve the linear system Ax = b or M^{-1}Ax = M^{-1}b.

__init__(backend: str = 'default', dtype: Type | None = None, eps: float = 1e-08, maxiter: int = 1000, default_precond: Preconditioner | None = None, matvec_func: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, sigma: float | None = None, is_gram: bool = False, **kwargs)[source]

Initializes solver metadata and optionally pre-configures for instance usage.

Parameters:
  • backend (str) – Preferred backend (‘numpy’, ‘jax’). Affects helpers.

  • dtype (Type, optional) – Default data type.

  • eps (float) – Default tolerance for convenience methods.

  • maxiter (int) – Default max iterations for convenience methods.

  • default_precond (Preconditioner, optional) – A default preconditioner instance.

  • a (Array, optional) – Explicit matrix A for instance setup.

  • s (Array, optional) – Matrix S for Fisher setup.

  • s_p (Array, optional) – Matrix Sp for Fisher setup.

  • matvec_func (Callable, optional) – Explicit matvec function.

  • sigma (float, optional) – Default regularization for setup helpers.

  • is_gram (bool) – If using Fisher setup (S, Sp).

Example

>>> solver = CgSolver(backend='jax', dtype=jnp.float32, eps=1e-6, maxiter=500)
>>> result = solver.solve_instance(b, x0)
static create_matvec_from_matrix_jax(a: ndarray | Array, sigma: float | None = None) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:

a (Array):

The matrix (dense or sparse compatible with JAX).

sigma (float, optional):

Optional regularization parameter.

Returns:

MatVecFunc:

The matrix-vector product function.

static create_matvec_from_matrix_np(a: ndarray | Array, sigma: float | None = None) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:
  • a (Array) – The matrix (dense or sparse compatible with NumPy).

  • sigma (float, optional) – Optional regularization parameter.

Returns:

The matrix-vector product function.

Return type:

MatVecFunc

static create_matvec_from_matrix(a: ndarray | Array, sigma: float | None = None, backend_module: Any = <module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/general-python/envs/latest/lib/python3.12/site-packages/numpy/__init__.py'>, compile_func: bool = False) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:
  • a (np.ndarray, jnp.ndarray) – The matrix (dense or sparse compatible with backend).

  • sigma (float) – Optional regularization parameter.

  • backend_module – The backend (e.g., np, jnp) to use for operations.

Returns:

The matrix-vector product function.

Return type:

Callable[[Array], Array]

static create_matvec_from_fisher_jax(s: ndarray | Array, s_p: ndarray | Array, n: int | None = None, sigma: float | None = None, create_full: bool | None = False) Callable[[ndarray | Array], ndarray | Array][source]

Creates a matrix-vector multiplication function (MatVecFunc) based on a Fisher information inspired formulation. This function constructs a custom matrix-vector product operator using the input arrays s and s_p. It first computes the normalization constant n as the number of rows of s. Depending on the flag create_full, it either:

  • Computes the full matrix as (s_p @ s) / n and passes it along with sigma to Solver.create_matvec_from_matrix_jax,

    or

  • Constructs a matvec function that, for a given vector x, computes:
    1. s_dot_x = dot(s, x)

    2. sp_dot_s_dot_x = dot(s_p, s_dot_x)

    3. The output as sp_dot_s_dot_x / n, with an additional term sigma * x if sigma is not None and non-zero.

The resulting matvec function is then compiled (presumably via JAX JIT) using Solver._compile_helper_jax. :param s: A JAX array representing the first component used in constructing the operator. :type s: Array :param s_p: A JAX array representing the second component used alongside s. :type s_p: Array :param sigma: A scalar value to be added to the diagonal (identity) part of the operation.

Defaults to None.

Parameters:

create_full (Optional[bool], optional) – A flag to determine whether to create a full matrix operator using Solver.create_matvec_from_matrix_jax. Defaults to False.

Returns:

A function that computes the matrix-vector product corresponding to the constructed operator.

This operator is compiled using Solver._compile_helper_jax.

Return type:

MatVecFunc

static create_matvec_from_fisher_np(s: ndarray, s_p: ndarray, n: int | None = None, sigma: float | None = None, create_full: bool | None = False) Callable[[ndarray | Array], ndarray | Array][source]

Creates a matrix-vector multiplication function (matvec) based on the Fisher information matrix. This function generates a matvec function that computes the product of a vector with a matrix derived from the input arrays s and s_p. Optionally, a regularization term sigma can be added to the diagonal of the matrix. The function can also return a full matrix-based matvec if create_full is set to True. :param s: A 2D array representing the first operand in the Fisher information matrix computation. :type s: np.ndarray :param s_p: A 2D array representing the second operand in the Fisher information matrix computation. :type s_p: np.ndarray :param sigma: A regularization parameter added to the diagonal of the matrix.

Defaults to None, which means no regularization is applied.

Parameters:

create_full (Optional[bool], optional) – If True, creates a matvec function based on the full matrix s @ s_p / n. Defaults to False.

Returns:

A function that performs matrix-vector multiplication with the derived matrix.

Return type:

MatVecFunc

static create_matvec_from_fisher(s: ndarray | Array, s_p: ndarray | Array, n: int | None = None, sigma: float | None = None, backend_module: Any = <module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/general-python/envs/latest/lib/python3.12/site-packages/numpy/__init__.py'>, create_full: bool = False, compile_func: bool = False) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (Sp @ S / N + sigma*I) @ x.

Parameters:
  • s – Matrix S.

  • s_p – Matrix Sp (transpose/adjoint of S).

  • n – Normalization factor (often number of samples/outputs). Defaults to S.shape[0].

  • sigma – Optional regularization parameter.

  • backend_module – The backend (e.g., np, jnp) to use for operations.

Returns:

The matrix-vector product function.

Return type:

Callable[[Array], Array]

static run_solver_func(backend_module: Any, solver_func: Callable[[...], SolverResult], *, matvec: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, b: ndarray | Array, x0: ndarray | Array | None, tol: float, maxiter: int, precond_apply: Callable[[ndarray | Array], ndarray | Array] | None = None, sigma: float | None = None, normalization: int | None = None, **kwargs: Any) SolverResult[source]

Dispatch a wrapped solver function through matrix, Fisher, or matvec mode.

abstractmethod static solve(matvec: Callable[[ndarray | Array], ndarray | Array], b: ndarray | Array, x0: ndarray | Array, *, tol: float, maxiter: int, precond_apply: Callable[[ndarray | Array], ndarray | Array] | None = None, backend_module: Any, **kwargs: Any) SolverResult[source]
Abstract Static:

Solves the linear system Ax = b using a specific algorithm.

Requires all inputs explicitly. Concrete implementations (e.g., CgSolver.solve) contain the actual algorithm for the specified backend.

Parameters:

matvecCallable[[Array], Array]

Function implementing the matrix-vector product A @ x. It must accept a vector of shape (N,) and return a vector of shape (N,). Must be compatible with backend_module (NumPy or JAX).

bArray

Right-hand side vector of shape (N,). Must be a backend_module array.

x0Array

Initial guess vector of shape (N,). Must be a backend_module array.

tolfloat

Relative convergence tolerance (||Ax - b|| / ||b||).

maxiterint

Maximum number of iterations allowed.

precond_applyCallable[[Array], Array], optional

Function applying the preconditioner M^{-1}. Takes a vector r of shape (N,) and returns M^{-1}r of shape (N,). Must be compatible with backend_module.

backend_modulemodule

The numerical backend module to use for array operations (e.g., numpy or jax.numpy). This allows the solver logic to be backend-agnostic.

**kwargsAny

Additional solver-specific keyword arguments (e.g., restart for GMRES).

returns:

A named tuple containing: - x (Array): The computed solution vector of shape (N,). - converged (bool): True if the solver reached the desired tolerance. - iterations (int): The number of iterations performed. - residual_norm (float): The norm of the final residual (||b - Ax||).

rtype:

SolverResult

raises NotImplementedError:

If a subclass hasn’t implemented this method.

raises SolverError:

If convergence fails catastrophically or inputs are invalid.

abstractmethod static get_solver_func(backend_module: Any, use_matvec: bool = True, use_fisher: bool = False, use_matrix: bool = False, sigma: float | None = None, **kwargs) Callable[[...], SolverResult][source]
Abstract Static:

Retrieves the solver function, which may be JIT-compiled (with JAX), Numba-compiled, or a plain Python function based on the provided backend_module.

Parameters:
  • backend_module

  • backend (The numerical)

Returns:

A callable with the signature:

(matvec, b, x0, tol, maxiter, precond_apply, **kwargs) -> SolverResult

Return type:

StaticSolverFunc

Note

The backend_module helps in tailoring the solver function for the appropriate numerical library.

solve_instance(b: ndarray | Array, x0: ndarray | Array | None = None, *, tol: float | None = None, maxiter: int | None = None, precond: Preconditioner | Callable[[ndarray | Array], ndarray | Array] | None = 'default', sigma: float | None = None, compile_matvec: bool = False, **kwargs) SolverResult[source]

Convenience instance method to run the solver.

Sets up matvec and precond_apply based on instance configuration (if provided during __init__) or arguments, then calls the static solve method of this solver’s class. Stores the result in instance attributes.

Parameters:
  • b (Array) – Right-hand side vector.

  • x0 (Optional[Array]) – Initial guess. Defaults to zeros.

  • tol (Optional[float]) – Tolerance override. Uses instance default if None.

  • maxiter (Optional[int]) – Max iterations override. Uses instance default if None.

  • precond (Union[Preconditioner, Callable, None, str]) –

    Preconditioner for this solve.
    • If Preconditioner instance:

      Uses its __call__ method.

    • If Callable:

      Assumes it’s r -> M^{-1}r and uses it directly.

    • If None:

      No preconditioning.

    • If ‘default’:

      Uses None…

  • sigma (Optional[float]) – Regularization for matvec creation if matvec is not already defined for the instance. Uses instance default _conf_sigma if None.

  • **kwargs – Additional arguments passed directly to the static solve.

Returns:

Result from the static solve method.

Return type:

SolverResult

property solution: ndarray | Array | None

Solution vector from the most recent instance solve, if any.

property converged: bool | None

Whether the most recent instance solve reported convergence.

property iterations: int | None

Iteration count from the most recent instance solve.

property residual_norm: float | None

Final residual norm from the most recent instance solve.

property backend_str: str

Normalized backend name used by this solver.

property dtype: Type

Default dtype used by this solver instance.

property default_eps: float

Default convergence tolerance for instance solves.

property default_maxiter: int

Default maximum iteration count for instance solves.

general_python.algebra.solver.sym_ortho(a, b, backend: str = 'default')[source]

Stable symmetric Householder (Givens) reflection.

Computes parameters c, s, r such that:

[ c s ] [ a ] = [ r ] [ s -c ] [ b ] [ 0 ]

For real inputs, r = sqrt(a^2 + b^2) is nonnegative. For complex inputs, r preserves the phase of a (if b==0) or b (if a==0), and the reflectors are computed in a stable manner.

Parameters:
  • a (scalar (real or complex)) – The first element of the two-vector [a; b].

  • b (scalar (real or complex)) – The second element of the two-vector [a; b].

  • backend (str, optional (default "default")) – Specifies which backend to use. If set to “jax”, the function uses jax.numpy and is jitted for speed.

Returns:

  • (c, s, r) (tuple of scalars) –

    The computed reflection parameters satisfying:

    c = a / r and s = b / r,

    with r = sqrt(a^2 + b^2) for real numbers (or the appropriately phased value for complex).

  • Numerical stability

  • ——————-

  • This function avoids overflow and underflow by scaling by the larger magnitude component

  • (either |a| or |b|). This ensures that the intermediate calculations of tau

  • and the hypotenuse do not exceed floating-point range limits unnecessarily.

Solver module for various linear algebra solvers.

Initialization file for the solvers module. Exports solver classes, the SolverType enum, and the choose_solver factory function. —————————————————————- File : general_python/algebra/solvers/__init__.py Author : Maksymilian Kliczkowski License : MIT Description : This module provides a factory function to choose and instantiate

different solver types based on user input. It includes various solver implementations (direct, iterative, etc.) and allows for customization through keyword arguments. The module also includes a utility function to generate test matrix-vector pairs for testing purposes. The solvers are designed to work with different numerical backends (e.g., NumPy, SciPy, JAX) and support various data types. The module is part of a larger algebra library and aims to provide a flexible and extensible framework for solving linear algebra problems.


class general_python.algebra.solvers.Solver(backend: str = 'default', dtype: Type | None = None, eps: float = 1e-08, maxiter: int = 1000, default_precond: Preconditioner | None = None, matvec_func: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, sigma: float | None = None, is_gram: bool = False, **kwargs)[source]

Bases: ABC

Abstract base class for linear system solvers

Targets problems of the form $Ax = b$.

Primarily defines the static interface solve that concrete algorithm implementations (like CG, MINRES) must provide.

Also includes static helpers create_matvec_from_* to construct the matrix-vector product function and an optional instance method solve_instance for convenience when working with configured Solver objects.

Normally, one should focus on using the implementation provided by the constructor-set function due to the optimized call. Nevertheless, instance based setups are convenient for the less-time consuming tasks.

The solver can be initialized with a matrix A, a matrix-vector multiplication function, or a Fisher matrix S. The solver can also be initialized with a preconditioner M.

The solver can be used to solve the linear system Ax = b or M^{-1}Ax = M^{-1}b.

__init__(backend: str = 'default', dtype: Type | None = None, eps: float = 1e-08, maxiter: int = 1000, default_precond: Preconditioner | None = None, matvec_func: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, sigma: float | None = None, is_gram: bool = False, **kwargs)[source]

Initializes solver metadata and optionally pre-configures for instance usage.

Parameters:
  • backend (str) – Preferred backend (‘numpy’, ‘jax’). Affects helpers.

  • dtype (Type, optional) – Default data type.

  • eps (float) – Default tolerance for convenience methods.

  • maxiter (int) – Default max iterations for convenience methods.

  • default_precond (Preconditioner, optional) – A default preconditioner instance.

  • a (Array, optional) – Explicit matrix A for instance setup.

  • s (Array, optional) – Matrix S for Fisher setup.

  • s_p (Array, optional) – Matrix Sp for Fisher setup.

  • matvec_func (Callable, optional) – Explicit matvec function.

  • sigma (float, optional) – Default regularization for setup helpers.

  • is_gram (bool) – If using Fisher setup (S, Sp).

Example

>>> solver = CgSolver(backend='jax', dtype=jnp.float32, eps=1e-6, maxiter=500)
>>> result = solver.solve_instance(b, x0)
static create_matvec_from_matrix_jax(a: ndarray | Array, sigma: float | None = None) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:

a (Array):

The matrix (dense or sparse compatible with JAX).

sigma (float, optional):

Optional regularization parameter.

Returns:

MatVecFunc:

The matrix-vector product function.

static create_matvec_from_matrix_np(a: ndarray | Array, sigma: float | None = None) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:
  • a (Array) – The matrix (dense or sparse compatible with NumPy).

  • sigma (float, optional) – Optional regularization parameter.

Returns:

The matrix-vector product function.

Return type:

MatVecFunc

static create_matvec_from_matrix(a: ndarray | Array, sigma: float | None = None, backend_module: Any = <module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/general-python/envs/latest/lib/python3.12/site-packages/numpy/__init__.py'>, compile_func: bool = False) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (A + sigma*I) @ x.

Parameters:
  • a (np.ndarray, jnp.ndarray) – The matrix (dense or sparse compatible with backend).

  • sigma (float) – Optional regularization parameter.

  • backend_module – The backend (e.g., np, jnp) to use for operations.

Returns:

The matrix-vector product function.

Return type:

Callable[[Array], Array]

static create_matvec_from_fisher_jax(s: ndarray | Array, s_p: ndarray | Array, n: int | None = None, sigma: float | None = None, create_full: bool | None = False) Callable[[ndarray | Array], ndarray | Array][source]

Creates a matrix-vector multiplication function (MatVecFunc) based on a Fisher information inspired formulation. This function constructs a custom matrix-vector product operator using the input arrays s and s_p. It first computes the normalization constant n as the number of rows of s. Depending on the flag create_full, it either:

  • Computes the full matrix as (s_p @ s) / n and passes it along with sigma to Solver.create_matvec_from_matrix_jax,

    or

  • Constructs a matvec function that, for a given vector x, computes:
    1. s_dot_x = dot(s, x)

    2. sp_dot_s_dot_x = dot(s_p, s_dot_x)

    3. The output as sp_dot_s_dot_x / n, with an additional term sigma * x if sigma is not None and non-zero.

The resulting matvec function is then compiled (presumably via JAX JIT) using Solver._compile_helper_jax. :param s: A JAX array representing the first component used in constructing the operator. :type s: Array :param s_p: A JAX array representing the second component used alongside s. :type s_p: Array :param sigma: A scalar value to be added to the diagonal (identity) part of the operation.

Defaults to None.

Parameters:

create_full (Optional[bool], optional) – A flag to determine whether to create a full matrix operator using Solver.create_matvec_from_matrix_jax. Defaults to False.

Returns:

A function that computes the matrix-vector product corresponding to the constructed operator.

This operator is compiled using Solver._compile_helper_jax.

Return type:

MatVecFunc

static create_matvec_from_fisher_np(s: ndarray, s_p: ndarray, n: int | None = None, sigma: float | None = None, create_full: bool | None = False) Callable[[ndarray | Array], ndarray | Array][source]

Creates a matrix-vector multiplication function (matvec) based on the Fisher information matrix. This function generates a matvec function that computes the product of a vector with a matrix derived from the input arrays s and s_p. Optionally, a regularization term sigma can be added to the diagonal of the matrix. The function can also return a full matrix-based matvec if create_full is set to True. :param s: A 2D array representing the first operand in the Fisher information matrix computation. :type s: np.ndarray :param s_p: A 2D array representing the second operand in the Fisher information matrix computation. :type s_p: np.ndarray :param sigma: A regularization parameter added to the diagonal of the matrix.

Defaults to None, which means no regularization is applied.

Parameters:

create_full (Optional[bool], optional) – If True, creates a matvec function based on the full matrix s @ s_p / n. Defaults to False.

Returns:

A function that performs matrix-vector multiplication with the derived matrix.

Return type:

MatVecFunc

static create_matvec_from_fisher(s: ndarray | Array, s_p: ndarray | Array, n: int | None = None, sigma: float | None = None, backend_module: Any = <module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/general-python/envs/latest/lib/python3.12/site-packages/numpy/__init__.py'>, create_full: bool = False, compile_func: bool = False) Callable[[ndarray | Array], ndarray | Array][source]
Static Helper:

Creates matvec function x -> (Sp @ S / N + sigma*I) @ x.

Parameters:
  • s – Matrix S.

  • s_p – Matrix Sp (transpose/adjoint of S).

  • n – Normalization factor (often number of samples/outputs). Defaults to S.shape[0].

  • sigma – Optional regularization parameter.

  • backend_module – The backend (e.g., np, jnp) to use for operations.

Returns:

The matrix-vector product function.

Return type:

Callable[[Array], Array]

static run_solver_func(backend_module: Any, solver_func: Callable[[...], SolverResult], *, matvec: Callable[[ndarray | Array], ndarray | Array] | None = None, a: ndarray | Array | None = None, s: ndarray | Array | None = None, s_p: ndarray | Array | None = None, b: ndarray | Array, x0: ndarray | Array | None, tol: float, maxiter: int, precond_apply: Callable[[ndarray | Array], ndarray | Array] | None = None, sigma: float | None = None, normalization: int | None = None, **kwargs: Any) SolverResult[source]

Dispatch a wrapped solver function through matrix, Fisher, or matvec mode.

abstractmethod static solve(matvec: Callable[[ndarray | Array], ndarray | Array], b: ndarray | Array, x0: ndarray | Array, *, tol: float, maxiter: int, precond_apply: Callable[[ndarray | Array], ndarray | Array] | None = None, backend_module: Any, **kwargs: Any) SolverResult[source]
Abstract Static:

Solves the linear system Ax = b using a specific algorithm.

Requires all inputs explicitly. Concrete implementations (e.g., CgSolver.solve) contain the actual algorithm for the specified backend.

Parameters:

matvecCallable[[Array], Array]

Function implementing the matrix-vector product A @ x. It must accept a vector of shape (N,) and return a vector of shape (N,). Must be compatible with backend_module (NumPy or JAX).

bArray

Right-hand side vector of shape (N,). Must be a backend_module array.

x0Array

Initial guess vector of shape (N,). Must be a backend_module array.

tolfloat

Relative convergence tolerance (||Ax - b|| / ||b||).

maxiterint

Maximum number of iterations allowed.

precond_applyCallable[[Array], Array], optional

Function applying the preconditioner M^{-1}. Takes a vector r of shape (N,) and returns M^{-1}r of shape (N,). Must be compatible with backend_module.

backend_modulemodule

The numerical backend module to use for array operations (e.g., numpy or jax.numpy). This allows the solver logic to be backend-agnostic.

**kwargsAny

Additional solver-specific keyword arguments (e.g., restart for GMRES).

returns:

A named tuple containing: - x (Array): The computed solution vector of shape (N,). - converged (bool): True if the solver reached the desired tolerance. - iterations (int): The number of iterations performed. - residual_norm (float): The norm of the final residual (||b - Ax||).

rtype:

SolverResult

raises NotImplementedError:

If a subclass hasn’t implemented this method.

raises SolverError:

If convergence fails catastrophically or inputs are invalid.

abstractmethod static get_solver_func(backend_module: Any, use_matvec: bool = True, use_fisher: bool = False, use_matrix: bool = False, sigma: float | None = None, **kwargs) Callable[[...], SolverResult][source]
Abstract Static:

Retrieves the solver function, which may be JIT-compiled (with JAX), Numba-compiled, or a plain Python function based on the provided backend_module.

Parameters:
  • backend_module

  • backend (The numerical)

Returns:

A callable with the signature:

(matvec, b, x0, tol, maxiter, precond_apply, **kwargs) -> SolverResult

Return type:

StaticSolverFunc

Note

The backend_module helps in tailoring the solver function for the appropriate numerical library.

solve_instance(b: ndarray | Array, x0: ndarray | Array | None = None, *, tol: float | None = None, maxiter: int | None = None, precond: Preconditioner | Callable[[ndarray | Array], ndarray | Array] | None = 'default', sigma: float | None = None, compile_matvec: bool = False, **kwargs) SolverResult[source]

Convenience instance method to run the solver.

Sets up matvec and precond_apply based on instance configuration (if provided during __init__) or arguments, then calls the static solve method of this solver’s class. Stores the result in instance attributes.

Parameters:
  • b (Array) – Right-hand side vector.

  • x0 (Optional[Array]) – Initial guess. Defaults to zeros.

  • tol (Optional[float]) – Tolerance override. Uses instance default if None.

  • maxiter (Optional[int]) – Max iterations override. Uses instance default if None.

  • precond (Union[Preconditioner, Callable, None, str]) –

    Preconditioner for this solve.
    • If Preconditioner instance:

      Uses its __call__ method.

    • If Callable:

      Assumes it’s r -> M^{-1}r and uses it directly.

    • If None:

      No preconditioning.

    • If ‘default’:

      Uses None…

  • sigma (Optional[float]) – Regularization for matvec creation if matvec is not already defined for the instance. Uses instance default _conf_sigma if None.

  • **kwargs – Additional arguments passed directly to the static solve.

Returns:

Result from the static solve method.

Return type:

SolverResult

property solution: ndarray | Array | None

Solution vector from the most recent instance solve, if any.

property converged: bool | None

Whether the most recent instance solve reported convergence.

property iterations: int | None

Iteration count from the most recent instance solve.

property residual_norm: float | None

Final residual norm from the most recent instance solve.

property backend_str: str

Normalized backend name used by this solver.

property dtype: Type

Default dtype used by this solver instance.

property default_eps: float

Default convergence tolerance for instance solves.

property default_maxiter: int

Default maximum iteration count for instance solves.

class general_python.algebra.solvers.SolverResult(x: ndarray | Array, converged: bool, iterations: int, residual_norm: float | None)[source]

Bases: NamedTuple

Stores the result of a solver’s static execution.

x

The computed solution vector.

Type:

Array

converged

Whether the solver reached the desired tolerance.

Type:

bool

iterations

The number of iterations performed.

Type:

int

residual_norm

The norm of the final residual (||b - Ax||).

Type:

Optional[float]

x: ndarray | Array

Alias for field number 0

converged: bool

Alias for field number 1

iterations: int

Alias for field number 2

residual_norm: float | None

Alias for field number 3

exception general_python.algebra.solvers.SolverError(code: SolverErrorMsg, message: str | None = None)[source]

Bases: Exception

Base class for exceptions in the solver module.

__init__(code: SolverErrorMsg, message: str | None = None)[source]
class general_python.algebra.solvers.SolverErrorMsg(*values)[source]

Bases: Enum

Enumeration class for solver error messages.

MATVEC_FUNC_NOT_SET = 101
MAT_NOT_SET = 102
DIM_MISMATCH = 106
METHOD_NOT_IMPL = 109
BACKEND_MISMATCH = 111
INVALID_INPUT = 112
COMPILATION_NA = 113
class general_python.algebra.solvers.SolverType(*values)[source]

Bases: Enum

Enumeration class for the different types of solvers.

DIRECT = 1
BACKEND = 2
PSEUDO_INVERSE = 3
SCIPY_CG = 4
SCIPY_MINRES = 5
SCIPY_GMRES = 6
SCIPY_DIRECT = 7
CG = 8
MINSR = 9
MINRES = 10
MINRES_QLP = 11
GMRES = 12
ARNOLDI = 13
class general_python.algebra.solvers.SolverForm(*values)[source]

Bases: Enum

Enum for solver forms.

MATRIX = 1
MATVEC = 2
GRAM = 3
general_python.algebra.solvers.choose_solver(solver_id: str | int | SolverType | Type[Solver], backend: str = 'default', *, sigma: float | None = None, is_gram: bool = False, default_precond: Preconditioner | None = None, **kwargs) Solver[source]

Factory function to select and instantiate a solver based on identifier. Uses lazy loading to import specific solver classes only when requested.

Parameters:
  • solver_id (Union[str, int, SolverType, Type[Solver]]) – Identifier for the solver. Can be a string name, integer code, SolverType enum, or a Solver subclass.

  • backend (str, optional) – Numerical backend to use (e.g., “numpy”, “scipy”, “jax”). Default is “default”.

  • sigma (Optional[float], optional) – Shift parameter for the solver, if applicable. Default is None.

  • is_gram (bool, optional) – Whether to treat the problem as a Gram matrix problem. Default is False.

  • default_precond (Optional[Preconditioner], optional) – Default preconditioner to use if none is specified. Default is None.

  • **kwargs – Additional keyword arguments to pass to the solver constructor.

Returns:

An instance of the selected solver class.

Return type:

Solver

Examples

>>> solver = choose_solver("CG", backend="numpy", sigma=0.1)
>>> solver = choose_solver(SolverType.MINRES, is_gram=True)
>>> solver = choose_solver(MyCustomSolverClass, custom_param=42)

Initial-value ODE integrators with NumPy and optional JAX backends.

The module provides a compact interface for explicit one-step methods used in simulation loops where the caller owns time evolution state. Integrators return (y_next, dt, info) and do not impose a global solver object or trajectory storage.

Use choose_ode() for string-based construction or instantiate Euler, Heun, AdaptiveHeun, RK, and ScipyRK directly when full configuration is needed.

email : maxgrom97@gmail.com

class general_python.algebra.ode.IVP(backend: str = 'numpy', rhs_prefactor: float = 1.0, dt: float = 0.001)[source]

Bases: ABC

Abstract initial value problem solver interface.

step(f, t, y, \*\*rhs_args)[source]

Compute one integration step without modifying internal state.

update(y, h, f, t, \*\*rhs_args)[source]

Update and return new state given current y and step size h.

dt(h, i)[source]

Return the time step used (may depend on h or step index i).

xp

Array module (numpy or jax.numpy) selected by backend.

__init__(backend: str = 'numpy', rhs_prefactor: float = 1.0, dt: float = 0.001)[source]

Initialize the ODE solver with a specified backend.

Parameters:
  • backend (str) – Backend to use for numerical operations (‘numpy’ or ‘jax’).

  • rhs_prefactor (float) – Prefactor for the right-hand side of the ODE.

dt(h: float = 0.0, i: int = 0) float[source]

Return the step size used for step index i.

Parameters:
  • h – Optional external step-size proposal. Fixed-step integrators ignore it and return their configured dt.

  • i – Optional integration step index.

set_dt(dt: float)[source]

Set the time step for the integration.

Parameters:

dt (float) – The new time step to set.

abstractmethod step(f, t: float, y, **rhs_args)[source]

Compute one integration step without modifying internal state. This method should be implemented by subclasses.

update(y, h: float, f, t: float, **rhs_args)[source]

Advance y by one step and return only the updated state.

This convenience method discards the dt and auxiliary info returned by step(). Subclasses can override it if they need custom update semantics.

property order: int

Return the order of the integration method. This method should be implemented by subclasses.

property is_jax: bool

Check if the backend is JAX.

Returns:

True if the backend is JAX, False otherwise.

Return type:

bool

property is_numpy: bool

Check if the backend is NumPy.

Returns:

True if the backend is NumPy, False otherwise.

Return type:

bool

__repr__()[source]

Return a string representation of the IVP object.

__str__()[source]

Return a string representation of the IVP object.

__call__(f, t, y, **rhs_args)[source]

Call the step method to compute one integration step.

Parameters:
  • f (callable) – Function representing the right-hand side of the ODE.

  • t (float) – Current time.

  • y (array-like) – Current state.

  • **rhs_args (keyword arguments) – Additional arguments to pass to the function f.

Returns:

  • yout (array-like) – New state after one integration step.

  • dt (float) – The step size used.

__len__()[source]

Return the length of the IVP object.

class general_python.algebra.ode.Euler(dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Bases: IVP

Simple forward Euler integrator.

Parameters:
  • dt (float) – Fixed step size for the integration.

  • backend (str) – ‘numpy’ or ‘jax’

__init__(dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Initializes the object with a specified time step and computational backend.

Parameters:
  • dt (float, optional) – The time step to use for the ODE solver. Defaults to 1e-3.

  • backend (str, optional) – The computational backend to use (e.g., ‘numpy’). Defaults to ‘numpy’.

  • rhs_prefactor (float, optional) – A prefactor to multiply with the right-hand side of the ODE. Defaults to 1.0.

step(f, t: float, y, **rhs_args)[source]

Compute one Euler step: y_{n+1} = y_n + Delta t * f(y_n, t).

Returns:

  • yout – New state after one Euler step.

  • dt – The step size used.

__repr__()[source]

Return a string representation of the Euler object.

class general_python.algebra.ode.Heun(dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Bases: IVP

Second-order Heun (explicit trapezoidal) integrator.

Parameters:
  • dt (float) – Fixed step size delta t (can be adapted externally).

  • backend (str) – ‘numpy’ or ‘jax’

__init__(dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Initialize the ODE solver with a specified backend.

Parameters:
  • backend (str) – Backend to use for numerical operations (‘numpy’ or ‘jax’).

  • rhs_prefactor (float) – Prefactor for the right-hand side of the ODE.

step(f, t: float, y, **rhs_args)[source]

Compute one Heun step: .. math:

y_{n+1} = y_n + \frac{\Delta t}{2} \left( f(y_n, t) + f(y_n + \Delta t f(y_n, t), t + \Delta t) \right)

where \(\Delta t\) is the time step. This is a second-order Runge-Kutta method.

Parameters:
  • f (callable) – Function representing the right-hand side of the ODE.

  • t (t) # slope at) – Current time.

  • y (array-like) – Current state.

  • f(y (>>> k0 =)

  • t

  • step (>>> yout = y + (dt / 2) * (k0 + k1) # corrector)

  • f(y_pred (>>> k1 =)

  • dt (t + dt) # slope at t +)

  • step

Return type:

yout and dt.

__repr__()[source]

Return a string representation of the Heun object.

class general_python.algebra.ode.AdaptiveHeun(dt: float = 0.001, tol: float = 1e-08, max_step: float = 1.0, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Bases: IVP

Adaptive second-order Heun integrator with error control.

Parameters:
  • dt (float) – Initial time step delta t.

  • tol (float) – Error tolerance.

  • max_step (float) – Maximum allowed time step.

  • backend (str) – ‘numpy’ or ‘jax’

__init__(dt: float = 0.001, tol: float = 1e-08, max_step: float = 1.0, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Initialize the ODE solver with a specified backend.

Parameters:
  • backend (str) – Backend to use for numerical operations (‘numpy’ or ‘jax’).

  • rhs_prefactor (float) – Prefactor for the right-hand side of the ODE.

step(f, t: float, y, norm_fun=None, **rhs_args)[source]

Perform one adaptive Heun step with local error control.

Parameters:
  • f – Right-hand side callable.

  • t – Current time.

  • y – Current state.

  • norm_fun – Optional norm function used for the local error estimate. Defaults to the active backend’s vector norm.

  • **rhs_args – Extra keyword arguments forwarded to f.

Returns:

(y_next, dt_used, info) where info contains the last right-hand-side metadata returned by f.

Return type:

tuple

__repr__()[source]

Return a string representation of the AdaptiveHeun object.

class general_python.algebra.ode.RK(a: list, b: list, c: list, dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Bases: IVP

General explicit Runge-Kutta integrator with arbitrary Butcher tableau.

Parameters:
  • a (array-like of shape (s, s)) – Lower-triangular matrix of stage coefficients.

  • b (array-like of length s) – Weights for final combination.

  • c (array-like of length s) – Nodes (c = sum of rows of a).

  • dt (float) – Fixed step size dt.

  • backend (str) – ‘numpy’ or ‘jax’

Notes

For common orders, use the from_order classmethod.

__init__(a: list, b: list, c: list, dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Initialize the ODE solver with a specified backend.

Parameters:
  • backend (str) – Backend to use for numerical operations (‘numpy’ or ‘jax’).

  • rhs_prefactor (float) – Prefactor for the right-hand side of the ODE.

property order: int

Return the number of stages in the configured Butcher tableau.

classmethod from_order(order: int, dt: float = 0.001, backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Create a Runge-Kutta method instance from a specified order. :param order: The order of the Runge-Kutta method. Supported values are 1 (Euler), 2 (RK2), and 4 (RK4). :type order: int :param dt: The time step size. Defaults to 1e-3. :type dt: float, optional :param backend: The computational backend to use (e.g., ‘numpy’). Defaults to ‘numpy’. :type backend: str, optional

Returns:

An instance of the class initialized with the appropriate Butcher tableau for the specified order.

Return type:

cls

Raises:

ValueError – If the specified order is not supported.

step(f, t: float, y, **rhs_args)[source]

Perform one Runge-Kutta step.

The right-hand side prefactor is treated as part of the ODE, dy/dt = a f(t, y), not as a final post-processing multiplier. This matters for complex real-time evolution in TDVP, where the stage states must be evaluated at

y_i = y_n + h a sum_{j < i} a_{ij} k_j

in order for RK2 / RK4 to integrate the same physical time variable as Euler and exact diagonalization.

Returns:

  • yout – new state

  • dt – step size used

__repr__()[source]

Return a string representation of the RK object.

class general_python.algebra.ode.ScipyRK(dt: float = 0.001, tol: float = 1e-06, max_step: float = None, method: str = 'RK45', backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Bases: IVP

Wrapper for scipy.integrate.solve_ivp with explicit Runge-Kutta methods.

Parameters:
  • dt (float) – Initial and maximum time step delta t.

  • tol (float) – Relative and absolute tolerance for solver.

  • max_step (float or None) – Maximum allowed step size; if None uses dt.

  • method (str) – Solver method: one of ‘RK45’, ‘RK23’, ‘DOP853’, ‘Radau’, ‘BDF’, ‘LSODA’.

  • backend (str) – ‘numpy’ or ‘jax’. JAX backend is not supported and will fallback to NumPy.

__init__(dt: float = 0.001, tol: float = 1e-06, max_step: float = None, method: str = 'RK45', backend: str = 'numpy', rhs_prefactor: float = 1.0)[source]

Initialize the ODE solver with a specified backend.

Parameters:
  • backend (str) – Backend to use for numerical operations (‘numpy’ or ‘jax’).

  • rhs_prefactor (float) – Prefactor for the right-hand side of the ODE.

step(f, t: float, y, **rhs_args)[source]

Perform one adaptive step using solve_ivp over [t, t + dt].

Returns:

  • yout (ndarray) – State at end of interval.

  • dt_actual (float) – Actual time step taken.

__repr__()[source]

Return a string representation of the ScipyRK object.

class general_python.algebra.ode.OdeTypes[source]

Bases: object

Enum-like class for ODE types.

EULER = 'euler'
HEUN = 'heun'
RK2 = 'rk2'
RK4 = 'rk4'
ADAPTIVE = 'adaptive'
SCIPY = 'scipy'
general_python.algebra.ode.choose_ode(ode_type: str | int | OdeTypes, *, dt: float = 0.1, rhs_prefactor: float = 1.0, backend: Any = 'numpy', **kwargs) IVP[source]

Choose an ODE solver based on the specified type.

Parameters:
  • ode_type (str or int) – Type of ODE solver to use.

  • dt (float, optional) – Time step size. Default is 1e-1.

  • backend (str, optional) – Computational backend to use. Default is ‘numpy’.

  • **kwargs (keyword arguments) – Additional arguments for the ODE solver.

Returns:

An instance of the selected ODE solver.

Return type:

IVP

file: general_python/algebra/preconditioners.py author: Maksymilian Kliczkowski

This module contains the implementation of preconditioners for iterative solvers of linear systems Ax = b. Preconditioners transform the system into an equivalent one,

M^{-1}Ax = M^{-1}b (left preconditioning)

or

A M^{-1} y = b, x = M^{-1}y (right preconditioning),

where M is the preconditioner matrix.

The goal is for the transformed system to have more favorable spectral properties (e.g., eigenvalues clustered around 1, lower condition number), leading to faster convergence of iterative methods like CG, MINRES, GMRES. The matrix M should approximate A in some sense, while the operation M^{-1}r should be computationally inexpensive.

Date : 2025-02-02 Version : 1.0

general_python.algebra.preconditioners.preconditioner_idn(r: ndarray | Array) ndarray | Array[source]

Identity function for preconditioner apply.

Parameters:

r (Array) – The input array.

Returns:

The same input array.

Return type:

Array

class general_python.algebra.preconditioners.PreconditionersType(*values)[source]

Bases: Enum

Enumeration of the symmetry type of preconditioners.

SYMMETRIC = 1
NONSYMMETRIC = 2
class general_python.algebra.preconditioners.PreconditionersTypeSym(*values)[source]

Bases: Enum

Enumeration of specific symmetric preconditioner types.

IDENTITY = 0
JACOBI = 1
INCOMPLETE_CHOLESKY = 2
COMPLETE_CHOLESKY = 3
SSOR = 4
class general_python.algebra.preconditioners.PreconditionersTypeNoSym(*values)[source]

Bases: Enum

Enumeration of specific potentially non-symmetric preconditioner types.

IDENTITY = 0
INCOMPLETE_LU = 1
class general_python.algebra.preconditioners.Preconditioner(is_positive_semidefinite=False, is_gram=False, backend='default', apply_func: Callable[[ndarray | Array], ndarray | Array] | None = None, **kwargs)[source]

Bases: ABC

Abstract base class for preconditioners M used in iterative solvers.

Provides a framework for setting up the preconditioner based on a matrix A (or its factors S, Sp for Gram matrices) and applying the inverse operation M^{-1}r efficiently. Supports different computational backends (NumPy, JAX).

is_positive_semidefinite

Indicates if the original matrix A (and potentially M) is assumed to be positive semi-definite. Important for methods like Cholesky.

Type:

bool

is_gram

True if the preconditioner setup uses factors S, Sp such that A = Sp @ S / N.

Type:

bool

sigma

Regularization parameter sigma added during setup, effectively forming M based on A + sigma*I.

Type:

float

type

The specific type of the preconditioner (e.g., JACOBI, ILU). Set by subclass.

Type:

Enum

stype

Symmetry type (SYMMETRIC/NONSYMMETRIC).

Type:

PreconditionersType

backend_str

The name of the current backend (‘numpy’, ‘jax’).

Type:

str

__init__(is_positive_semidefinite=False, is_gram=False, backend='default', apply_func: Callable[[ndarray | Array], ndarray | Array] | None = None, **kwargs)[source]

Initialize the preconditioner.

Parameters:
  • is_positive_semidefinite (bool) – True if the matrix is positive semidefinite.

  • is_gram (bool) – True if the preconditioner setup uses factors S, Sp such that A = Sp @ S / N.

  • backend (optional) – The computational backend to be used by the preconditioner.

  • apply_func (PreconitionerApplyFun, optional) – The apply function for the preconditioner.

log(msg: str, log: int | str = 'info', lvl: int = 0, color: str = 'white', append_msg=True)[source]

Log the message.

Parameters:
  • msg (str) – The message to log.

  • log (Union[int, str]) – The flag to log the message (default is ‘info’).

  • lvl (int) – The level of the message.

  • color (str) – The color of the message.

  • append_msg (bool) – Flag to append the message.

reset_backend(backend: str)[source]

Resets the backend and recompiles the internal apply function.

Parameters:

backend (str) – The name of the new backend (‘numpy’, ‘jax’).

get_apply() Callable[[ndarray | Array], ndarray | Array][source]

Returns the potentially JIT-compiled function apply(r). Uses precomputed data stored by the last call to set(). Raises RuntimeError if called before set() to match test expectations.

get_apply_mat(**default_setup_kwargs) Callable[[ndarray | Array, ndarray | Array, float], ndarray | Array][source]

Returns a potentially JIT-compiled function apply_mat(r, A, sigma, **override_kwargs) that computes preconditioner data from A and applies it on the fly.

Params:
**default_setup_kwargs:

Default keyword arguments for the setup kernel (e.g., tol_small for Jacobi). These are fixed at compile time if using JIT.

Returns:

The compiled function.

Return type:

Callable

get_apply_gram(**default_setup_kwargs) Callable[[ndarray | Array, ndarray | Array, ndarray | Array, float], ndarray | Array][source]

Returns a potentially JIT-compiled function apply_gram(r, S, Sp, sigma, **override_kwargs) that computes preconditioner data from S, Sp and applies it on the fly.

Params:

**default_setup_kwargs: Default keyword arguments for the setup kernel.

Returns:

The compiled function.

Return type:

Callable

Note

For JAX compatibility, kwargs are frozen at function creation time. The returned function does NOT accept runtime kwargs to avoid dictionary operations inside JIT-traced code.

property name: str

Name of the preconditioner.

property dcol: str

Color for logging messages.

property backend_str: str

Name of the current backend (‘numpy’, ‘jax’).

property backend: Any

The backend module (e.g., np, jnp).

property backends: Any

The backend module for scipy-like operations.

property isjax: bool

True if using JAX backend.

property precomputed_data: dict

Returns empty dict as no precomputed data is needed.

property type: PreconditionersTypeNoSym | PreconditionersTypeSym | None

Specific preconditioner type (e.g., JACOBI, ILU).

property stype: PreconditionersType

Symmetry type (SYMMETRIC/NONSYMMETRIC).

property is_positive_semidefinite: bool

True if the matrix A (and potentially M) is positive semidefinite.

property is_gram: bool

True if the preconditioner is set up from Gram matrix factors S, Sp.

property sigma

Regularization parameter.

property tol_big

Tolerance for big values.

property tol_small

Tolerance for small values.

property zero

Value treated as zero.

set(a: ndarray | Array, sigma: float = 0.0, ap: ndarray | Array | None = None, backend: str | None = None, **kwargs)[source]

Sets up the preconditioner using the provided matrix A and optional parameters. This method computes the preconditioner data and prepares the apply function.

Params:
a (Array):

The matrix to be used for setting up the preconditioner.

sigma (float, optional):

The regularization parameter. Defaults to 0.0.

ap (Optional[Array], optional):

An optional second matrix for Gram setup. Defaults to None.

backend (Optional[str], optional):

The backend to use for computations. Defaults to None.

**kwargs:

Additional keyword arguments for specific implementations.

__call__(r: ndarray | Array) ndarray | Array[source]

Apply the configured preconditioner instance M^{-1} to vector r using precomputed data.

__repr__() str[source]

Returns the name and configuration of the preconditioner.

__str__() str[source]

Returns the name of the preconditioner.

class general_python.algebra.preconditioners.IdentityPreconditioner(backend: str = 'default')[source]

Bases: Preconditioner

Identity preconditioner M = I. Applying M^{-1} simply returns the input vector.

This is the simplest preconditioner and has no effect on the system. It serves as a baseline or placeholder.

Math:

M = I M^{-1}r = I^{-1}r = r

__init__(backend: str = 'default')[source]

Initialize the Identity preconditioner.

Parameters:

backend (str) – The computational backend (‘numpy’, ‘jax’, ‘default’).

static apply(r: ndarray | Array, backend_mod: Any, sigma: float = 0.0, **precomputed_data: Any) ndarray | Array[source]

Static apply convenience wrapper used by tests. Mirrors the signature expected in test files.

__repr__() str[source]

Returns the name and configuration of the Identity preconditioner.

class general_python.algebra.preconditioners.JacobiPreconditioner(is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default', tol_small: float = 1e-13, zero_replacement: float = 10000000000000.0, **kwargs)[source]

Bases: Preconditioner

Jacobi (Diagonal) Preconditioner. M = diag(A + sigma*I).

Uses the diagonal of the (potentially regularized) matrix A as the preconditioner M. Applying the inverse M^{-1}r involves element-wise division by the diagonal entries.

Math:

M = diag(A) + sigma*I = D + sigma*I M^{-1}r = (D + sigma*I)^{-1} r = [1 / (A_ii + sigma)] * r_i

References

  • Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10.

__init__(is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default', tol_small: float = 1e-13, zero_replacement: float = 10000000000000.0, **kwargs)[source]

Initialize the Jacobi preconditioner.

Parameters:
  • is_positive_semidefinite (bool) – If A is positive semi-definite.

  • is_gram (bool) – If setting up from Gram matrix factors.

  • backend (str) – The computational backend.

  • tol_small (float) – Values on diagonal smaller than this (in magnitude) after regularization are considered zero.

  • zero_replacement (float) – Value used to replace division by zero (effectively setting the result component to zero, 1 / large_number -> 0).

static apply(r: ndarray | Array, backend_mod: Any, sigma: float = 0.0, **precomputed_data: Any) ndarray | Array[source]

Static apply wrapper matching test signature.

__repr__() str[source]

Returns the name and configuration of the Jacobi preconditioner.

property zero_replacement: float

Value substituted for unsafe diagonal entries in Jacobi setup.

class general_python.algebra.preconditioners.CholeskyPreconditioner(backend: str = 'default')[source]

Bases: Preconditioner

Cholesky Preconditioner using complete Cholesky decomposition.

Suitable for symmetric positive-definite matrices A. The preconditioner M is defined by the Cholesky factorization of the regularized matrix: M = L @ L.T (or L @ L.conj().T for complex), where L is the lower Cholesky factor of (A + sigma*I).

Applying the inverse M^{-1}r involves solving two triangular systems: 1. Solve L y = r for y (forward substitution) 2. Solve L.T z = y for z (backward substitution)

(or L.conj().T z = y for complex)

Note

This performs a complete Cholesky factorization. For incomplete Cholesky (suitable for large sparse matrices), specialized routines (often from sparse linear algebra libraries) are required.

References

  • Golub, G. H., & Van Loan, C. F. (2013). Matrix Computations (4th ed.). JHU Press. Chapter 4.

  • Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10.

__init__(backend: str = 'default')[source]

Initialize the Cholesky preconditioner.

Parameters:

backend (str) – The computational backend (‘numpy’, ‘jax’, ‘default’).

__repr__() str[source]

Returns a string representation of the Cholesky preconditioner, including its status.

Returns:

A string indicating whether the preconditioner is factorized or not.

Return type:

str

class general_python.algebra.preconditioners.SSORPreconditioner(omega: float = 1.0, is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default', tol_small: float = 1e-13, zero_replacement: float = 10000000000000.0)[source]

Bases: Preconditioner

Symmetric Successive Over-Relaxation (SSOR) Preconditioner.

Suitable for symmetric matrices A, particularly those that are positive definite. It involves forward and backward Gauss-Seidel-like sweeps with a relaxation parameter omega.

Math:

Let A = D + L + U, where D is diag(A), L is strictly lower part, U is strictly upper part. The SSOR preconditioner matrix M is defined implicitly by its application M^{-1}r: 1. Solve (D/omega + L) y = r (Forward sweep) 2. Solve (D/omega + U) z = D y / omega (Backward sweep - check formulation)

A common formulation for the inverse application M^{-1}r is: M^{-1} = omega * (2 - omega) * (D/omega + U)^{-1} D (D/omega + L)^{-1} Applying M^{-1}r involves solving the two triangular systems mentioned above.

The choice of omega (0 < omega < 2) is crucial for performance. omega=1 gives Symmetric Gauss-Seidel (SGS).

References

  • Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10.

  • Axelsson, O. (1996). Iterative Solution Methods. Cambridge University Press.

__init__(omega: float = 1.0, is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default', tol_small: float = 1e-13, zero_replacement: float = 10000000000000.0)[source]

Initialize the SSOR preconditioner.

Parameters:
  • omega (float) – Relaxation parameter (0 < omega < 2). Default is 1.0 (SGS).

  • is_positive_semidefinite (bool) – If A is assumed positive semi-definite.

  • is_gram (bool) – If setting up from Gram matrix factors (less common for SSOR).

  • backend (str) – The computational backend.

  • tol_small (float) – Tolerance for safe diagonal inversion.

  • zero_replacement (float) – Value for safe diagonal inversion.

property omega: float

Relaxation parameter omega (0 < omega < 2).

class general_python.algebra.preconditioners.IncompleteCholeskyPreconditioner(backend: str = 'default')[source]

Bases: Preconditioner

Incomplete Cholesky Preconditioner (Approximation using ILU(0)).

Suitable for large, sparse, symmetric positive-definite matrices A. This implementation uses SciPy’s sparse ILU(0) decomposition as a proxy for IC(0), as a direct IC(0) is not available in SciPy/NumPy/JAX. ILU(0) computes factors L and U such that L @ U approximates A, maintaining the sparsity pattern of A (zero fill-in).

Applying the inverse M^{-1}r involves solving L y = P r and U z = y, where P is a permutation matrix handled internally by the ILU object.

Note: - Requires SciPy and the NumPy backend. Not JAX compatible. - Designed for sparse matrices. Dense input matrices will be converted. - Uses ILU(0) factorization as a proxy for IC(0).

Args for set:

fill_factor (float): See scipy.sparse.linalg.spilu. Default 1. drop_tol (float): See scipy.sparse.linalg.spilu. Default None.

References

  • Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10 (discusses both IC and ILU).

  • SciPy documentation for scipy.sparse.linalg.spilu.

__init__(backend: str = 'default')[source]

Initialize the Incomplete Cholesky (ILU Proxy) preconditioner.

Parameters:

backend (str) – The computational backend. Must be ‘numpy’.

property fill_factor: float

The fill factor for ILU(0) (default 1.0).

property drop_tol: float | None

The drop tolerance for ILU(0) (default None).

static apply(r: ndarray | Array, sigma: float, backend_module: Any, ilu_obj: SuperLU | None) ndarray | Array[source]

Static apply method for ILU: solves Mz = r using the LU factors.

Parameters:
  • r (Array) – The residual vector.

  • sigma (float) – Regularization (used during setup).

  • backend_module (Any) – The backend numpy module (must be numpy).

  • ilu_obj (Optional[SuperLU]) – The precomputed ILU object from spilu.

Returns:

The preconditioned vector M^{-1}r, or r if ilu_obj is None.

Return type:

Array

__repr__() str[source]

Returns the name and configuration of the Incomplete Cholesky preconditioner.

class general_python.algebra.preconditioners.ILUPreconditioner(is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default')[source]

Bases: Preconditioner

Incomplete LU Preconditioner (ILU(0)).

Suitable for large, sparse, non-symmetric matrices A. This implementation uses SciPy’s sparse ILU(0) decomposition (scipy.sparse.linalg.spilu). ILU(0) computes factors L and U such that L @ U approximates A, maintaining the sparsity pattern of A (zero fill-in).

Applying the inverse M^{-1}r involves solving L y = P r and U z = y, where P is a permutation matrix handled internally by the ILU object.

Note: - Requires SciPy and the NumPy backend. Not JAX compatible. - Designed for sparse matrices. Dense input matrices will be converted.

Args for set:

fill_factor (float): See scipy.sparse.linalg.spilu. Default 1. drop_tol (float): See scipy.sparse.linalg.spilu. Default None.

References

  • Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10.

  • SciPy documentation for scipy.sparse.linalg.spilu.

__init__(is_positive_semidefinite: bool = False, is_gram: bool = False, backend: str = 'default')[source]

Initialize the Incomplete LU preconditioner.

Parameters:
  • is_positive_semidefinite (bool) – If A is assumed positive semi-definite (not required for ILU).

  • is_gram (bool) – If setting up from Gram matrix factors.

  • backend (str) – The computational backend. Must be ‘numpy’.

property fill_factor: float

The fill factor for ILU(0) (default 1.0).

property drop_tol: float | None

The drop tolerance for ILU(0) (default None).

static apply(r: ndarray | Array, sigma: float, backend_module: Any, ilu_obj: SuperLU | None) ndarray | Array[source]

Static apply method for ILU: solves Mz = r using the LU factors.

Parameters:
  • r (Array) – The residual vector.

  • sigma (float) – Regularization (used during setup).

  • backend_module (Any) – The backend numpy module (must be numpy).

  • ilu_obj (Optional[SuperLU]) – The precomputed ILU object from spilu.

Returns:

The preconditioned vector M^{-1}r, or r if ilu_obj is None.

Return type:

Array

__repr__() str[source]

Returns the name and configuration of the Incomplete LU preconditioner.

general_python.algebra.preconditioners.choose_precond(precond_id: Any, **kwargs) Preconditioner[source]

Factory function to select and instantiate a preconditioner.

Accepts various identifiers (Enum, str, int, instance) and passes kwargs to the specific preconditioner’s constructor.

Parameters:
  • precond_id (Any) – Identifier (instance, Enum, str, int).

  • **kwargs – Additional arguments for the constructor (e.g., backend=’jax’).

Returns:

An instance of the selected preconditioner.

Return type:

Preconditioner

Common Utilities

Shared utilities for IO, plotting, logging, and runtime helpers.

The common package contains cross-cutting infrastructure used by all major subpackages, including directory management, plotting, HDF5 helpers, timers, and low-level binary utilities.

Input/output contracts

Utilities are intentionally lightweight: they accept standard Python objects or NumPy-compatible arrays and return plain Python or NumPy outputs where possible. Plotting helpers expect Matplotlib-compatible inputs.

Determinism and stability

Most helpers are deterministic for fixed inputs. Logging or timing output may vary with runtime scheduling. Binary helper functions operate on integer bit patterns and are deterministic by construction.

Submodules are loaded lazily on first access to reduce import overhead.

class general_python.common.Directories(*parts: str | Path)[source]

Bases: object

Class representing a directory handler - static methods are represented with camel case - class methods are represented with underscore

__init__(*parts: str | Path) None[source]

Initialize with one or more path components. >>> d = Directories(“foo”, “bar”) # -> Path(“foo/bar”)

__len__() int | str[source]

Return the number of items in the directory if it is a directory,

__add__(other: str | Path) Directories[source]

Concatenate with another path component. >>> d = Directories(“foo”) + “bar” # -> Path(“foo/bar”)

__iadd__(other: str | Path) Directories[source]

In-place concatenation with another path component. >>> d = Directories(“foo”); d += “bar” # -> Path(“foo/bar”)

__radd__(other: str | Path) Directories[source]

Concatenate with another path component. >>> d = “foo” + Directories(“bar”) # -> Path(“foo/bar”)

__truediv__(other: str | Path) Directories[source]

Concatenate with another path component using / operator. >>> d = Directories(“foo”) / “bar” # -> Path(“foo/bar”)

__rtruediv__(other: str | Path) Directories[source]

Concatenate with another path component using / operator. >>> d = “foo” / Directories(“bar”) # -> Path(“foo/bar”)

__iter__() Iterator[Path][source]

Iterate over parts of the path. >>> d = Directories(“foo/bar”) # -> iterates over [“foo”, “bar”]

__eq__(other: str | Path) bool[source]

Check equality with another path component. >>> d = Directories(“foo”) == “foo” # -> True

__ne__(other: str | Path) bool[source]

Check inequality with another path component. >>> d = Directories(“foo”) != “bar” # -> True

__hash__() int[source]

Hash the path for use in sets or dictionaries. >>> d = Directories(“foo”) # -> hash(Path(“foo”))

__repr__() str[source]

Return a string representation of the path. >>> d = Directories(“foo”) # -> “Directories(‘foo’)”

__str__() str[source]

Return a string representation of the path. >>> d = Directories(“foo”) # -> “foo”

static f_h5(p: List[Path]) List[str][source]

Filter for .h5 files.

static f_csv(p: List[Path]) List[str][source]

Filter for .csv files.

static f_nonempty(p: List[Path]) List[str][source]

Filter for non-empty files.

static f_contains(substr: str) Callable[[Path], bool][source]

Return a filter that checks if the filename contains a substring.

join(*parts: str | Path, create: bool = False) Directories[source]

Return a new Directories for self/path joined with parts. If create=True, mkdir(parents=True, exist_ok=True) is called.

property parent: Directories

Return Directories for parent directory (..).

classmethod win(raw: str) Directories[source]

Parse a Windows-style backslash path into Directories.

format(*args, **kwargs) Directories[source]

Format the path using str.format() and return a new Directories. >>> d = Directories(“foo”).format(“bar”) # -> Path(“foo/bar”)

resolve() Directories[source]

Return a new Directories with the absolute resolved path.

endswith(suffix: str) bool[source]

Check if the path ends with the given suffix.

mkdir(parents: bool = True, exist_ok: bool = True) Directories[source]

Create this directory on disk. Returns self for chaining.

static mkdirs(paths: Iterable[str | Path], parents: bool = True, exist_ok: bool = True) None[source]

Create multiple directories.

list_files(*, include_empty: bool = True, filters: List[Callable[[Path], bool]] = None, sort_key: Callable[[Path], any] | None = None) List[Path][source]

List files (not directories) in this directory. - include_empty : if False, skip files of size zero. - filters : a list of callables Path->bool; all must pass. - sort_key : key function for sorting.

list_dirs(*, include_empty: bool = True, include_hidden: bool = True, relative: bool = False, as_string: bool = False, filters: List[Callable[[Path], bool]] = [], sort_key: Callable[[Path], Any] | None = None) List[Path][source]

List directories in this directory.

Parameters:
  • include_empty (bool) – if False, skip empty directories. If True, include all directories. This checks only if the directory has any entries, not if they are files or directories.

  • filters (list of callables Path -> bool) – A list of callables; all must return True for a directory to be included.

  • sort_key (callable, optional) – Key function for sorting the results.

static list_data_roots(base: str | Path, *, sort: bool = True, as_dirs: bool = True) List[Directories] | List[Path][source]

List all first-level directories inside base…

Parameters:
  • base (PathLike) – Root directory (e.g. data_path)

  • sort (bool) – Sort lexicographically (useful for YYYYMMDD)

  • as_dirs (bool) – Return Directories objects instead of Path

Return type:

List of directories

static expand_data_roots(base: str | Path, *subpath: str | Path, require_exist: bool = True) List[Directories][source]

Expand a relative subpath across all first-level directories.

Parameters:
  • base (PathLike) – Root directory (e.g. data_path)

  • subpath (PathLike) – Relative path to append to each root (e.g. hamil/occ/ns/sp)

  • require_exist (bool) – If True, only include paths that exist on disk.

Example

expand_data_roots(data_path, ‘data’, hamil, …, ‘sp’)

Returns list of:

base/<date>/data/…/sp

static collect_files(dirs: List[Directories], *, prefix: str = None, suffix: str = None, filters: List[Callable[[Path], bool]] = None, sort: bool = False) List[Path][source]

Collect files from multiple directories.

Parameters:
  • dirs – list of Directories

  • prefix – optional filename prefix filter

  • suffix – optional filename suffix filter

  • filters – additional filters (Path -> bool)

  • sort – global sorting

Return type:

Flat list of Paths

clear_empty() List[Path][source]

Remove all zero-length files in this directory. Returns list of files left after removal.

walk() Iterator[Path][source]

Walk the directory tree and yield all files.

glob(pattern: str) List[Path][source]

Return a list of all files matching the pattern in this directory.

random_file(condition: ~typing.Callable[[~pathlib.Path], bool] = <function Directories.<lambda>>) Path[source]

Return a random Path in this directory satisfying condition. Raises ValueError if none match.

copy_files(dest: str | Path, condition: Callable[[Path], bool], overwrite: bool = False) None[source]

Copy all files satisfying condition() from self to dest. Creates dest if needed.

Parameters:
  • dest (PathLike) – Destination directory.

  • condition (Callable[[Path], bool]) – Function that takes a Path and returns True if the file should be copied.

  • overwrite (bool, optional) – If True, overwrite existing files in the destination directory. Default is False.

transfer_files(dest: str | Path, condition: Callable[[Path], bool]) None[source]

Move all files satisfying condition() from self to dest. Creates dest if needed.

property exists: bool

Check if the path exists.

property as_path: Path

Return the path as a Path object.

property is_empty: bool

Check if the directory is empty.

property is_dir: bool

Check if the path is a directory.

property is_file: bool

Check if the path is a file.

Check if the path is a symlink.

property size: int

Return the size of the directory in bytes.

property size_human: str

Return the size of the directory in a human-readable format.

property disk_usage: str

Return the disk usage of the directory in a human-readable format.

property checksum: str

Return the checksum of the directory.

static temp_dir(prefix: str = 'tmp') Directories[source]

Create and return a temporary directory with the given prefix.

current = Directories(PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/general-python/checkouts/latest/docs'))[source]
home = Directories(PosixPath('/home/docs'))[source]
root = Directories(PosixPath('/'))[source]
static from_env(var_name: str) Directories | None[source]

Create a Directories object from an environment variable. Returns None if the variable is not set or the path does not exist.

static from_config(config: dict, key: str) Directories | None[source]

Create a Directories object from a configuration dictionary. Returns None if the key is not found or the path does not exist.

static from_string(s: str) Directories[source]

Create a Directories object from a string path.

static from_parts(*parts: str | Path) Directories[source]

Create a Directories object from multiple path components.

static from_path(p: str | Path) Directories[source]

Create a Directories object from a Path-like object.

class general_python.common.Plotter(default_cmap='viridis', font_size=12, dpi=200)[source]

Bases: object

Publication-quality plotting utilities for scientific computing.

This class provides static methods for creating, customizing, and saving Matplotlib figures suitable for scientific journals (Nature, Science, PRL, etc.).

All methods are @staticmethod, so you can use them without instantiation:

>>> Plotter.plot(ax, x, y, color='C0', label='Data')
>>> Plotter.set_legend(ax, style='publication')

Main Categories

Plotting Methods : plot, scatter, tripcolor_field, semilogy, semilogx, loglog, errorbar, fill_between, histogram Axis Setup : set_ax_params, set_tickparams, setup_log_x, setup_log_y Annotations : set_annotate, set_annotate_letter, set_arrow Colorbars : add_colorbar, get_colormap, discrete_colormap Layouts : get_subplots, get_grid, get_inset Legends : set_legend, set_legend_custom Saving : save_fig, savefig

For full documentation, call: Plotter.help()

__init__(default_cmap='viridis', font_size=12, dpi=200)[source]

Initialize the Plotter with default parameters.

Parameters:
  • default_cmap (str, default='viridis') – Default colormap for heatmaps and colorbars.

  • font_size (int, default=12) – Default font size for labels and text.

  • dpi (int, default=200) – Resolution for rasterized output (PNG, TIFF).

Note

Most methods are @staticmethod and don’t require instantiation. Use the class directly: Plotter.plot(ax, x, y)

ax_off(ax: Axes | List[Axes])[source]

Completely turn off the axis (no ticks, labels, spines, data).

static ax(ax: Axes | List[Axes], *args, **kwargs)[source]

Alias for ax method to allow direct calls.

static disable(axes, target, *, warn: bool = False, hide: bool = True, reason: str | None = None) AxesList[source]

Convenience wrapper for AxesList.disable.

Parameters:
  • axes (AxesList or list-like of axes) – Axes container to operate on.

  • target (selector) – Forwarded to AxesList.disable().

  • warn (bool, default=False) – Emit warnings when disabled axes are used.

  • hide (bool, default=True) – Hide underlying axes before disabling.

  • reason (str, optional) – Optional warning context.

Returns:

The modified axes list.

Return type:

AxesList

static help(topic: str = None)[source]

Print help information about available plotting methods.

Parameters:

topic (str, optional) – Specific topic to get help on. Options: - ‘plot’: Basic plotting methods - ‘axis’: Axis configuration - ‘color’: Colors and colorbars - ‘layout’: Subplots and grids - ‘save’: Saving figures - None: Print overview of all topics

Examples

>>> Plotter.help()           # Overview
>>> Plotter.help('plot')     # Plotting methods
>>> Plotter.help('axis')     # Axis configuration
static plot_style(**kwargs) PlotStyle[source]

Return a PlotStyle config instance.

static kspace_config(**kwargs) KSpaceConfig[source]

Return a KSpaceConfig instance.

static kpath_config(**kwargs) KPathConfig[source]

Return a KPathConfig instance.

static spectral_config(**kwargs) SpectralConfig[source]

Return a SpectralConfig instance.

static figure_config(**kwargs) FigureConfig[source]

Return a FigureConfig instance.

static plotters()[source]

Expose the general_python.common.plotters package.

static statistical_fitter()[source]

Backward-compatible alias for fitter().

static fitter()[source]

Expose shared fitting/scaling helpers from general_python.maths.math_utils.

static math(label: str, *, auto_wrap: bool = True, escape_text: bool = True, **values: Any) str[source]

Build a LaTeX-ready math label from a template.

This method extends Python str.format-style placeholders with simple math filters for scientific labels.

Parameters:
  • label (str) – Template string. Standard format fields are supported, e.g. {J:.3g}, and can be combined with filters, e.g. {point|vec:.2f}, {kpoint|sym}. Greek-name values (for example Gamma or omega) are automatically rendered as LaTeX variables.

  • auto_wrap (bool, default=True) – If True, wrap the final string with $...$ when no dollar sign is present in the rendered output.

  • escape_text (bool, default=True) – If True, plain substituted strings are LaTeX-escaped by default. Use |raw or |tex to bypass escaping for a specific field.

  • **values (Any) – Values used by template fields.

  • filters (Supported)

  • -----------------

  • tex (raw /) – Insert value as-is (no escaping).

  • sym – Force symbol conversion for common Greek names (e.g. Gamma -> \Gamma).

  • num – Numeric formatting helper. Uses format spec if provided, otherwise .6g.

  • vec – Render iterable as \left(v_1, v_2, ...\right).

  • set – Render iterable as \left\{v_1, v_2, ...\right\}.

Returns:

Rendered LaTeX/mathtext-compatible label.

Return type:

str

Examples

>>> Plotter.math(r"\\langle S_i^z \\rangle = {value|num:.3e}", value=1.2e-4)
'$\\langle S_i^z \\rangle = 1.200e-04$'
>>> Plotter.math(r"{kx|sym}-{ky|sym} path, q={q|vec:.2f}", kx="Gamma", ky="K", q=[0, 1/3])
'$\\Gamma-K path, q=\\left(0, 0.33\\right)$'
>>> Plotter.math(r"E={expr|raw}", expr=r"E_0 + \\Delta")
'$E=E_0 + \\Delta$'
static ensure_list(x)[source]

Return x as a list-like container for axis utilities.

static unify_limits(axes, which='y')[source]

Set all axes to the shared x or y limits.

static resolve_planar_limits(points, *, limits: tuple | list | ndarray | None = None, x_limits: tuple | list | ndarray | None = None, y_limits: tuple | list | ndarray | None = None, xmin: float | None = None, xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, limit_to_pi: bool = False, pad_fraction: float = 0.08) Tuple[tuple, tuple][source]

Resolve visible (xlim, ylim) for planar data.

Parameters:
  • points (array-like) – Planar sample points shaped like (N, 2) or (N, D).

  • limits (sequence, optional) – Explicit bounds. Length 2 means shared (min, max) for both axes. Length 4 means (xmin, xmax, ymin, ymax).

  • x_limits (sequence, optional) – Explicit bounds for each axis separately.

  • y_limits (sequence, optional) – Explicit bounds for each axis separately.

  • xmin (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • xmax (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • ymin (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • ymax (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • limit_to_pi (bool, default=False) – If True and limits is not provided, use [-pi, pi] on both axes.

  • pad_fraction (float, default=0.08) – Relative padding applied when limits are inferred from points.

static markers()[source]

Markers with common options for line and scatter plots.

static markersC()[source]

Markers cycle with common options for line and scatter plots.

static colors()[source]

Colors with common options for line and scatter plots.

static linestyles()[source]

Linestyles with solid and dashed options.

static linestylesC()[source]

Linestyles cycle with solid and dashed options.

static linestylesCE()[source]

Extended linestyles cycle with more options for dashed lines.

static palette(name: str = 'tableau', n: int | None = None) List[str][source]

Return a named color palette as a list of hex strings.

Parameters:
  • name (str, default='tableau') –

    Palette name. Built-in options:

    Name

    Description

    wong

    Wong (2011) 8-color CBF palette (Nature Methods)

    okabe

    Okabe & Ito – identical to wong

    tol

    Paul Tol’s muted 10-color qualitative set

    tol_bright

    Paul Tol’s bright 7-color high-contrast set

    ibm

    IBM Carbon 5-color accessible palette

    colorblind

    seaborn colorblind 10-color cycle

    deep

    seaborn deep 10-color perceptual palette

    muted

    seaborn muted toned-down palette

    tableau

    Matplotlib default Tableau-10 cycle

    classic

    Matplotlib pre-2.0 default cycle

    nature

    Nature/BioRxiv warm editorial palette

    science

    Science journal-inspired muted palette

    pastel

    Soft pastel tones for presentations

    sunset

    9-stop cool→warm diverging gradient

  • n (int, optional) – Return exactly n colors. When n > len(palette) colors are repeated cyclically.

Returns:

Hex color strings.

Return type:

list[str]

Examples

>>> Plotter.palette('wong')           # 8 colorblind-safe colors
>>> Plotter.palette('nature', n=4)    # first 4 of nature palette
>>> Plotter.palette('deep', n=12)     # 12 colors (cycles)
static palette_cycle(name: str = 'tableau') cycle[source]

Return an infinite itertools.cycle over a named palette.

Parameters:

name (str) – Same keys as palette().

Examples

>>> cyc   = Plotter.palette_cycle('wong')
>>> color = next(cyc)
static colorsC(palette: str = 'tableau') cycle[source]

Return a color cycle for a named palette.

Alias for palette_cycle().

Examples

>>> cyc = Plotter.colorsC('wong')
>>> c1  = next(cyc)
static colorsN(n: int, palette: str = 'tableau') List[str][source]

Return exactly n colors from a named palette (cycling if needed).

Parameters:

Examples

>>> c5 = Plotter.colorsN(5, 'wong')
static set_color_cycle(ax, palette: str | List = 'tableau') None[source]

Set the Matplotlib color cycle on one or more axes.

This makes subsequent ax.plot(...) calls auto-pick colors from the selected palette in order.

Parameters:
  • ax (axes or list of axes)

  • palette (str or list of colors, default='tableau') – Named palette string (see palette()) or an explicit list of any Matplotlib-compatible color specs.

Examples

>>> Plotter.set_color_cycle(ax, 'wong')
>>> ax.plot(x1, y1)   # first wong color
>>> ax.plot(x2, y2)   # second wong color
static apply_palette(axes, palette: str = 'tableau') None[source]

Apply a named color palette as the default color cycle for one or more axes. Shorthand for set_color_cycle().

Examples

>>> fig, axes = Plotter.get_subplots(1, 3)
>>> Plotter.apply_palette(axes, 'wong')
static to_rgba(color, alpha: float | None = None) Tuple[float, float, float, float][source]

Convert any Matplotlib-compatible color spec to an (r, g, b, a) tuple.

Parameters:
  • color (color spec) – Named string, hex, RGB/RGBA tuple, 'C0'-style, etc.

  • alpha (float, optional) – Override the alpha channel (0–1).

Return type:

tuple[float, float, float, float]

Examples

>>> Plotter.to_rgba('C0')
>>> Plotter.to_rgba('#E64B35', alpha=0.5)
static to_hex(color, keep_alpha: bool = False) str[source]

Convert any Matplotlib-compatible color spec to a hex string.

Parameters:
  • color (color spec)

  • keep_alpha (bool, default=False) – If True, return an 8-character #RRGGBBAA string.

Examples

>>> Plotter.to_hex('C0')              # '#1f77b4'
>>> Plotter.to_hex((0.2, 0.4, 0.6, 0.8), keep_alpha=True)
static adjust_color(color, *, lighten: float = 0.0, darken: float = 0.0, saturate: float = 0.0, desaturate: float = 0.0, alpha: float | None = None) Tuple[float, float, float, float][source]

Perceptually adjust a color in HLS space.

Each parameter shifts the corresponding channel by a fraction of the remaining headroom, so operations compose gracefully and values are always clamped to [0, 1].

Parameters:
  • color (color spec) – Any Matplotlib-compatible color.

  • lighten (float, default=0.0) – Push lightness toward 1 (white). 0 = no change, 1 = white.

  • darken (float, default=0.0) – Push lightness toward 0 (black). 0 = no change, 1 = black.

  • saturate (float, default=0.0) – Push saturation toward 1. 0 = no change, 1 = fully saturated.

  • desaturate (float, default=0.0) – Push saturation toward 0 (grey). 0 = no change, 1 = grey.

  • alpha (float, optional) – Override alpha channel (0–1).

Returns:

Adjusted RGBA color.

Return type:

tuple[float, float, float, float]

Examples

>>> Plotter.adjust_color('C0', lighten=0.3)
>>> Plotter.adjust_color('#E64B35', darken=0.4)
>>> Plotter.adjust_color('C2', desaturate=0.5, alpha=0.7)
static lighten(color, amount: float = 0.3) Tuple[float, float, float, float][source]

Return a lightened version of color (push lightness toward white).

Parameters:
  • color (color spec)

  • amount (float, default=0.3) – 0 = no change, 1 = white.

Examples

>>> fill = Plotter.lighten('C0', 0.5)
>>> Plotter.fill_between(ax, x, y1, y2, color=fill)
static darken(color, amount: float = 0.3) Tuple[float, float, float, float][source]

Return a darkened version of color (push lightness toward black).

Parameters:
  • color (color spec)

  • amount (float, default=0.3) – 0 = no change, 1 = black.

Examples

>>> edge = Plotter.darken('C0', 0.25)
static desaturate(color, amount: float = 0.5) Tuple[float, float, float, float][source]

Return a desaturated (greyed-out) version of color.

Parameters:
  • color (color spec)

  • amount (float, default=0.5) – 0 = original, 1 = fully grey.

Examples

>>> faded = Plotter.desaturate('C1', 0.6)
static with_alpha(color, a: float = 0.5) Tuple[float, float, float, float][source]

Return color with a modified alpha channel.

Parameters:
  • color (color spec)

  • a (float) – New alpha value (0–1).

Examples

>>> Plotter.fill_between(ax, x, y1, y2, color=Plotter.with_alpha('C0', 0.25))
static blend(c1, c2, t: float = 0.5, *, n: int | None = None) Tuple[float, float, float, float] | List[Tuple[float, float, float, float]][source]

Linearly interpolate between two colors in linear RGB space.

Parameters:
  • c1 (color spec) – Start and end colors.

  • c2 (color spec) – Start and end colors.

  • t (float, default=0.5) – Blend position: 0 → c1, 1 → c2. Ignored when n is set.

  • n (int, optional) – If provided, return n evenly-spaced colors from c1 to c2 (inclusive of both endpoints).

Returns:

Single RGBA tuple when n is None; list of n tuples otherwise.

Return type:

tuple or list[tuple]

Examples

>>> mid      = Plotter.blend('red', 'blue')
>>> gradient = Plotter.blend('#E64B35', '#4DBBD5', n=7)
static n_colors(n: int, cmap: str | Colormap = 'viridis', vmin: float = 0.0, vmax: float = 1.0, *, as_hex: bool = False) List[source]

Sample n evenly-spaced colors from a colormap.

Ideal for encoding a continuous parameter (temperature, time, β …) as line colors when you want a smooth gradient rather than a categorical palette.

Parameters:
  • n (int) – Number of colors to sample.

  • cmap (str or Colormap, default='viridis') – Source colormap.

  • vmin (float, default 0.0, 1.0) – Fraction range to sample from (allows using only a sub-range of the colormap, e.g. vmin=0.1, vmax=0.9 avoids the near-white ends of sequential maps).

  • vmax (float, default 0.0, 1.0) – Fraction range to sample from (allows using only a sub-range of the colormap, e.g. vmin=0.1, vmax=0.9 avoids the near-white ends of sequential maps).

  • as_hex (bool, default=False) – If True, return hex strings instead of RGBA tuples.

Return type:

list[tuple] or list[str]

Examples

>>> colors = Plotter.n_colors(5, 'plasma')
>>> for c, (x, y) in zip(colors, datasets):
...     Plotter.plot(ax, x, y, color=c)
>>> # Avoid extreme ends of the colormap
>>> colors = Plotter.n_colors(8, 'RdBu_r', vmin=0.1, vmax=0.9)
static cmap_colors(cmap: str | Colormap, values: ndarray, *, vmin: float | None = None, vmax: float | None = None, norm: Normalize | None = None, scale: str = 'linear') List[Tuple][source]

Map an array of scalar values to RGBA colors via a colormap.

Convenience wrapper around get_colormap() when you only need the list of colors (not the full getcolor / norm / mappable bundle).

Parameters:
  • cmap (str or Colormap)

  • values (array-like) – Scalar values to map.

  • vmin (float, optional) – Color limits. Default to min / max(values).

  • vmax (float, optional) – Color limits. Default to min / max(values).

  • norm (Normalize, optional) – Explicit normalization. Takes precedence over scale.

  • scale ({'linear', 'log', 'symlog'}, default='linear')

Returns:

RGBA tuples, one per value.

Return type:

list[tuple]

Examples

>>> beta_values = np.linspace(0.1, 2.0, 8)
>>> colors      = Plotter.cmap_colors('plasma', beta_values)
>>> for val, c in zip(beta_values, colors):
...     Plotter.plot(ax, x, data[val], color=c, label=rf'$\beta={val:.1f}$')
static filter_results(results, filters=None, get_params_fun: callable = None, *, tol=1e-09)[source]

Backward-compatible wrapper around plotters.data_loader.filter_results.

static get_figsize(columnwidth, wf=0.5, hf=None, aspect_ratio=None)[source]
Parameters:
  • [float] (- columnwidth) – width fraction in columnwidth units

  • [float] – height fraction in columnwidth units. If None, it will be calculated based on aspect_ratio.

  • [float] – Aspect ratio (height/width). If None, defaults to golden ratio.

  • [float] – width of the column in latex. Get this from LaTeX using showthecolumnwidth

Returns: [fig_width, fig_height]: that should be given to matplotlib

static get_color(color, alpha=None, edgecolor=(0, 0, 0, 1), facecolor=(1, 1, 1, 0))[source]

Get a dictionary with color properties for matplotlib patches. :param - color [str or tuple]: Color to use, can be a named color or an RGB tuple. :param - alpha [float]: Transparency level (0 to 1). :param - edgecolor [tuple]: Edge color as an RGB tuple. :param - facecolor [tuple]: Face color as an RGB tuple.

Returns:

Dictionary with color properties.

Return type:

  • dictionary [dict]

static add_colorbar(fig: Figure, pos: List[float], mappable: ndarray | list | _ScalarMappable, cmap: str | Colormap = 'viridis', norm: Normalize | None = None, vmin: float | None = None, vmax: float | None = None, scale: str = 'linear', orientation: str = 'vertical', label: str = '', label_kwargs: dict = None, title: str = '', title_kwargs: dict = None, ticks: List | ndarray | None = None, ticklabels: List[str] | None = None, tick_location: str = 'auto', tick_params: dict = None, extend: str = None, format: str | Formatter | None = None, discrete: bool | int = False, boundaries: List[float] = None, invert: bool = False, remove_pdf_lines: bool = True, **kwargs) Tuple[Colorbar, Axes][source]

Add a fully customizable colorbar to the figure at a specific position.

Parameters:
  • fig (matplotlib.figure.Figure) – Parent figure onto which the colorbar axis is added.

  • pos (list[float] | tuple[float, float, float, float]) – [left, bottom, width, height] in figure coordinates (0..1).

  • mappable (array-like | matplotlib.cm.ScalarMappable) –

    • If array-like: a new ScalarMappable is built from cmap/norm (and scale, vmin, vmax).

    • If ScalarMappable: it is used directly. vmin/vmax update its clim; norm is taken from it

    when not provided. Note: in this case discrete/boundaries resampling is not applied.

  • cmap (str | Colormap, default='viridis') – Colormap name or object. If mappable is a ScalarMappable, its cmap is used unless cmap is explicitly different from the default and a new mappable is constructed (array-like path).

  • norm (matplotlib.colors.Normalize, optional) – Normalization to map data to 0-1. Ignored if mappable is ScalarMappable and norm is None (then the mappable’s norm is used).

  • vmin (float, optional) – Data limits. When scale=’log’, non-positive vmin is clamped internally.

  • vmax (float, optional) – Data limits. When scale=’log’, non-positive vmin is clamped internally.

  • scale ({'linear', 'log', 'symlog'}, default='linear') – Creates a suitable Normalize when mappable is array-like and norm is None. - ‘linear’ -> Normalize - ‘log’ -> LogNorm (vmin<=0 clamped to ~1e-10) - ‘symlog’ -> SymLogNorm with linthresh=0.1

  • orientation ({'vertical', 'horizontal'}, default='vertical') – Colorbar orientation.

  • label (str, default='') – Axis label along the long side of the colorbar.

  • label_kwargs (dict, optional) – Passed to ColorbarBase.set_label (e.g., dict(fontsize=…, labelpad=…)).

  • title (str, default='') – Title text set at the end/top of the colorbar. For horizontal bars, the title is placed to the side.

  • title_kwargs (dict, optional) – Text properties for the title (e.g., dict(fontsize=…, pad=…)).

  • ticks (list[float] | np.ndarray, optional) – Explicit major tick locations.

  • ticklabels (list[str], optional) – Custom labels for the ticks (same length as ticks).

  • tick_location ({'auto','left','right','top','bottom'}, default='auto') – Side on which to draw ticks/labels (respects orientation).

  • tick_params (dict, optional) – Passed to cbar.ax.tick_params (e.g., dict(length=4, width=1, direction=’in’)).

  • extend ({'neither','both','min','max','neutral'}, default='neutral') – Colorbar extension behavior. Standard Matplotlib values are ‘neither’, ‘both’, ‘min’, ‘max’. ‘neutral’ is treated as a pass-through here and may behave like ‘neither’ depending on Matplotlib.

  • format (str | matplotlib.ticker.Formatter, optional) – Tick formatting. If str (e.g., ‘%.2e’), uses FormatStrFormatter.

  • discrete (bool | int, default=False) – Discretize colormap when building from array-like: - True -> 10 bins - int N -> N bins Ignored when mappable is a ScalarMappable.

  • boundaries (list[float], optional) – Discrete bin edges. Enables BoundaryNorm and passes boundaries to fig.colorbar (default spacing=’proportional’, overridable via kwargs[‘spacing’]).

  • invert (bool, default=False) – If True, invert the colorbar axis direction.

  • remove_pdf_lines (bool, default=True) – Set solids edgecolor to ‘face’ to avoid white hairlines in vector exports (PDF/SVG).

  • **kwargs – Additional arguments forwarded to fig.colorbar, e.g.: - alpha, spacing (‘uniform’|’proportional’), fraction, pad, shrink, aspect, drawedges, etc.

Returns:

(cbar, cax) – The created colorbar and its axes.

Return type:

tuple[matplotlib.colorbar.Colorbar, matplotlib.axes.Axes]

Notes

  • When mappable is a ScalarMappable, this helper does not modify its colormap discretization.

    To use discrete/boundaries, pass raw data (array-like) instead.

  • For ‘log’ scale, ensure your data are strictly positive (this function clamps vmin if needed).

Examples

# Vertical, linear scale from raw data cbar, cax = Plotter.add_colorbar(fig, [0.92, 0.15, 0.02, 0.7], data, label=’Mz’)

# Horizontal, log scale with sci formatting and extensions cbar, cax = Plotter.add_colorbar(

fig, [0.2, 0.9, 0.6, 0.03], data, scale=’log’, orientation=’horizontal’, format=’%.0e’, extend=’both’, tick_location=’top’, label=’Conductance’

)

# Discrete categorical-like bar with custom tick labels cbar, cax = Plotter.add_colorbar(

fig, [0.85, 0.1, 0.03, 0.8], [0, 1, 2], cmap=’Set1’, discrete=3, ticklabels=[‘Insulator’, ‘Metal’, ‘SC’]

)

# Non-uniform boundaries cbar, cax = Plotter.add_colorbar(

fig, [0.86, 0.15, 0.02, 0.7], data, boundaries=[0, 0.5, 2.0, 10.0], spacing=’proportional’

)

static get_colormap(values: ndarray | None = None, vmin=None, vmax=None, *, cmap='PuBu', elsecolor='blue', get_mappable: bool = False, return_mappable: bool | None = None, norm=None, scale='linear', **kwargs)[source]

Get a colormap for the given values.

Parameters: - values (array-like): The values to map to colors. - cmap (str, optional): The colormap to use. Defaults to ‘PuBu’. - elsecolor (str, optional): The color to use if there is only one value. Defaults to ‘blue’. - get_mappable (bool, optional): If True, also return a ScalarMappable as

the 4th item, ready to pass into Plotter.add_colorbar(…, mappable=…).

  • return_mappable (bool, optional): Alias for get_mappable.

Returns: - getcolor (function): A function that maps a value to a color. - colors (Colormap): The colormap object. - norm (Normalize): The normalization object. - mappable (ScalarMappable, optional): Returned when get_mappable=True

(or return_mappable=True).

Example: >>> getcolor, colors, norm = Plotter.get_colormap([1, 2, 3], cmap=’viridis’) >>> color = getcolor(2.5) >>> getcolor, colors, norm, mappable = Plotter.get_colormap( … [1, 2, 3], cmap=’viridis’, return_mappable=True … )

static apply_colormap(ax, data, cmap='PuBu', colorbar=True, **kwargs)[source]

Apply a colormap to the given data and plot it on the provided axis.

Parameters: - ax (object): The axis object to plot on. - data (array-like): The data to plot. - cmap (str, optional): The colormap to use. Defaults to ‘PuBu’. - colorbar (bool, optional): Whether to add a colorbar. Defaults to True.

Returns: - img (AxesImage): The image object.

static discrete_colormap(N, base_cmap=None)[source]

Create an N-bin discrete colormap from the specified input map.

Parameters: - N (int): Number of discrete colors. - base_cmap (str or Colormap, optional): The base colormap to use. Defaults to None.

Returns: - cmap (Colormap): The discrete colormap.

static set_annotate(ax, elem: str, x: float = 0, y: float = 0, fontsize=None, xycoords='axes fraction', cond=True, zorder=50, boxaround=True, **kwargs)[source]

Make an annotation on the plot. - ax : axis to annotate on - elem : annotation string - x : x coordinate (ignored if xycoords=’best’) - y : y coordinate (ignored if xycoords=’best’) - fontsize : fontsize of the annotation - xycoords : how to interpret the coordinates (from MPL), or ‘best’ to find the best corner - cond : condition to make the annotation

static set_annotate_letter(ax: Axes, iter: int, x: float = 0, y: float = 0, fontsize=12, xycoords='axes fraction', addit='', condition=True, zorder=50, boxaround=False, fontweight='normal', color='black', **kwargs)[source]

Annotate plot with the letter.

Params:

ax: matplotlib.axes.Axes

axis to annotate on

iter:

iteration number

x:

x coordinate

y:

y coordinate

fontsize:

fontsize

xycoords:

how to interpret the coordinates (from MPL)

addit:

additional string to add after the letter

condition:

condition to make the annotation

zorder:

zorder of the annotation

boxaround:

whether to put a box around the annotation

fontweight:

weight of the text (‘bold’, ‘normal’, etc.)

kwargs:

additional arguments for annotation - color : color of the text - weight: weight of the text

Example:

>>> Plotter.set_annotate_letter(ax, 0, x=0.1, y=0.9, fontsize=14, addit=' Test', color='red')
static set_arrow(ax, start_T: str, end_T: str, xystart: float, xystart_T: float, xyend: float, xyend_T: float, arrowprops={'arrowstyle': '->'}, startcolor='black', endcolor='black', **kwargs)[source]

@staticmethod

Make an annotation on the plot. - ax : axis to annotate on - start_T : start text - end_T : end text - xystart : x coordinate start - xystart_T : x coordinate start text - xyend : x coordinate end - xyend_T : x coordinate end text - arrowprops: properties of the arrow - startcolor: color of the arrow at the start - endcolor : color of the arrow in the end - kwargs : additional arguments for annotation

static callout(ax, text: str, xy, xytext=None, *, xycoords: str = 'data', textcoords: str | None = None, arrowstyle: str = '->', color: str = 'black', lw: float = 1.0, boxaround: bool = True, box_alpha: float = 0.85, zorder: int = 20, **kwargs)[source]

Add a compact callout (text + optional arrow) to an axis.

static highlight_box(ax, x: float, y: float, width: float, height: float, *, coords: str = 'data', edgecolor='crimson', facecolor='none', lw: float = 1.3, ls='-', alpha: float = 0.95, zorder: int = 15, **kwargs)[source]

Draw a highlighted rectangular region in data or axes coordinates.

static highlight_circle(ax, x: float, y: float, radius: float, *, coords: str = 'data', edgecolor='darkorange', facecolor='none', lw: float = 1.3, ls='-', alpha: float = 0.95, zorder: int = 15, **kwargs)[source]

Draw a highlighted circular region in data or axes coordinates.

static plot_fit(ax, funct, x, **kwargs)[source]

@staticmethod

Plots the fitting function provided by the user on a given axis using the **kwargs provider afterwards. - ax : axis to annotate on - funct : function to use for the fitting - x : arguments to the function

static hline(ax: Axes, val: float, ls='--', lw=2.0, color='black', label=None, zorder=10, label_cond=True, **kwargs)[source]

horizontal line plotting

static vline(ax, val: float, ls='--', lw=2.0, color='black', label=None, zorder=10, label_cond=True, **kwargs)[source]

vertical line plotting

static scatter(ax, x, y, *, s=10, c='blue', marker='o', alpha=1.0, label=None, edgecolor=None, zorder=5, label_cond=True, linewidths=1.0, cmap=None, norm=None, vmin=None, vmax=None, plotnonfinite=False, clip_on=True, rasterized=False, **kwargs)[source]

Creates a scatter plot on the provided axis, styled for Nature-like plots.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis on which to draw the scatter plot.

  • x (array-like) – The x-coordinates of the points.

  • y (array-like) – The y-coordinates of the points.

  • s (float or array-like, optional) – The size of the points (default: 10).

  • c (color or array-like, optional) – The color of the points (default: ‘blue’).

  • marker (str, optional) – The shape of the points (default: ‘o’).

  • alpha (float, optional) – The transparency of the points (0.0 to 1.0, default: 1.0).

  • label (str, optional) – The label for the points (default: None).

  • edgecolor (str or array-like, optional) – The edge color of the points (default: ‘white’).

  • zorder (int, optional) – The drawing order of the points (default: 5).

  • **kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.scatter.

Example

scatter(ax, x_data, y_data, s=20, c=’red’, alpha=0.5, label=’Sample Data’)

static tripcolor_field(ax, points, values, *, triangles=None, mask=None, shading: str = 'gouraud', **kwargs)[source]

Plot a scalar field sampled on irregular planar points using triangulation.

This helper is meant for 2D scattered data where imshow is not appropriate because the samples do not lie on a regular rectangular grid. Matplotlib first builds a triangulation of the point cloud and then interpolates values inside each triangle.

Typical use cases: - real-space lattice-site scalar fields - Brillouin-zone data on irregular planar k-point sets - any scattered 2D measurement data

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • points (array-like) – Sample positions shaped like (N, 2) or (N, D) with D >= 2. Only the first two Cartesian components are used.

  • values (array-like) – Scalar values of length N.

  • triangles (array-like, optional) – Explicit connectivity passed to matplotlib.tri.Triangulation.

  • mask (array-like of bool, optional) – Triangle mask. True hides the corresponding triangle.

  • shading ({'flat', 'gouraud'}, default='gouraud') – Interpolation mode inside triangles.

  • **kwargs – Forwarded to Axes.tripcolor.

Returns:

The created artist, or None if fewer than three points are given.

Return type:

matplotlib.collections.Collection | None

static plot(ax, *args, y=None, x=None, ls='-', lw=2.0, color='black', label=None, label_cond=True, marker=None, ms=None, zorder=5, drawstyle='default', markevery=None, clip_on=True, rasterized=False, antialiased=True, solid_capstyle=None, solid_joinstyle=None, **kwargs)[source]

plot the data

static fill_between(ax, x, y1, y2, color='blue', alpha=0.5, where=None, interpolate=False, step=None, linewidth=0.0, edgecolor=None, zorder=4, clip_on=True, rasterized=False, **kwargs)[source]

Fills the area between two curves on the provided axis.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis on which to fill the area.

  • x (array-like) – The x-coordinates of the points.

  • y1 (array-like) – The y-coordinates of the first curve.

  • y2 (array-like) – The y-coordinates of the second curve.

  • color (str, optional) – The color of the filled area (default: ‘blue’).

  • alpha (float, optional) – The transparency of the filled area (0.0 to 1.0, default: 0.5).

  • **kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.fill_between.

Example

fill_between(ax, x_data, y1_data, y2_data, color=’red’, alpha=0.3)

static semilogy(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic y-axis.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot.

  • y (array-like) – Data to plot.

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.semilogy.

Examples

>>> Plotter.semilogy(ax, x, np.exp(-x), color='C0', label=r'$e^{-x}$')
static semilogx(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic x-axis.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot.

  • y (array-like) – Data to plot.

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.semilogx.

Examples

>>> Plotter.semilogx(ax, np.logspace(-3, 3, 100), y, color='C1')
static loglog(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic x and y axes.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot (must be positive).

  • y (array-like) – Data to plot (must be positive).

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.loglog.

Examples

>>> # Power law: y = x^(-2)
>>> x = np.logspace(0, 3, 50)
>>> Plotter.loglog(ax, x, x**(-2), label=r'$x^{-2}$', color='C2')
static errorbar(ax, x, y, yerr=None, xerr=None, fmt='o', color='black', capsize=2, capthick=1.0, elinewidth=1.0, markersize=5, label=None, label_cond=True, alpha=1.0, zorder=5, ecolor=None, errorevery=1, barsabove=False, uplims=False, lolims=False, xuplims=False, xlolims=False, clip_on=True, rasterized=False, **kwargs)[source]

Plot data with error bars.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data points.

  • y (array-like) – Data points.

  • yerr (float or array-like, optional) – Vertical error bars. Can be: - scalar: symmetric error for all points - 1D array: symmetric errors - 2D array (2, N): asymmetric [lower, upper] errors

  • xerr (float or array-like, optional) – Horizontal error bars (same format as yerr).

  • fmt (str, default='o') – Format string for markers (’’ for no markers, just error bars).

  • color (str or int, default='black') – Color for markers and error bars.

  • capsize (float, default=2) – Length of error bar caps.

  • capthick (float, default=1.0) – Thickness of error bar caps.

  • elinewidth (float, default=1.0) – Width of error bar lines.

  • markersize (float, default=5) – Size of markers.

  • label (str, optional) – Legend label.

  • alpha (float, default=1.0) – Transparency.

  • **kwargs – Additional arguments passed to ax.errorbar.

Examples

>>> # Symmetric error
>>> Plotter.errorbar(ax, x, y, yerr=sigma, label='Data')
>>> # Asymmetric error
>>> Plotter.errorbar(ax, x, y, yerr=[lower_err, upper_err])
>>> # Error band without markers
>>> Plotter.errorbar(ax, x, y, yerr=sigma, fmt='', elinewidth=2)
static histogram(ax, data, bins=50, density=True, histtype='stepfilled', alpha=0.7, color='C0', edgecolor='black', linewidth=1.0, label=None, orientation='vertical', cumulative=False, log=False, label_cond=True, zorder=5, weights=None, range=None, align='mid', rwidth=None, stacked=False, hatch=None, **kwargs)[source]

Plot a histogram.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • data (array-like) – Input data.

  • bins (int or array-like, default=50) – Number of bins or bin edges.

  • density (bool, default=True) – If True, normalize to form a probability density.

  • histtype (str, default='stepfilled') – Type of histogram: ‘bar’, ‘barstacked’, ‘step’, ‘stepfilled’.

  • alpha (float, default=0.7) – Transparency.

  • color (str, default='C0') – Fill color.

  • edgecolor (str, default='black') – Edge color.

  • linewidth (float, default=1.0) – Edge line width.

  • label (str, optional) – Legend label.

  • orientation (str, default='vertical') – ‘vertical’ or ‘horizontal’.

  • cumulative (bool, default=False) – If True, plot cumulative histogram.

  • log (bool, default=False) – If True, use log scale for counts axis.

  • **kwargs – Additional arguments passed to ax.hist.

Returns:

  • n (array) – Histogram values.

  • bins (array) – Bin edges.

  • patches (list) – Patch objects.

Examples

>>> # Basic histogram
>>> Plotter.histogram(ax, data, bins=30, label='Distribution')
>>> # Step histogram (unfilled)
>>> Plotter.histogram(ax, data, histtype='step', linewidth=2)
>>> # Cumulative distribution
>>> Plotter.histogram(ax, data, cumulative=True, density=True)
static contourf(ax, x, y, z, **kwargs)[source]

contourf plotting

static grid(ax, **kwargs)[source]

grid plotting

Kwargs include: - which : {‘major’, ‘minor’, ‘both’}, optional, default: ‘major’

  • Specifies which grid lines to apply the settings to.

  • axis{‘both’, ‘x’, ‘y’}, optional, default: ‘both
    • Specifies which axis to apply the grid settings to.

  • colorcolor, optional
    • Color of the grid lines.

  • linestylestr, optional
    • Style of the grid lines (e.g., ‘-’, ‘–’, ‘-.’, ‘:’).

  • linewidthfloat, optional
    • Width of the grid lines.

  • alphafloat, optional
    • Transparency of the grid lines (0.0 to 1.0).

static set_tickparams(ax, labelsize=None, left=True, right=True, top=True, bottom=True, xticks=None, yticks=None, xticklabels=None, yticklabels=None, maj_tick_l=4, min_tick_l=2, **kwargs)[source]

Sets tickparams to the desired ones. - ax : axis to use - labelsize : fontsize - left : whether to show the left side - right : whether to show the right side - top : whether to show the top side - bottom : whether to show the bottom side - xticks : list of xticks - yticks : list of yticks

static set_ax_params(ax, which: str = 'both', xlabel: str | None = None, ylabel: str | None = None, title: str | None = None, fontsize: int | None = None, labelsize_title: int | None = None, labelsize_tick: int | None = None, labelpad: float | dict = 0.0, title_pad: float = 10.0, xlabel_position: Literal['top', 'bottom'] = 'bottom', ylabel_position: Literal['left', 'right'] = 'left', xlim: tuple | None = None, ylim: tuple | None = None, xscale: Literal['linear', 'log', 'symlog'] = 'linear', yscale: Literal['linear', 'log', 'symlog'] = 'linear', xticks: list | ndarray | None = None, yticks: list | ndarray | None = None, xticklabels: list | None = None, yticklabels: list | None = None, xtickpos: Literal['top', 'bottom', 'both'] = None, ytickpos: Literal['left', 'right', 'both'] = None, tick_length_major: float = 4.0, tick_length_minor: float = 2.0, tick_width: float = 0.8, tick_direction: Literal['in', 'out', 'inout'] = 'in', show_minor_ticks: bool = True, minor_tick_locator: str | None = 'auto', grid: bool = False, grid_axis: Literal['both', 'x', 'y'] = 'both', grid_which: Literal['major', 'minor', 'both'] = 'major', grid_style: str = '--', grid_color: str | None = None, grid_alpha: float = 0.3, grid_linewidth: float = 0.8, show_spines: bool | dict = True, spine_width: float = 1.0, spine_color: str = 'black', aspect: str | float | None = None, tight_layout: bool = False, legend: bool = False, legend_kwargs: dict | None = None, invert_xaxis: bool = False, invert_yaxis: bool = False, auto_formatter: bool = True, label_cond: bool = True, label_pos: dict = None, tick_pos: dict = None, **kwargs)[source]

Comprehensive axis configuration method for publication-quality plots.

This method provides centralized control over all major axis properties with sensible defaults and advanced options for fine-tuning. It integrates with other Plotter methods for a cohesive styling experience.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis object to modify.

  • which ({'both', 'x', 'y'}, default='both') – Specifies which axes to update. Allows independent configuration of x and y axes.

  • Titles** (**Labels and)

  • xlabel (str, optional) – Axis labels. Set to ‘’ to hide labels while maintaining formatting.

  • ylabel (str, optional) – Axis labels. Set to ‘’ to hide labels while maintaining formatting.

  • title (str, optional) – Axis title.

  • fontsize (int, optional) – Default font size for labels (overridable per-element).

  • labelsize_title (int, optional) – Font size for title. If None, uses fontsize + 2.

  • labelsize_tick (int, optional) – Font size for tick labels. If None, uses fontsize - 2.

  • labelpad (float or dict, default=0.0) – Padding between label and axis. Can be {‘x’: val, ‘y’: val}.

  • title_pad (float, default=10.0) – Vertical padding between title and plot area.

  • Positioning** (**Label)

  • xlabel_position ({'top', 'bottom'}, default='bottom') – Position of x-axis label.

  • ylabel_position ({'left', 'right'}, default='left') – Position of y-axis label.

  • Scales** (**Axis Limits and)

  • xlim (tuple, optional) – Axis limits as (min, max). Use None for auto limits.

  • ylim (tuple, optional) – Axis limits as (min, max). Use None for auto limits.

  • xscale ({'linear', 'log', 'symlog'}, default='linear') – Axis scale type. ‘symlog’ uses symmetric log scaling.

  • yscale ({'linear', 'log', 'symlog'}, default='linear') – Axis scale type. ‘symlog’ uses symmetric log scaling.

  • Configuration** (**Spine)

  • xticks (list or np.ndarray, optional) – Explicit tick positions. Leave None for matplotlib auto-ticks.

  • yticks (list or np.ndarray, optional) – Explicit tick positions. Leave None for matplotlib auto-ticks.

  • xticklabels (list, optional) – Custom tick labels. Must match length of ticks if provided.

  • yticklabels (list, optional) – Custom tick labels. Must match length of ticks if provided.

  • tick_length_major (float, default=4.0) – Length of major ticks in points.

  • tick_length_minor (float, default=2.0) – Length of minor ticks in points.

  • tick_width (float, default=0.8) – Width of ticks in points.

  • tick_direction ({'in', 'out', 'inout'}, default='in') – Direction ticks point (‘in’ recommended for publication).

  • show_minor_ticks (bool, default=True) – Whether to show minor ticks.

  • minor_tick_locator ({'auto', 'log'}, default='auto') – How to locate minor ticks. ‘log’ uses LogLocator for log scales.

  • Configuration**

  • grid (bool, default=False) – Enable gridlines.

  • grid_axis ({'both', 'x', 'y'}, default='both') – Which axes to show grid on.

  • grid_which ({'major', 'minor', 'both'}, default='major') – Which ticks to grid on.

  • grid_style (str, default='--') – Line style (‘-’, ‘–’, ‘-.’, ‘:’).

  • grid_color (str, optional) – Grid color. If None, uses current axes color scheme.

  • grid_alpha (float, default=0.3) – Transparency of gridlines (0=transparent, 1=opaque).

  • grid_linewidth (float, default=0.8) – Width of gridlines in points.

  • Configuration**

  • show_spines (bool or dict, default=True) – Visibility of spines. - True: show all spines - False: hide all spines - dict: {‘top’: bool, ‘bottom’: bool, ‘left’: bool, ‘right’: bool}

  • spine_width (float, default=1.0) – Width of spines in points.

  • spine_color (str, default='black') – Color of spines.

  • **Appearance**

  • aspect (str or float, optional) – Aspect ratio (‘equal’, ‘auto’) or numeric value.

  • tight_layout (bool, default=False) – Apply tight layout after configuration.

  • **Legend**

  • legend (bool, default=False) – Whether to display legend using set_legend().

  • legend_kwargs (dict, optional) – Arguments to pass to set_legend() if legend=True.

  • Options** (**Advanced)

  • invert_xaxis (bool, default=False) – Invert the direction of the axes.

  • invert_yaxis (bool, default=False) – Invert the direction of the axes.

  • auto_formatter (bool, default=True) – Automatically apply scientific notation formatter for large/small numbers.

  • **kwargs – Additional keyword arguments passed to matplotlib functions.

Examples

Example 1: Basic publication-ready plot

>>> ax = plt.gca()
>>> Plotter.set_ax_params(
...     ax,
...     xlabel=r'$x$ (nm)',
...     ylabel=r'Energy (eV)',
...     title='Band Structure',
...     xlim=(0, 10),
...     ylim=(-5, 5),
...     grid=True
... )

Example 2: Log-scale with custom ticks

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Frequency (Hz)',
...     ylabel='Magnitude',
...     yscale='log',
...     yticks=[1, 10, 100, 1000],
...     yticklabels=['1', '10', '100', '1 k'],
...     grid=True,
...     grid_which='both',
...     minor_tick_locator='log'
... )

Example 3: Detailed styling

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Temperature (K)',
...     ylabel=r'$\rho$ (Ω·cm)',
...     title='Resistivity vs Temperature',
...     fontsize=12,
...     labelsize_title=14,
...     labelsize_tick=10,
...     xlim=(0, 300),
...     ylim=(0, None),  # auto max
...     grid=True,
...     grid_style='--',
...     grid_alpha=0.4,
...     show_spines={'top': False, 'right': False},
...     spine_width=1.5,
...     tick_length_major=6,
...     legend=True,
...     tight_layout=True
... )

Example 4: Custom tick labels and positions

>>> import numpy as np
>>> Plotter.set_ax_params(
...     ax,
...     xticks=np.linspace(0, 2*np.pi, 5),
...     xticklabels=['0', 'π/2', 'π', '3π/2', '2π'],
...     xlabel=r'Phase',
...     ylabel=r'$\sin(\phi)$'
... )

Example 5: Asymmetric spines (Nature style)

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Parameter',
...     ylabel='Value',
...     show_spines={'left': True, 'bottom': True, 'top': False, 'right': False},
...     grid=False,
...     tick_direction='out'
... )

Notes

  • Set labelsize_* to None to auto-scale relative to fontsize

  • Grid is best used with light colors and low alpha (0.2-0.4)

  • For log scales, minor_tick_locator=’log’ is recommended

  • Use which=’x’ or which=’y’ for independent axis control

  • Integrates with Plotter.set_legend() for unified styling

See also

set_legend

Configure legend appearance

set_tickparams

Alternative tick configuration method

grid

Add gridlines to axis

static set_xlabel(ax, xlabel, fontsize=None, labelpad=0, loc=None, x=None, y=None, coords: str = 'axes', transform=None, **kwargs)[source]

Set x-axis label with optional alignment and explicit coordinates.

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • xlabel (str) – Label text.

  • fontsize (int, optional) – Label font size.

  • labelpad (float, default=0) – Padding in points.

  • loc ({'left', 'center', 'right'}, optional) – Matplotlib label location argument.

  • x (float, optional) – Explicit label coordinates (if either is provided).

  • y (float, optional) – Explicit label coordinates (if either is provided).

  • coords ({'axes', 'data'}, default='axes') – Coordinate system used for x/y when transform is not provided.

  • transform (matplotlib transform, optional) – Explicit transform for label coordinates.

  • **kwargs – Forwarded to ax.set_xlabel.

static set_ylabel(ax, ylabel, fontsize=None, labelpad=0, loc=None, x=None, y=None, coords: str = 'axes', transform=None, **kwargs)[source]

Set y-axis label with optional alignment and explicit coordinates.

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • ylabel (str) – Label text.

  • fontsize (int, optional) – Label font size.

  • labelpad (float, default=0) – Padding in points.

  • loc ({'bottom', 'center', 'top'}, optional) – Matplotlib label location argument.

  • x (float, optional) – Explicit label coordinates (if either is provided).

  • y (float, optional) – Explicit label coordinates (if either is provided).

  • coords ({'axes', 'data'}, default='axes') – Coordinate system used for x/y when transform is not provided.

  • transform (matplotlib transform, optional) – Explicit transform for label coordinates.

  • **kwargs – Forwarded to ax.set_ylabel.

static set_ax_labels(ax, fontsize=None, xlabel='', ylabel='', title='', xPad=0, yPad=0, xloc=None, yloc=None, xcoords: str = 'axes', ycoords: str = 'axes', x_pos=None, y_pos=None)[source]

Sets the labels of the x and y axes

static set_label_cords(ax, which: str, inX=0.0, inY=0.0, **kwargs)[source]

Sets the coordinates of the labels

static setup_log_y(ax: Axes, ylims=(1e-12, 1000000.0), decade_step=4)[source]

Configure clean log-scale y ticks at powers of 10 with LaTeX-like labels.

static setup_log_x(ax: Axes, xlims=(1e-12, 1000000.0), decade_step=4)[source]

Configure clean log-scale x ticks at powers of 10 with LaTeX-like labels.

static set_smart_lim(ax, *, which: str = 'both', data: ndarray | None = None, margin_p: float = 0, margin_m: float = 1, xlim: tuple | None = None, ylim: tuple | None = None)[source]

Auto-compute robust axis limits and apply them to ax.

static hide_unused_panels(axes: Axes, n_panels: int)[source]

Hide unused panels in a subplot grid.

static labellines(ax, align=False, xvals=None, yoffsets=[], zorder=2, **kwargs)[source]

Add labels to lines with a given slope. Uses labelLines package. :param - ax: Matplotlib axis object. :param - align: Align the label with the slope of the line. :param - xvals: The x values to place the labels at. :param - yoffsets: The y offsets for the labels. :param - zorder: The zorder of the labels.

static unset_spines(ax, top: bool = True, right: bool = True, bottom: bool = False, left: bool = False)[source]

Remove specified spines from the axis for cleaner publication-style plots.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • top (bool, default=True) – If True, REMOVE the top spine. If False, KEEP it.

  • right (bool, default=True) – If True, REMOVE the right spine. If False, KEEP it.

  • bottom (bool, default=False) – If True, REMOVE the bottom spine. If False, KEEP it.

  • left (bool, default=False) – If True, REMOVE the left spine. If False, KEEP it.

Examples

# Nature-style (remove top and right, keep left and bottom) - DEFAULT >>> Plotter.unset_spines(ax)

# Remove all spines (frameless plot) >>> Plotter.unset_spines(ax, top=True, right=True, bottom=True, left=True)

# Keep all spines >>> Plotter.unset_spines(ax, top=False, right=False, bottom=False, left=False)

# Only keep bottom spine (minimal style) >>> Plotter.unset_spines(ax, top=True, right=True, bottom=False, left=True)

Notes

The default settings (top=True, right=True) produce the classic “Nature” or “Science” journal style with only left and bottom spines.

static unset_ticks(ax, xticks: bool = False, yticks: bool = False, xlabel: bool = False, ylabel: bool = False, remove_labels_only: bool = True)[source]

Remove tick labels (and optionally tick marks) from the axis.

Useful for creating clean shared-axis plots where inner panels don’t need redundant tick labels.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • xticks (bool, default=False) – If True, REMOVE x-tick labels. If False, KEEP them.

  • yticks (bool, default=False) – If True, REMOVE y-tick labels. If False, KEEP them.

  • xlabel (bool, default=False) – If True, also REMOVE the x-axis label.

  • ylabel (bool, default=False) – If True, also REMOVE the y-axis label.

  • remove_labels_only (bool, default=True) – If True, only remove the text labels, keeping tick marks visible. If False, remove both the tick marks and labels.

Examples

# Remove x-tick labels for stacked plots with shared x-axis >>> for ax in axes[:-1]: # All except bottom … Plotter.unset_ticks(ax, xticks=True, xlabel=True)

# Remove all tick labels (keep tick marks) >>> Plotter.unset_ticks(ax, xticks=True, yticks=True)

# Remove tick marks AND labels (completely clean) >>> Plotter.unset_ticks(ax, xticks=True, yticks=True, remove_labels_only=False)

# Remove y-ticks and y-axis label for side-by-side shared y-axis >>> for ax in axes[1:]: # All except leftmost … Plotter.unset_ticks(ax, yticks=True, ylabel=True)

Notes

This function is commonly used in combination with sharex/sharey in multi-panel figures to avoid redundant labels.

static unset_all(ax, spines: bool = True, ticks: bool = True, labels: bool = True)[source]

Completely strip an axis of spines, ticks, and labels.

Useful for image plots, heatmaps, or decorative panels where axis elements are not needed.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • spines (bool, default=True) – If True, remove all spines.

  • ticks (bool, default=True) – If True, remove all tick marks and labels.

  • labels (bool, default=True) – If True, remove axis labels.

Examples

# Completely clean axis (for images/heatmaps) >>> Plotter.unset_all(ax)

# Keep only spines (box around plot) >>> Plotter.unset_all(ax, spines=False)

static unset_ticks_and_spines(ax, xticks: bool = True, yticks: bool = True, top: bool = True, right: bool = True, bottom: bool = False, left: bool = False)[source]

Convenience method to remove both ticks and spines in one call.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • xticks (bool, default=True) – If True, REMOVE x-tick labels.

  • yticks (bool, default=True) – If True, REMOVE y-tick labels.

  • top (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • right (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • bottom (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • left (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

Examples

# Clean Nature-style with no tick labels >>> Plotter.unset_ticks_and_spines(ax)

# Only remove top/right spines, keep all ticks >>> Plotter.unset_ticks_and_spines(ax, xticks=False, yticks=False)

static set_formater(ax, formater='%.1e', axis='xy')[source]

Sets the formatter for the given axis on the plot. :param ax: The axis object on which to set the formatter. :type ax: object :param formater: The format string for the axis labels. Defaults to “%.1e”. :type formater: str, optional :param axis: The axis on which to set the formatter. Defaults to ‘xy’. :type axis: str, optional

Returns:

None

static set_standard_formater(ax, axis='xy')[source]

Sets the formatter for the given axis on the plot. :param ax: The axis object on which to set the formatter. :type ax: object :param axis: The axis on which to set the formatter. Defaults to ‘xy’. :type axis: str, optional

Returns:

None

class GridBuilder(figsize=(10, 8))[source]

Bases: object

Builder class for creating complex figure layouts with nested grids.

Use this when you need different numbers of columns in different rows, or complex nested arrangements that can’t be achieved with a simple grid.

Parameters:

figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

Examples

Create a layout with varying column counts per row:

>>> builder = Plotter.GridBuilder(figsize=(12, 8))
>>> builder.add_row(ncols=1, height_ratio=1)     # Header row (1 panel)
>>> builder.add_row(ncols=3, height_ratio=2)     # Main row (3 panels)
>>> builder.add_row(ncols=2, height_ratio=1.5)   # Footer row (2 panels)
>>> fig, axes = builder.build(wspace=0.2, hspace=0.3)
>>> # axes = [[ax00], [ax10, ax11, ax12], [ax20, ax21]]

Access axes:

>>> header_ax = axes[0][0]
>>> main_left, main_center, main_right = axes[1]
>>> footer_left, footer_right = axes[2]
__init__(figsize=(10, 8))[source]
add_row(ncols: int, height_ratio: float = 1.0, width_ratios: List[float] = None)[source]

Add a row to the layout.

Parameters:
  • ncols (int) – Number of columns in this row.

  • height_ratio (float, default=1.0) – Relative height of this row compared to others.

  • width_ratios (list of float, optional) – Relative widths of columns within this row. If None, columns are equal width.

Returns:

self – For method chaining.

Return type:

GridBuilder

build(wspace: float = 0.2, hspace: float = 0.2, left: float = 0.1, right: float = 0.95, top: float = 0.95, bottom: float = 0.1)[source]

Build the figure with the specified layout.

Parameters:
  • wspace (float, default=0.2) – Horizontal space between columns within rows.

  • hspace (float, default=0.2) – Vertical space between rows.

  • left (float) – Figure margins (fraction of figure size).

  • right (float) – Figure margins (fraction of figure size).

  • top (float) – Figure margins (fraction of figure size).

  • bottom (float) – Figure margins (fraction of figure size).

Returns:

  • fig (matplotlib.figure.Figure) – The created figure.

  • axes (list of lists) – 2D list of axes, where axes[row][col] gives the axis at that position.

static make_grid(nrows: int, ncols: int, figsize: tuple = (10, 8), width_ratios: List[float] = None, height_ratios: List[float] = None, wspace: float = 0.2, hspace: float = 0.2, left: float = 0.1, right: float = 0.95, top: float = 0.95, bottom: float = 0.1, sharex: str = False, sharey: str = False, panel_labels: bool = False, panel_label_style: str = 'parenthesis', despine: bool = False)[source]

Create a figure with a grid of subplots with full control over layout.

This is the recommended method for creating publication-quality multi-panel figures with precise control over spacing and sizing.

Parameters:
  • nrows (int) – Number of rows.

  • ncols (int) – Number of columns.

  • figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

  • width_ratios (list of float, optional) – Relative widths of columns. Length must equal ncols. Example: [2, 1, 1] makes first column 2x wider.

  • height_ratios (list of float, optional) – Relative heights of rows. Length must equal nrows. Example: [1, 3] makes second row 3x taller.

  • wspace (float, default=0.2) – Horizontal space between columns (fraction of avg width).

  • hspace (float, default=0.2) – Vertical space between rows (fraction of avg height).

  • left (float) – Figure margins (0 to 1, fraction of figure size).

  • right (float) – Figure margins (0 to 1, fraction of figure size).

  • top (float) – Figure margins (0 to 1, fraction of figure size).

  • bottom (float) – Figure margins (0 to 1, fraction of figure size).

  • sharex (str or bool, default=False) – Share x-axis: ‘row’, ‘col’, ‘all’, or False.

  • sharey (str or bool, default=False) – Share y-axis: ‘row’, ‘col’, ‘all’, or False.

  • panel_labels (bool, default=False) – Add panel labels (a), (b), (c), etc.

  • panel_label_style (str, default='parenthesis') – Style for panel labels: ‘parenthesis’, ‘plain’, ‘bold’.

  • despine (bool, default=False) – Remove top and right spines (Nature-style).

Returns:

  • fig (matplotlib.figure.Figure) – The created figure.

  • axes (list of Axes) – Flat list of axes [ax0, ax1, ax2, …], row-major order.

Examples

Basic 2x3 grid:

>>> fig, axes = Plotter.make_grid(2, 3, figsize=(10, 6))
>>> ax0, ax1, ax2, ax3, ax4, ax5 = axes

Unequal column widths:

>>> fig, axes = Plotter.make_grid(1, 2, width_ratios=[3, 1])

Stacked panels with shared x-axis:

>>> fig, axes = Plotter.make_grid(3, 1, sharex='col', hspace=0.05)
>>> for ax in axes[:-1]:

… Plotter.unset_ticks(ax, xticks=True, xlabel=True)

Publication figure:

>>> fig, axes = Plotter.make_grid(2, 2, figsize=(8, 8),

… panel_labels=True, despine=True)

static figure(figsize: tuple = (10, 8), **kwargs) Figure[source]

Create a Matplotlib figure with specified size and options.

Parameters:
  • figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

  • **kwargs – Additional keyword arguments passed to plt.figure().

Returns:

The created figure object.

Return type:

matplotlib.figure.Figure

Examples

Basic figure creation:

>>> fig = Plotter.figure(figsize=(12, 6))

With additional options:

>>> fig = Plotter.figure(figsize=(8, 8), dpi=150, facecolor='white')
static get_grid(nrows: int, ncols: int, *, wspace: float = None, hspace: float = None, width_ratios: List[float] = None, height_ratios: List[float] = None, ax_sub=None, left: float = None, right: float = None, top: float = None, bottom: float = None, figure=None, **kwargs) GridSpec[source]

Create a GridSpec for flexible subplot layouts.

This is the foundation for creating complex multi-panel figures with control over panel sizes and spacing.

Parameters:
  • nrows (int) – Number of rows in the grid.

  • ncols (int) – Number of columns in the grid.

  • wspace (float, optional) – Width space between columns (0.0 to 1.0, fraction of average axis width). Recommended: 0.2-0.4 for labels, 0.05-0.1 for tight layouts.

  • hspace (float, optional) – Height space between rows (0.0 to 1.0, fraction of average axis height). Recommended: 0.2-0.4 for titles, 0.05-0.1 for tight layouts.

  • width_ratios (list of float, optional) – Relative widths of columns. E.g., [2, 1, 1] makes first column 2x wider. Length must equal ncols.

  • height_ratios (list of float, optional) – Relative heights of rows. E.g., [1, 2] makes second row 2x taller. Length must equal nrows.

  • ax_sub (SubplotSpec, optional) – If provided, creates a nested GridSpec within this subplot. Use for complex layouts with grids inside grids.

  • left (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • right (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • top (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • bottom (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • **kwargs – Additional arguments passed to GridSpec.

Returns:

The grid specification object.

Return type:

GridSpec or GridSpecFromSubplotSpec

Examples

Basic 2x3 grid:

>>> fig = plt.figure(figsize=(12, 8))
>>> gs  = Plotter.get_grid(2, 3, wspace=0.3, hspace=0.4)
>>> ax0 = fig.add_subplot(gs[0, 0])  # Row 0, Col 0
>>> ax1 = fig.add_subplot(gs[0, 1:]) # Row 0, Cols 1-2 (span)
>>> ax2 = fig.add_subplot(gs[1, :])  # Row 1, all columns (span)

Unequal widths (main panel + sidebar):

>>> gs = Plotter.get_grid(1, 2, width_ratios=[3, 1])
>>> # First column is 3x wider than second

Nested grid (inset layout):

>>> outer       = Plotter.get_grid(1, 2)
>>> ax_left     = fig.add_subplot(outer[0])
>>> inner       = Plotter.get_grid(2, 2, ax_sub=outer[1], wspace=0.1, hspace=0.1)
>>> ax_inner_00 = fig.add_subplot(inner[0, 0])

Control margins:

>>> gs = Plotter.get_grid(2, 2, left=0.1, right=0.95, top=0.95, bottom=0.1)

See also

get_grid_subplot

Create subplot from GridSpec index

get_subplots

High-level function for simple layouts

static get_grid_subplot(gs, fig, index, sharex=None, sharey=None, **kwargs)[source]

Create a subplot from a GridSpec at the specified index.

Parameters:
  • gs (GridSpec) – The GridSpec object.

  • fig (matplotlib.figure.Figure) – The figure to add the subplot to.

  • index (int, tuple, or slice) – Position in the grid. Can be: - int: Linear index (0, 1, 2, …) - tuple: (row, col) for single cell - slice/tuple with slices: For spanning multiple cells

  • sharex (Axes, optional) – Share axis with another subplot. Use for aligned multi-panel figures.

  • sharey (Axes, optional) – Share axis with another subplot. Use for aligned multi-panel figures.

  • **kwargs – Additional arguments passed to fig.add_subplot.

Returns:

The created subplot.

Return type:

matplotlib.axes.Axes

Examples

Single cell by linear index:

ax0 = Plotter.get_grid_subplot(gs, fig, 0)  # First cell
ax1 = Plotter.get_grid_subplot(gs, fig, 1)  # Second cell

Single cell by (row, col):

ax = fig.add_subplot(gs[1, 2])  # Row 1, Col 2

Span multiple cells:

ax_wide = fig.add_subplot(gs[0, :])   # Entire first row
ax_tall = fig.add_subplot(gs[:, 0])   # Entire first column
ax_block = fig.add_subplot(gs[0:2, 1:3])  # 2x2 block

Shared axes (for aligned panels):

ax0 = Plotter.get_grid_subplot(gs, fig, 0)
ax1 = Plotter.get_grid_subplot(gs, fig, 1, sharex=ax0)
ax2 = Plotter.get_grid_subplot(gs, fig, 2, sharex=ax0, sharey=ax0)
# ax1 and ax2 share x-axis with ax0; ax2 also shares y-axis
static get_grid_map(nrows: int, ncols: int) dict[source]

Generate a mapping from panel labels to grid indices.

Useful for referencing panels by name rather than index.

Parameters:
  • nrows (int) – Number of rows.

  • ncols (int) – Number of columns.

Returns:

Mapping with keys: - ‘by_index’: {0: (0,0), 1: (0,1), …} - ‘by_letter’: {‘a’: 0, ‘b’: 1, …} - ‘by_rowcol’: {(0,0): 0, (0,1): 1, …} - ‘grid’: 2D list of indices

Return type:

dict

Examples

>>> gmap = Plotter.get_grid_map(2, 3)
>>> gmap['by_letter']['c']  # Get index for panel 'c'
2
>>> gmap['by_index'][4]  # Get (row, col) for index 4
(1, 1)
>>> gmap['grid']  # 2D layout
[[0, 1, 2], [3, 4, 5]]
static configure_axes(ax, visible: bool = True, spines: bool | dict | str = True, ticks: bool | dict | str = True, tick_labels: bool | dict | str = True, xlabel: str = None, ylabel: str = None, title: str = None, xscale: str = None, yscale: str = None, xlim: tuple = None, ylim: tuple = None, fontsize: int = None, **kwargs)[source]

Configure axis visibility, spines, ticks, and labels in one call.

This is a convenience function for common axis customizations.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis to configure.

  • visible (bool, default=True) – If False, hide the entire axis (ax.axis(‘off’)).

  • spines (bool, dict, or str, default=True) – Control spine visibility: - True: Show all spines - False: Hide all spines - ‘left’: Hide all except left - ‘bottom’: Hide all except bottom - ‘minimal’: Hide top and right (Nature-style) - dict: {‘top’: False, ‘right’: False, …}

  • ticks (bool, dict, or str, default=True) – Control tick visibility: - True/False: Show/hide all ticks - ‘x’/’y’: Show only x/y ticks - dict: {‘x’: True, ‘y’: False}

  • tick_labels (bool, dict, or str, default=True) – Control tick label visibility (same format as ticks).

  • xlabel (str, optional) – Axis labels and title.

  • ylabel (str, optional) – Axis labels and title.

  • title (str, optional) – Axis labels and title.

  • xscale (str, optional) – Axis scale: ‘linear’, ‘log’, ‘symlog’.

  • yscale (str, optional) – Axis scale: ‘linear’, ‘log’, ‘symlog’.

  • xlim (tuple, optional) – Axis limits as (min, max).

  • ylim (tuple, optional) – Axis limits as (min, max).

  • fontsize (int, optional) – Font size for labels.

  • **kwargs – Additional arguments (e.g., labelpad).

Examples

Minimal style (no top/right spines):

Plotter.configure_axes(ax, spines='minimal', xlabel='Time', ylabel='Value')

Hide axis completely (for images/heatmaps):

Plotter.configure_axes(ax, visible=False)

Keep only left spine and y-ticks:

Plotter.configure_axes(ax, spines='left', ticks='y', tick_labels='y')

Log scale with custom limits:

Plotter.configure_axes(ax, yscale='log', ylim=(1e-6, 1e0))

Full configuration:

Plotter.configure_axes(
    ax,
    spines='minimal',
    xlabel=r'$x$ (nm)', ylabel=r'$\\rho$ (a.u.)',
    xscale='linear', yscale='log',
    xlim=(0, 100), ylim=(1e-3, 1),
    fontsize=12
)
static disable_axis(ax, which: str = 'both')[source]

Disable axis components for clean images/heatmaps.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis to modify.

  • which (str, default='both') – What to disable: - ‘both’: Disable x and y (full axis off) - ‘x’: Disable x-axis only - ‘y’: Disable y-axis only - ‘labels’: Keep ticks but hide labels - ‘ticks’: Keep labels but hide ticks - ‘spines’: Hide all spines

Examples

>>> Plotter.disable_axis(ax)  # Completely clean
>>> Plotter.disable_axis(ax, 'x')  # Keep y-axis
>>> Plotter.disable_axis(ax, 'labels')  # Keep ticks, no labels
static get_grid_ax(nrows: int, ncols: int, wspace: float = None, hspace: float = None, width_ratios: List[float] = None, height_ratios: List[float] = None, ax_sub=None, **kwargs) Tuple[GridSpec, list][source]

Get a GridSpec and an empty list for axes (convenience wrapper).

Parameters:
  • nrows (int) – Grid dimensions.

  • ncols (int) – Grid dimensions.

  • wspace (float, optional) – Spacing between subplots.

  • hspace (float, optional) – Spacing between subplots.

  • width_ratios (list, optional) – Relative sizes.

  • height_ratios (list, optional) – Relative sizes.

  • ax_sub (SubplotSpec, optional) – For nested grids.

  • **kwargs – Additional GridSpec arguments.

Returns:

(GridSpec, empty_axes_list)

Return type:

tuple

Examples

>>> gs, axes = Plotter.get_grid_ax(2, 3, wspace=0.3)
>>> for i in range(6):
...     Plotter.app_grid_subplot(axes, gs, fig, i)
static app_grid_subplot(axes: list, gs, fig, index: int, sharex=None, sharey=None, **kwargs)[source]

Append a subplot to an axes list (convenience method).

Parameters:
  • axes (list) – List to append the new axis to.

  • gs (GridSpec) – The GridSpec.

  • fig (Figure) – The figure.

  • index (int) – Grid index.

  • sharex (Axes, optional) – Share axes with another subplot.

  • sharey (Axes, optional) – Share axes with another subplot.

  • **kwargs – Additional arguments.

Examples

>>> gs, axes    = Plotter.get_grid_ax(2, 2)
>>> fig         = plt.figure()
>>> for i in range(4):
...     Plotter.app_grid_subplot(axes, gs, fig, i)
>>> # axes is now [ax0, ax1, ax2, ax3]
static twin_axis(ax, which='y', label='', color='C1', scale='linear', lim=None, fontsize=None, labelpad=0, **kwargs)[source]

Create a twin axis with a secondary scale.

Parameters:
  • ax (matplotlib.axes.Axes) – Primary axis.

  • which (str, default='y') – Which axis to twin: ‘y’ creates twinx(), ‘x’ creates twiny().

  • label (str, default='') – Label for the secondary axis.

  • color (str, default='C1') – Color for the secondary axis (spine, ticks, label).

  • scale (str, default='linear') – Scale for secondary axis: ‘linear’ or ‘log’.

  • lim (tuple, optional) – Limits for the secondary axis.

  • fontsize (int, optional) – Font size for the label.

  • labelpad (float, default=0) – Padding for the label.

  • **kwargs – Additional arguments passed to set_ylabel/set_xlabel.

Returns:

ax2 – The secondary axis.

Return type:

matplotlib.axes.Axes

Examples

>>> ax2 = Plotter.twin_axis(ax, which='y', label='Temperature (K)', color='red')
>>> Plotter.plot(ax2, x, temperature, color='red')
static power_law_guide(ax, x_range, exponent, *, add_label: bool = True, label=None, position='lower right', color='gray', ls='--', lw=1.5, offset_log=0, zorder=3, **kwargs)[source]

Add a power-law guide line to a log-log plot.

Useful for showing scaling behavior (e.g., y ~ x^{-2}).

Parameters:
  • ax (matplotlib.axes.Axes) – Axis with log-log scale.

  • x_range (tuple) – (x_start, x_end) for the guide line.

  • exponent (float) – Power-law exponent (slope in log-log).

  • label (str, optional) – Label (e.g., r’$\sim N^{-2}$’). If None, auto-generates.

  • position (str, default='lower right') – Where to anchor the line: ‘lower right’, ‘upper left’, etc.

  • color (str, default='gray') – Line color.

  • ls (str, default='--') – Line style.

  • lw (float, default=1.5) – Line width.

  • offset_log (float, default=0) – Vertical offset in log10 units.

  • **kwargs – Additional arguments passed to ax.plot.

Returns:

line – The guide line object.

Return type:

Line2D

Examples

>>> # Show y ~ x^{-2} scaling
>>> Plotter.power_law_guide(ax, (10, 1000), -2, label=r'$\\sim N^{-2}$')
static get_inset(ax, position=[0.0, 0.0, 1.0, 1.0], add_box=False, box_alpha=0.5, box_ext=0.02, facecolor='white', zorder=1, **kwargs)[source]

Create an inset axis within the given axis.

Parameters:
  • ax (matplotlib.axes.Axes) – The parent axis.

  • position (list) – [x0, y0, width, height] for the inset axis in relative coordinates.

  • add_box (bool, default=False) – Whether to add a semi-transparent box around the inset.

  • box_alpha (float, default=0.5) – Transparency of the box.

  • box_ext (float, default=0.02) – Extension of the box beyond the inset axis.

  • facecolor (str, default='white') – Face color of the box.

  • zorder (int, default=1) – Z-order of the inset axis.

  • **kwargs – Additional arguments passed to fig.add_axes.

  • Returns

  • ax2 (-)

static set_transparency(ax, alpha=0.0)[source]

Set the background patch transparency for an axis.

static set_legend(ax, fontsize=None, frameon: bool = False, loc: str = 'best', alignment: str = 'left', markerfirst: bool = False, framealpha: float = 1.0, reverse: bool = False, style=None, labelspacing: float = 0.5, handlelength: float = 1.5, handletextpad: float = 0.4, borderpad: float = 0.4, columnspacing: float = 1.0, ncol: int = 1, **kwargs)[source]

Sets the legend with a preferred style for publication-quality plots.

Parameters: - ax : Axis to which the legend will be added. - fontsize : Font size of the legend labels. - frameon : Whether to draw a frame around the legend. - loc : Location of the legend (‘best’, ‘upper right’, etc.). - alignment : Text alignment (‘left’, ‘center’, ‘right’). - markerfirst : Whether the marker or label appears first in the legend. - framealpha : Transparency of the legend frame (1.0 is opaque). - reverse : Reverse the order of legend items. - style : Predefined style for the legend (‘minimal’, ‘boxed’, etc.). - labelspacing : Vertical space between legend entries. - handlelength : Length of the legend markers. - handletextpad : Space between legend markers and text. - borderpad : Padding inside the legend box. - columnspacing : Spacing between legend columns. - ncol : Number of columns in the legend. - kwargs : Additional arguments passed to ax.legend().

static set_legend_custom(ax, conditions: list, fontsize=None, frameon=False, loc='best', alignment='left', markerfirst=False, framealpha=1.0, reverse=False, **kwargs)[source]

Set the legend with custom conditions for the labels - ax : axis to use - conditions: list of conditions - fontsize : fontsize - frameon : frame on or off - loc : location of the legend - alignment : alignment of the legend - markerfirst: marker first or not - framealpha: alpha of the frame

static get_subplots(nrows=1, ncols=1, sizex=10.0, sizey=10.0, sizex_def=3, sizey_def=3, annot_x_pos=None, annot_y_pos=None, panel_labels=False, single_if_1=False, share_x=False, share_y=False, width_ratios=None, height_ratios=None, constrained_layout=None, tight_layout=False, layout=None, mosaic=None, spans=None, named_panels=None, **kwargs) Tuple[Figure, AxesList][source]

Create subplot layouts and return a list-like AxesList wrapper.

Parameters:
  • nrows (int, default=(1, 1)) – Grid shape used for regular subplot creation and for spans.

  • ncols (int, default=(1, 1)) – Grid shape used for regular subplot creation and for spans.

  • sizex (float or sequence, default=10.0) – Figure width/height in inches, or ratio sequences per column/row.

  • sizey (float or sequence, default=10.0) – Figure width/height in inches, or ratio sequences per column/row.

  • sizex_def (float, default=3) – Inch scaling used when sizex/sizey are ratio sequences.

  • sizey_def (float, default=3) – Inch scaling used when sizex/sizey are ratio sequences.

  • annot_x_pos (float or sequence, optional) – Position(s) for panel label annotations in axes-fraction units.

  • annot_y_pos (float or sequence, optional) – Position(s) for panel label annotations in axes-fraction units.

  • panel_labels (bool or sequence, default=False) – If truthy, annotate each axis with labels (auto or user-provided).

  • single_if_1 (bool, default=False) – If True and only one axis is created, return that axis instead of an AxesList.

  • share_x (bool, default=False) – Share x/y axes across created panels.

  • share_y (bool, default=False) – Share x/y axes across created panels.

  • width_ratios (sequence, optional) – GridSpec ratios overriding ratios inferred from sizex/sizey.

  • height_ratios (sequence, optional) – GridSpec ratios overriding ratios inferred from sizex/sizey.

  • constrained_layout (bool, optional) – Explicitly control constrained layout engine.

  • tight_layout (bool, default=False) – Call fig.tight_layout() after creation (when compatible).

  • layout (str, optional) – Matplotlib layout engine name (e.g. 'constrained', 'tight').

  • mosaic (subplot-mosaic spec, optional) – Use plt.subplot_mosaic with named panels.

  • spans (dict, optional) – Named span panels on a regular grid. Example: {'main': (0, 2, 0, 3), 'side': (0, 2, 3, 4)}.

  • named_panels (sequence or dict, optional) – Panel aliases for regular grids or mosaic alias remapping.

  • **kwargs (dict) – Forwarded Matplotlib options. Common keys include: dpi, subplot_kw, gridspec_kw, hspace, wspace, left/right/top/bottom, grid, grid_kws, despine, axis_off, suptitle, suptitle_kws, post_hook.

Returns:

  • fig (matplotlib.figure.Figure) – Created figure.

  • axes (AxesList or matplotlib.axes.Axes) – AxesList wrapper (list-compatible, grid-aware, named-panel access). Returns single axis only when single_if_1=True.

Notes

AxesList supports: - list operations (iterate, append-like access, slicing) - grid indexing: axes[row, col] - named access: axes['main'] - helpers: row(), col(), span(), select(), apply()

Examples

Standard grid: fig, axes = Plotter.get_subplots(2, 3, sizex=9, sizey=5) axes[1, 2].plot(x, y)

Named aliases: fig, axes = Plotter.get_subplots(1, 3, named_panels=['left', 'mid', 'right']) axes['mid'].set_title('Center')

Mosaic: fig, axes = Plotter.get_subplots(mosaic=[['A', 'A', 'B'], ['C', 'D', 'D']]) axes['A'].plot(x, y)

Spans: fig, axes = Plotter.get_subplots(nrows=3, ncols=4, spans={'main': (0, 2, 0, 3), 'side': (0, 2, 3, 4), 'bottom': (2, 3, 0, 4)}) axes['main'].plot(x, y)

static subplots(*args, **kwargs)[source]

Alias of Plotter.get_subplots().

static subplot_mosaic(mosaic, *args, **kwargs)[source]

Convenience alias for mosaic layouts.

Equivalent to: Plotter.get_subplots(mosaic=mosaic, *args, **kwargs)

static save_fig(directory: str, filename: str, format='pdf', dpi=200, adjust=True, fig=None, **kwargs)[source]

Save figure to a specific directory. - directory : directory to save the file - filename : name of the file - format : format of the file - dpi : dpi of the file - adjust : adjust the figure

static savefig(directory, filename, format, dpi, adjust, fig=None, **kwargs)[source]

Alias for save_fig() with the historical lowercase name.

static plot_heatmaps(dfs: list, colormap='viridis', cb_width=0.1, movefirst=True, index=None, columns=None, values=None, sortidx=True, zlabel='', sizemult=3, xvals=True, yvals=True, vmin=None, vmax=None, **kwargs)[source]

Plot a sequence of pivoted DataFrame heatmaps on a shared figure.

class general_python.common.PlotterSave[source]

Bases: object

File-output helpers for simple plot-adjacent data artifacts.

static dict2json(directory: str, fileName: str, data)[source]

Save dictionary to json file - directory : directory to save the file - fileName : name of the file - data : dictionary to save

static json2dict(directory: str, fileName: str) dict[source]

Load dictionary from json file

static json2dict_multiple(directory: str, keys: list)[source]

Based on the specified keys, load the dictionaries from the json files The keys are the names of the files as well!

static singleColumnData(directory: str, fileName: str, y, typ='.npy')[source]

Stores the values as a single vector

static twoColumnsData(directory: str, fileName: str, x, y, typ='.npy')[source]

Stores the x, y vectors in 2D form (multiple rows and two columns)

static matrixData(directory: str, fileName: str, x, y, typ='.npy')[source]

Stores the x, y vectors in matrix form (appending single column at start for x values)

static app_df(df, colname: str, y, fill_value=nan)[source]

Appends the data to the dataframe.

Parameters: - df (pd.DataFrame): The dataframe to append data to. - colname (str): The column name to append data under. - y (array-like): The data to append. - fill_value: The value to use for filling if resizing is needed.

static app_array(arr, y)[source]

Appends the data to a numpy array.

Parameters: - arr (np.ndarray): The numpy array to append data to. - y (np.ndarray): The data to append.

Returns: - np.ndarray: The updated numpy array with appended data.

class general_python.common.MatrixPrinter[source]

Bases: object

Class for printing matrices and vectors

static print_matrix(matrix: ndarray)[source]

Prints the matrix in a nice form

static print_vector(vector: ndarray)[source]

Prints the vector in a nice form

static print_matrices(matrices: list)[source]

Prints a list of matrices in a nice form

static print_vectors(vectors: list)[source]

Prints a list of vectors in a nice form

class general_python.common.DataHandler[source]

Bases: object

DataHandler class provides static methods for handling and processing data arrays, including filtering, interpolating, aggregating, concatenating, and cutting matrices based on specific criteria. .. method:: _filter_typical_values(current_x, current_y, typical, threshold=1.0) -> tuple

_initialize_combined_arrays(y_list, x_list, typical, threshold=1.0) tuple[source]

Initializes and combines arrays from given lists. If the typical flag is set to True, it filters the combined arrays to include only elements where the values in y_combined are less than the threshold.

_interpolate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple[source]
_aggregate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple[source]

Aggregates and updates combined x and y data arrays with current x and y data arrays by summing common bins and appending unique bins.

concat_and_average(y_list, x_list, typical=False, use_interpolation=True, threshold=1.0) tuple[source]
concat_and_fill(y_list, x_list, lengths, missing_val=np.nan) tuple[source]
cut_matrix_bad_vals_zero(M, axis=0, tol=1e-9, check_limit

float | None = 10) -> np.ndarray: Cuts off the slices (along any specified axis) in matrix M where all elements are close to zero.

cut_matrix_bad_vals(M, axis=0, threshold=-1e4, check_limit=None) np.ndarray[source]

Cuts off the rows or columns in matrix M where the first check_limit elements are all below a threshold.

static concat_and_average(y_list, x_list, typical=False, use_interpolation=True, threshold=1.0)[source]

Concatenates and averages y values across multiple histograms.

:param y_list : List of y matrices (each one corresponding to a realization). :param x_list : List of x vectors (each one corresponding to a realization). :param typical : If True, filter y values less than 1.0. :param use_interpolation: If True, interpolate y values for non-matching bins.

If False, aggregate only exact matches and append unique bins.

:param threshold : The threshold value for filtering y values (default: 1.0). :returns : Combined y values and x bins after averaging.

static concat_and_fill(y_list, x_list, lengths, missing_val=nan)[source]

Concatenates y values across multiple histograms, combines x vectors into a single sorted array, and fills missing values.

Parameters:
  • y_list – List of y arrays (each one corresponding to a realization).

  • x_list – List of x arrays (each one corresponding to a realization group).

  • lengths – List indicating how many y arrays correspond to each x array.

  • missing_val – Value to fill for missing data points after interpolation (default: np.nan).

Returns:

A 2D NumPy array of y values interpolated to a common x grid and the combined x bins.

static cut_matrix_bad_vals_zero(M, axis=0, tol=1e-09, check_limit: float | None = 10)[source]

Cut off the slices (along any specified axis) in matrix M where all elements are close to zero. If a 1D vector is provided, it returns the vector unless all elements are close to zero, in which case it returns an empty array.

Parameters: - M (numpy.ndarray) : The input matrix or vector. - axis (int) : The axis along which to check for zero elements.

For example, 0 for rows, 1 for columns, etc. Ignored if M is a 1D vector.

  • tol (float) : The tolerance for considering elements as zero.

  • check_limit (int) : The maximum number of elements along the axis to check for zeros.

Returns: - numpy.ndarray: The resulting matrix after removing slices (along the specified axis)

that are close to zero, or the vector after removing if all elements are close to zero.

static cut_matrix_bad_vals(M, axis=0, threshold=-10000.0, check_limit=None)[source]

Cut off the rows or columns in matrix M where the first check_limit elements are all below a threshold.

Parameters: - M (numpy.ndarray): The input matrix. - axis (int): The axis along which to check for elements below the threshold (0 for rows, 1 for columns). - threshold (float): The threshold value. - check_limit (int, optional): The number of elements to check from each row or column.

Returns: - numpy.ndarray: The resulting matrix after removing rows or columns where the first check_limit elements are below the threshold.

general_python.common.ctz64(x: uint64) int64[source]

Count trailing zeros in a 64-bit unsigned integer (Numba-safe).

Returns 64 if x == 0 (no bits set). Uses binary search - O(log bits).

Parameters:

x – 64-bit unsigned integer

Returns:

Number of trailing zero bits (0-64)

Example

>>> ctz64(np.uint64(8))  # 0b1000 -> 3 trailing zeros
3
general_python.common.popcount64(x: uint64) int64[source]

Count number of set bits in a 64-bit integer (Numba-safe).

Uses parallel bit-counting algorithm - O(1).

Parameters:

x – 64-bit unsigned integer

Returns:

Number of set bits (0-64)

Example

>>> popcount64(np.uint64(0b1011))
3
general_python.common.mask_from_indices(idxs: ndarray) uint64[source]

Convert array of bit indices to a bitmask (Numba-safe).

Parameters:

idxs – Array of indices (int64) indicating which bits to set

Returns:

64-bit mask with bits set at given indices

Example

>>> mask_from_indices(np.array([0, 2, 3], dtype=np.int64))
np.uint64(13)  # 0b1101
general_python.common.indices_from_mask(mask: uint64) ndarray[source]

Convert bitmask to array of set bit indices (Numba-safe).

Returns indices in ascending order. Uses ctz64 for efficiency.

Parameters:

mask – 64-bit mask

Returns:

Array of indices (int64) where bits are set

Example

>>> indices_from_mask(np.uint64(13))  # 0b1101
array([0, 2, 3], dtype=int64)
general_python.common.complement_mask(mask: uint64, ns: int) uint64[source]

Return the complement of a mask within ns bits.

Parameters:
  • mask – Original bitmask

  • ns – Number of bits in the system (1-64)

Returns:

Complement mask (bits flipped within range [0, ns))

general_python.common.complement_indices(n: int, indices: ndarray) ndarray[source]

Return indices in [0..n) not in indices.

O(n) boolean scratch, minimal allocations.

Parameters:
  • n – Upper bound of the range (exclusive)

  • indices – Input indices to exclude

Returns:

Array of complementary indices (sorted)

Example

>>> complement_indices(5, np.array([1, 3]))
array([0, 2, 4], dtype=int64)
class general_python.common.HDF5Manager[source]

Bases: object

A class encapsulating methods for reading, writing, and processing HDF5 files. Methods include:

  • load_file_data: Read data from a single HDF5 file.

  • stream_key_from_loaded_files: Generator to yield specific dataset from loaded data.

static load_file_data(file_path: str, dataset_keys: List[str] | None = None, verbose: bool = False, remove_corrupted_file: bool = False, strict_keys: bool = True, include_missing_keys: bool = False, missing_value: Any = nan) Dict[str, Any][source]

Reads data from an HDF5 file.

  • If dataset_keys is provided:
    • strict_keys=True -> skip file entirely if any key is missing

    • strict_keys=False -> load only available keys

    • include_missing_keys=True -> attach missing keys with missing_value

static read_hdf5(file_path, keys=None, verbose=False, remove_bad=False)[source]
static stream_key_from_loaded_files(loaded_hdf5_data_list: List[Dict[str, Any]], key: str) Generator[ndarray, None, None][source]

Yields data for a specific key from a list of already loaded HDF5 data dictionaries. Each dictionary in loaded_hdf5_data_list is expected to be the output of ‘load_file_data’.

Parameters:
  • loaded_hdf5_data_list – List of dictionaries, where each dict contains data from an HDF5 file.

  • key – The dataset key to extract from each dictionary.

Yields:

numpy.ndarray – The dataset corresponding to the key.

static concatenate_key_from_loaded_files(loaded_hdf5_data_list: List[Dict[str, Any]], key: str, concat_axis: int = 0, target_shape_axis: int | None = None, allow_padding: bool = False, is_vector: bool = False, clean_zeros_params: Dict[str, Any] | None = None, clean_threshold_params: Dict[str, Any] | None = None, verbose: bool = False) ndarray[source]

Concatenates a specific dataset key from a list of loaded HDF5 data. Handles shape mismatches by padding (if enabled) or skipping.

Parameters:
  • loaded_hdf5_data_list – List of dictionaries (from ‘load_file_data’).

  • key – Dataset key to extract and concatenate.

  • concat_axis – Axis along which to concatenate arrays.

  • target_shape_axis – If specified, datasets must have the same size along this axis as the first valid dataset found. Usually this is ‘concat_axis’.

  • allow_padding – If True, pads/truncates datasets to match the target shape on ‘target_shape_axis’.

  • is_vector – If True and data is 2D (1, N), flatten to 1D (N,).

  • clean_zeros_params – Optional dict of parameters for ‘clean_data_remove_zeros’.

  • clean_threshold_params – Optional dict of parameters for ‘clean_data_remove_thresholded’.

  • verbose – If True, log detailed information.

Returns:

A concatenated numpy array. Returns an empty array if no data is found or processed.

static stream_data_from_multiple_files(file_paths: List[str], dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False, strict_keys: bool = True, include_missing_keys: bool = True, missing_value: Any = nan) Generator[Dict[str, Any], None, None][source]

Streams data dictionary (from ‘load_file_data’) for each HDF5 file found in specified paths.

Parameters:
  • file_paths – List of HDF5 file paths.

  • dataset_keys – Specific dataset keys to load from each file. If None, loads all datasets. If a single key, it will be converted to a list.

  • sort_files (Whether to sort the found files by name.)

  • verbose (If True, log detailed information.)

  • strict_keys (If True, skip files that are missing any of the specified keys.)

  • include_missing_keys (If True, include missing keys with a specified value instead of skipping.)

  • missing_value (The value to use for missing keys if include_missing_keys is True.)

  • Yields – dict: Data dictionary from each processed HDF5 file.

static load_data_from_multiple_files(file_paths: List[str] | list[str], dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False, strict_keys: bool = True, include_missing_keys: bool = True, missing_value: Any = nan) List[Dict[str, Any]][source]

Loads data from multiple HDF5 files into a list of dictionaries. Eager evaluation. (This was ‘read_multiple_hdf5’ before)

static stream_lazy_from_multiple_files(file_paths: List[str] | list[str], dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False) Generator[Any, None, None][source]

Stream lazy HDF5 entries, optionally filtered by dataset key presence.

Notes

  • Does not read datasets eagerly.

  • If dataset_keys is provided, files are kept only when all listed keys are present (checked via lazy key listing).

static load_lazy_from_multiple_files(file_paths: List[str] | list[str], dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False) List[Any][source]

Eagerly collect lazy HDF5 entries into a list.

static save_data_to_file(directory: str | Directories, filename: str, data_to_save: ndarray | List[ndarray] | Dict[str, ndarray], target_shape: Tuple[int, ...] | None = None, dataset_names_config: List[str] | str | None = None, overwrite: bool = True, *args, data: Dict[str, Any] | None = None)[source]

Saves data to an HDF5 file.

Parameters:
  • directory – Directory to save the file.

  • filename – Name of the HDF5 file (extension .h5 or .hdf5 will be ensured).

  • data_to_save – Data to save. Can be a single np.ndarray, a list of np.ndarrays, or a dictionary {name: np.ndarray}.

  • target_shape – If specified, datasets will be reshaped to this shape before saving. dataset_names_config: Names for datasets if ‘data_to_save’ is a list/ndarray. If a string, used as a prefix. If a list, used as names.

  • overwrite – If True (default), overwrites the file if it exists.

static append_data_to_file(directory: str, filename: str, new_data: ndarray | List[ndarray] | Dict[str, ndarray], dataset_names_config: List[str] | str | None = None, overwrite_existing_datasets: bool = True, allow_dataset_creation: bool = True, *, data: Dict[str, Any] | None = None)[source]

Appends data to an existing HDF5 file or creates it if it doesn’t exist.

Parameters:
  • directory – Directory of the HDF5 file.

  • filename – Name of the HDF5 file.

  • new_data – Data to append.

  • dataset_names_config – Names for datasets if ‘new_data’ is list/ndarray.

  • overwrite_existing_datasets

    • If True and dataset exists, it’s deleted and recreated.

    • If False and dataset exists, data is appended (row-wise).

    Requires dataset to be resizable.

  • allow_dataset_creation – If True, new datasets are created if they don’t exist.

static update_fields_in_file(directory: str | Directories, filename: str, data_to_update: ndarray | List[ndarray] | Dict[str, ndarray], target_shape: Tuple[int, ...] | None = None, dataset_names_config: List[str] | str | None = None, create_if_missing: bool = True)[source]

Updates only provided datasets in an HDF5 file and keeps all other datasets unchanged.

Example

Existing file has ‘/a’ and ‘/b’. Calling with data_to_update={‘/a’: new_data} updates only ‘/a’ and leaves ‘/b’ intact.

Parameters:
  • directory – Directory where the HDF5 file is located.

  • filename – Name of the HDF5 file.

  • data_to_update – Data to update. Can be dict {name: array}, list of arrays, or single ndarray.

  • target_shape – If provided, each updated dataset is reshaped before writing.

  • dataset_names_config – Names for datasets when data_to_update is list/ndarray.

  • create_if_missing – If True, creates a missing file and/or missing dataset paths as needed. If False, update only existing datasets in an existing file.

static save_hdf5(directory: str | Directories, filename: str, data_to_save: ndarray | List[ndarray] | Dict[str, ndarray], target_shape: Tuple[int, ...] | None = None, dataset_names_config: List[str] | str | None = None, overwrite: bool = True, *args, data: Dict[str, Any] | None = None)

Saves data to an HDF5 file.

Parameters:
  • directory – Directory to save the file.

  • filename – Name of the HDF5 file (extension .h5 or .hdf5 will be ensured).

  • data_to_save – Data to save. Can be a single np.ndarray, a list of np.ndarrays, or a dictionary {name: np.ndarray}.

  • target_shape – If specified, datasets will be reshaped to this shape before saving. dataset_names_config: Names for datasets if ‘data_to_save’ is a list/ndarray. If a string, used as a prefix. If a list, used as names.

  • overwrite – If True (default), overwrites the file if it exists.

static append_hdf5(directory: str, filename: str, new_data: ndarray | List[ndarray] | Dict[str, ndarray], dataset_names_config: List[str] | str | None = None, overwrite_existing_datasets: bool = True, allow_dataset_creation: bool = True, *, data: Dict[str, Any] | None = None)

Appends data to an existing HDF5 file or creates it if it doesn’t exist.

Parameters:
  • directory – Directory of the HDF5 file.

  • filename – Name of the HDF5 file.

  • new_data – Data to append.

  • dataset_names_config – Names for datasets if ‘new_data’ is list/ndarray.

  • overwrite_existing_datasets

    • If True and dataset exists, it’s deleted and recreated.

    • If False and dataset exists, data is appended (row-wise).

    Requires dataset to be resizable.

  • allow_dataset_creation – If True, new datasets are created if they don’t exist.

static update_hdf5(directory: str | Directories, filename: str, data_to_update: ndarray | List[ndarray] | Dict[str, ndarray], target_shape: Tuple[int, ...] | None = None, dataset_names_config: List[str] | str | None = None, create_if_missing: bool = True)

Updates only provided datasets in an HDF5 file and keeps all other datasets unchanged.

Example

Existing file has ‘/a’ and ‘/b’. Calling with data_to_update={‘/a’: new_data} updates only ‘/a’ and leaves ‘/b’ intact.

Parameters:
  • directory – Directory where the HDF5 file is located.

  • filename – Name of the HDF5 file.

  • data_to_update – Data to update. Can be dict {name: array}, list of arrays, or single ndarray.

  • target_shape – If provided, each updated dataset is reshaped before writing.

  • dataset_names_config – Names for datasets when data_to_update is list/ndarray.

  • create_if_missing – If True, creates a missing file and/or missing dataset paths as needed. If False, update only existing datasets in an existing file.

static file_list_matching(directories: List | Directories | str, *args, conditions: List[Callable] = [], check_hdf5_condition: bool = True, as_string: bool = True)[source]

Returns a list of HDF5 files in the specified directories matching given conditions. :param directories: A list of directory paths (str) or Directories objects, or a single one. :param *args: Additional arguments passed to Directories constructor if directories are str. :param conditions: A list of callables that take a filename and return True if it matches the condition. :param check_hdf5_condition: If True (default), adds a condition to only include files ending with .h5 or .hdf5. :param as_string: If True (default), returns file paths as strings. If False, returns as Path objects.

Returns:

A sorted list of file paths matching the conditions.

static stream_data_from_multiple_folders(directory_paths: List[Directories], file_conditions: List[Any] | None = None, dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False) Generator[Dict[str, Any], None, None][source]

Streams data dictionary (from ‘load_file_data’) for each HDF5 file found in specified directories.

Parameters:
  • directory_paths – List of directories to search for HDF5 files.

  • file_conditions – Conditions passed to ‘Directories.listDirs’ for filtering files.

  • dataset_keys – Specific dataset keys to load from each file.

  • sort_files – Whether to sort the found files by name.

  • verbose – If True, log detailed information.

Yields:

dict – Data dictionary from each processed HDF5 file.

static load_data_from_multiple_folders(directory_paths: List[str], file_conditions: List[Any] | None = None, dataset_keys: List[str] | None = None, sort_files: bool = True, verbose: bool = False) List[Dict[str, Any]][source]

Loads data from multiple HDF5 files into a list of dictionaries. Eager evaluation. (This was ‘read_multiple_hdf5l’ before)

static load_and_concatenate_key_from_folders(directory_paths: List[str], key_to_extract: str, file_conditions: List[Any] | None = None, concat_axis: int = 0, target_shape_axis: int | None = None, allow_padding: bool = False, is_vector: bool = False, clean_zeros_params: Dict[str, Any] | None = None, clean_threshold_params: Dict[str, Any] | None = None, sort_files: bool = True, verbose: bool = False) ndarray[source]

Reads a specific dataset key from multiple HDF5 files found in directories and concatenates them. (This was ‘read_hdf5_extract_and_concat’ before)

static load_and_concatenate_key_per_directory(list_of_directory_paths: List[str], key_to_extract: str, file_conditions: List[Any] | None = None, concat_axis: int = 0, target_shape_axis: int | None = None, allow_padding: bool = False, is_vector: bool = False, clean_zeros_params: Dict[str, Any] | None = None, clean_threshold_params: Dict[str, Any] | None = None, verbose: bool = False) List[ndarray][source]

For each directory in ‘list_of_directory_paths’, loads and concatenates data for ‘key_to_extract’. Returns a list of concatenated numpy arrays, one for each input directory. (This was ‘read_hdf5_extract_and_concat_list’ before)

static process_data(data, keys: str | list[str] | tuple[str, ...], throw_if_bad: bool = False, unpack: bool = True, expected_ndim: int | None = None, expected_dim0: int | None = None, expected_dim1: int | None = None, expected_first_val: Any = None, return_skipped: bool = False) ndarray | tuple[ndarray, list[str]][source]

Collects arrays from iterable of mappings and concatenates them robustly.

Parameters:
  • data – iterable of dict-like objects

  • keys – key or list of possible keys to try (first available is used)

  • throw_if_bad – whether to throw if no valid arrays are found

  • unpack – whether to flatten nested arrays along first axis

  • expected_ndim – enforce specific ndim (1 or 2). If None, auto-infer from first valid array. This is stricter than just checking consistency.

  • expected_dim0 – enforce first dimension length (skip mismatches)

  • expected_dim1 – enforce second dimension length (skip mismatches)

  • expected_first_val – if not None, enforce that the first value of the array is close to this (skip mismatches)

  • return_skipped – if True, return (array, skipped_filenames)

Returns:

  • - ndarray (default)

  • - (ndarray, skipped_filenames) if return_skipped=True

  • Example

  • >>> energies = HDF5Manager.process_data(data, “energies”)

  • >>> energies = HDF5Manager.process_data(data, [“energies”, “E”], expected_ndim=1)

  • >>> obs = HDF5Manager.process_data(data, “observables”, expected_ndim=2, expected_dim1=4)

static clean_data_remove_zeros(matrix: ndarray, axis: int = 0, tolerance: float = 1e-09, check_limit: int | None = 10) ndarray[source]

Removes slices (e.g., rows or columns) from a matrix where initial elements are all close to zero. For 1D vector, removes elements close to zero from the beginning up to check_limit.

Parameters:
  • matrix – Input numpy array.

  • axis – Axis along which to check for zero elements and remove slices.

  • tolerance – Tolerance for considering an element as zero.

  • check_limit – Max number of elements along the slice (or vector) to check. If None, checks all elements in the slice.

Returns:

Cleaned numpy array.

static clean_data_remove_thresholded(matrix: ndarray, axis: int = 0, threshold: float = -10000.0, check_limit: int | None = None) ndarray[source]

Removes slices from a matrix where initial elements are all below a threshold. Improved to handle any axis using np.moveaxis.

Parameters:
  • matrix – Input numpy array.

  • axis – Axis along which to check and remove slices.

  • threshold – Threshold value. Slices are removed if all checked elements are < threshold.

  • check_limit – Max number of elements along the slice to check. If None, checks all.

Returns:

Cleaned numpy array.

static process_file_content(source_directory: str, source_filename: str, key_map: Dict[str, str] | None = None, clean_zeros_axis: int | None = None, clean_values_axis: int | None = None, clean_check_limit: int = 10, output_directory: str | None = None, verbose: bool = False)[source]

Loads an HDF5 file, optionally renames keys, cleans data, and saves it. (This was ‘change_h5_bad’ before)

Parameters:
  • source_directory – Directory of the source HDF5 file.

  • source_filename – Filename of the source HDF5 file.

  • key_map – Dictionary to rename dataset keys {old_key: new_key}.

  • clean_zeros_axis – Axis for ‘clean_data_remove_zeros’.

  • clean_values_axis – Axis for ‘clean_data_remove_thresholded’.

  • clean_check_limit – ‘check_limit’ for cleaning functions.

  • output_directory – Directory to save the processed file. If None, overwrites original.

  • verbose – If True, log detailed information.

static batch_process_files_in_dirs(source_directories: List[str], file_conditions: List[Any] | None = None, key_map: Dict[str, str] | None = None, clean_zeros_axis: int | None = None, clean_values_axis: int | None = None, clean_check_limit: int = 10, output_directory_base: str | None = None, is_test_run: bool = False, verbose: bool = False, exception_handler: Callable[[Exception, str], None] | None = None)[source]

Processes multiple HDF5 files across directories. (This was ‘change_h5_bad_dirs’ before)

static average_histograms(y_arrays_list: List[ndarray], x_arrays_list: List[ndarray], filter_y_lt_one: bool = False, use_interpolation: bool = True) Tuple[ndarray, ndarray][source]

Combines and averages y-values (e.g., histogram counts) across multiple series, aligning them by their x-values (e.g., bin centers).

Parameters:
  • y_arrays_list – List of Y-value arrays.

  • x_arrays_list – List of corresponding X-value arrays (bins).

  • filter_y_lt_one – If True, y-values < 1.0 (and corresponding x) are filtered out before averaging. (Original ‘typical’ parameter).

  • use_interpolation

    • If True, interpolates Y-values onto a common X-grid.

    • If False, aggregates only at exact X-matches and appends unique X-bins.

Returns:

Tuple (y_combined_averaged, x_common_grid).

static align_and_fill_histograms(y_arrays_list: List[ndarray], x_arrays_list: List[ndarray], group_lengths: List[int], fill_value: float = nan) Tuple[List[ndarray], ndarray][source]

Aligns multiple y-value series (histograms) to a common x-grid by interpolation, filling values for x-points not present in an original series.

Parameters:
  • y_arrays_list – A list where each element can be a list of y-arrays (if multiple realizations share an x_array) or a single y-array. The structure should align with group_lengths. Example: [[y1_real1, y1_real2], [y2_real1]]

  • x_arrays_list – List of x-value arrays (bins), one for each group of y_arrays. Example: [x1_bins, x2_bins]

  • group_lengths – List indicating how many y-arrays in ‘y_arrays_list’ correspond to each x-array in ‘x_arrays_list’. Example: [2, 1] means y_arrays_list[0] (a list of 2 y-arrays) uses x_arrays_list[0], and y_arrays_list[1] (a list of 1 y-array) uses x_arrays_list[1]. If y_arrays_list elements are single y-arrays, then group_lengths would be [1, 1, …].

  • fill_value – Value used for points in the common x-grid that are outside an original series’ x-range.

Returns:

Tuple (y_aligned_all, x_common_grid). y_aligned_all: List of 1D numpy arrays, each y-series interpolated to x_common_grid.

class general_python.common.LazyDataEntry(filepath: str, params: Dict[str, Any])[source]

Bases: object

Base class for lazy data entries.

__init__(filepath: str, params: Dict[str, Any])[source]

Initialization of a lazy data entry.

filepath
filename
params
__getitem__(key: str)[source]

Access a dataset by key, loading it if not already cached.

get(key: str, default=None)[source]

Get a dataset by key, returning default if not found.

keys()[source]

Return available dataset keys, using known keys if available, otherwise loading all data.

values()[source]

Return loaded dataset values, loading all data if not already cached.

items()[source]

Return loaded dataset items, loading all data if not already cached.

load(keys: str | Iterable[str] | None = None)[source]

Load one key, multiple keys, or all keys into cache.

is_loaded(key: str | None = None) bool[source]

Check if a specific key or any key is loaded in the cache.

clear_cache(keys: str | Iterable[str] | None = None)[source]

Clear cached data for one key, multiple keys, or all keys.

require(key: str)[source]

Strict key accessor, explicit in client code.

get_many(keys: Iterable[str], default=None) Dict[str, Any][source]

Get multiple datasets by keys, returning a dict of key-value pairs, using default for missing keys.

to_dict(keys: Iterable[str] | None = None, copy: bool = False) Dict[str, Any][source]

Return a dict of datasets for specified keys, loading them if necessary. If keys is None, return all datasets. If copy is True, return a new dict; otherwise return the internal cache dict (which may be shared).

as_array(key: str | None = None)[source]
shape(key: str | None = None)[source]
dtype(key: str | None = None)[source]
load_all()[source]
class general_python.common.LazyHDF5Entry(filepath: str, params: Dict[str, Any])[source]

Bases: LazyDataEntry

Lazy loader for HDF5 datasets.

load_all()[source]
shape(key: str | None = None)[source]
dtype(key: str | None = None)[source]
class general_python.common.LazyNpzEntry(filepath: str, params: Dict[str, Any])[source]

Bases: LazyDataEntry

Lazy loader for .npz files.

load_all()[source]
class general_python.common.LazyPickleEntry(filepath: str, params: Dict[str, Any])[source]

Bases: LazyDataEntry

Lazy loader for .pkl/.pickle files.

load_all()[source]
class general_python.common.LazyJsonEntry(filepath: str, params: Dict[str, Any])[source]

Bases: LazyDataEntry

Lazy loader for .json files.

load_all()[source]
general_python.common.load_results(data_dir: str, *, filters: dict | Callable[[Any], bool] | None = None, lx=None, ly=None, lz=None, Ns=None, post_process_func: Callable[[dict], None] | None = None, get_params_func: Callable | None = None, logger: Logger = None, recursive: bool = True, sort_files: bool = True, **kwargs) ResultSet[source]

Load lazy entries from a directory (or single file) and apply filters.

This function scans for supported files, extracts parameters from filenames and paths, creates lazy entries, and applies filtering based on provided conditions.

It also supports optional post-processing of parameters and custom parameter extraction for filtering.

Parameters:
  • data_dir – Directory path (or single file) to scan for results.

  • filters

    • None: return all

    • callable: filters(result) -> bool

    • dict: {param: condition} where condition supports:
      • scalar exact match (numeric with tolerance)

      • list/tuple/set membership

      • tuple operators: (‘eq’|’neq’|’lt’|’le’|’gt’|’ge’, value)

      • range operator: (‘between’, (min, max))

      • membership operators: (‘in’|’not_in’, [v1, …])

      • string contains: (‘contains’, ‘sub’)

      • callable: lambda param_value, params: …

  • lx – Optional shortcuts for filtering by common size parameters. If provided, they are applied as additional filters on top of filters.

  • ly – Optional shortcuts for filtering by common size parameters. If provided, they are applied as additional filters on top of filters.

  • lz – Optional shortcuts for filtering by common size parameters. If provided, they are applied as additional filters on top of filters.

  • Ns – Optional shortcuts for filtering by common size parameters. If provided, they are applied as additional filters on top of filters.

  • post_process_func – Optional function f(params: dict) -> None that can modify the extracted parameters in-place. Useful for computing derived parameters or converting units before filtering.

  • get_params_func – Optional function f(result) -> dict to extract parameters from a result entry for filtering. If None, the function will look for a .params attribute or use the entry itself if it’s a dict.

  • logger – Optional logger for progress and error messages. Should have methods like logger.info(msg, color=…) and logger.error(msg, color=…).

  • recursive – Whether to scan directories recursively. Default is True.

  • sort_files – Whether to sort the list of files before processing. Default is True.

  • **kwargs – Additional keyword arguments for future extensions or specific filtering needs.

Returns:

List-like container of lazy entries with convenience methods such as filtered(...), show(...), and show_filtered(...).

Return type:

ResultSet

general_python.common.filter_results(results: Iterable[Any], filters: dict | Callable[[Any], bool] | None = None, get_params_fun: Callable | None = None, *, tol: float = 1e-09) ResultSet[source]

Filter result entries by parameter conditions.

Parameters:
  • results – Iterable of result-like objects (with .params or dict-like).

  • filters

    • None: return all

    • callable: filters(result) -> bool

    • dict: {param: condition} where condition supports:
      • scalar exact match (numeric with tolerance)

      • list/tuple/set membership

      • tuple operators: (‘eq’|’neq’|’lt’|’le’|’gt’|’ge’, value)

      • range operator: (‘between’, (min, max))

      • membership operators: (‘in’|’not_in’, [v1, …])

      • string contains: (‘contains’, ‘sub’)

      • callable: lambda param_value, params: …

  • get_params_fun – Optional extractor f(result) -> dict.

  • tol – Numeric tolerance for equality-like checks.

Returns:

Filtered results as a ResultSet with the same get_params_fun and tol.

Return type:

ResultSet

class general_python.common.ResultSet(iterable: Iterable[Any] = (), *, get_params_fun: Callable | None = None, tol: float = 1e-09, name: str = 'results')[source]

Bases: list

List-like container for result entries with convenience query/preview methods.

Notes

  • Inherits from list to preserve normal list behavior.

  • Slice operations return ResultSet (not plain list).

__init__(iterable: Iterable[Any] = (), *, get_params_fun: Callable | None = None, tol: float = 1e-09, name: str = 'results')[source]
get_params_fun
tol
name
__getitem__(item)[source]

Override to return ResultSet for slices, preserving get_params_fun and tol.

copy() ResultSet[source]

Return a shallow copy of the list.

filtered(filters: dict | Callable[[Any], bool] | None = None, *, get_params_fun: Callable | None = None, tol: float | None = None) ResultSet[source]

Filter results based on provided conditions.

where(filters: dict | Callable[[Any], bool] | None = None, *, get_params_fun: Callable | None = None, tol: float | None = None) ResultSet[source]

Alias for filtered() to allow chaining like results.where(…).show(…)

param_values(key: str, *, default: Any = nan, get_params_fun: Callable | None = None) ndarray[source]

Extract an array of parameter values for a given key across all results.

unique(key: str, *, drop_nan: bool = True, get_params_fun: Callable | None = None) ndarray[source]

Get unique values of a parameter key across all results, with option to drop NaNs.

sort_by(key: str, *, reverse: bool = False, get_params_fun: Callable | None = None) ResultSet[source]

Return a new ResultSet sorted by a parameter key.

first(default: Any = None)[source]

Return the first entry or default if empty.

show(*, fields: Sequence[str] | None = None, limit: int = 20, sort_by: str | None = None, reverse: bool = False, include_filename: bool = True) ResultSet[source]

Display a tabular preview of the results in the console.

show_filtered(filters: dict | Callable[[Any], bool] | None = None, *, get_params_fun: Callable | None = None, tol: float | None = None, fields: Sequence[str] | None = None, limit: int = 20, sort_by: str | None = None, reverse: bool = False, include_filename: bool = True) ResultSet[source]
class general_python.common.PlotData[source]

Bases: object

Convenience helpers that work with Lazy* entries and ResultProxy.

static from_input(directory: str | None, data_values: np.ndarray | dict | list | tuple | None = None, x_parameters: List[float] = None, y_parameters: List[float] = None, *, x_param: str = 'x', y_param: str = 'y', data_key: str = 'default', filters: dict | Callable[[Any], bool] | None = None, logger: Logger = None, **kwargs) ResultSet[source]

Build a plot-ready result list from either filesystem data or in-memory arrays.

static from_match(results: List[LazyDataEntry | ResultProxy], x_param: str, y_param: str, x_val: float, y_val: float, tol: float = 1e-05) LazyDataEntry | ResultProxy | None[source]

Return the first entry matching two parameter values within tolerance.

static extract_parameter_arrays(filtered_results: List[LazyDataEntry | ResultProxy], x_param: str = 'J', y_param: str = 'hx', xlim=None, ylim=None)[source]

Extract raw and unique numeric x/y parameter arrays.

static sort_results_by_param(results: List[LazyDataEntry | ResultProxy], param_name: str)[source]
static determine_vmax_vmin(results: ~typing.List[~general_python.common.lazy_entry.LazyDataEntry | ~general_python.common.plotters.data_loader.ResultProxy], param_name: str, param_fun: ~typing.Callable = <function PlotData.<lambda>>, nstates: int = None)[source]
static savefig(fig, directory: str, *name_parts, suffix: str = '', ext: str = 'png', dpi: int = 250, logger: Logger = None) Path[source]

Save figure to directory with a generated file name.

class general_python.common.Logger(name: str = 'Global', logfile: str | None = None, lvl: int = 20, append_ts: bool = False, use_ts_in_cmd: bool = False)

Bases: object

Logger class for handling console and file logging with verbosity control.

LEVELS = {10: 'debug', 20: 'info', 30: 'warning', 40: 'error'}
LEVELS_R = {'debug': 10, 'error': 40, 'info': 20, 'warning': 30}
__init__(name: str = 'Global', logfile: str | None = None, lvl: int = 20, append_ts: bool = False, use_ts_in_cmd: bool = False)

Initialize the logger instance.

Parameters:
  • logfile (str) – Name of the log file (without extension if empty, a timestamp will be used).

  • lvl (int) – Logging level (default: logging.INFO).

  • append_ts (bool) – Whether to append a timestamp to the log file name (default: False).

  • use_ts_in_cmd (bool) – Whether to use a timestamp in console output (default: False).

static breakline(n: int)

Print multiple break lines.

Parameters:

n (int) – Number of break lines.

static colorize(txt: str, color: str)

Apply color to the given text (for console output).

Parameters:
  • txt (str) – Text to colorize.

  • color (str) – Color name.

Returns:

Colorized text.

Return type:

str

configure(directory: str)

Configure the logger to use a specific directory for log files.

Parameters:

directory (str) – Path to the directory where log files will be stored.

dbg(msg: str, lvl=0, verbose=True, color=None)

Alias for debug().

debug(msg: str, lvl=0, verbose=True, color=None)

Log a debug message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the message.

classmethod endl(n: int)

Print n blank lines through the logger break-line helper.

err(msg: str, lvl=0, verbose=True, color='red')

Alias for error().

error(msg: str, lvl=0, verbose=True, color='red')

Log an error message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the message.

inf(msg: str, lvl=0, verbose=True, color=None)

Alias for info().

info(msg: str, lvl=0, verbose=True, color=None)

Log an informational message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

static print(msg: str, lvl=0)

Format a message with a timestamp.

Parameters:
  • msg (str) – Message to format.

  • lvl (int) – Indentation level.

Returns:

Formatted message.

Return type:

str

static print_tab(lvl=0)

Generate indentation for message formatting.

Parameters:

lvl (int) – Number of tabulators.

Returns:

Indented string.

Return type:

str

say(*args, end=True, log=20, lvl=0, verbose=True, color=None)

Print and log multiple messages if verbosity is enabled.

Parameters:
  • *args – Messages to log.

  • end (bool) – Append newline (default: True).

  • log (int) – Log level (10 : info, 20 : debug, 30 : warning, 40 : error) (default: 10).

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

timing(func)

Decorator to measure and log the execution time of functions. :param func: function to be timed

Use as:

@logger.timing def my_function(…):

title(tail: str, desired_size: int = 50, fill: str = '=', lvl=0, verbose=True, color=None)

Create a formatted title with filler characters if verbosity is enabled.

Parameters:
  • tail (str) – Text in the middle of the title.

  • desired_size (int) – Total width of the title.

  • fill (str) – Character used for filling.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the title.

warn(msg: str, lvl=0, verbose=True, color='yellow')

Alias for warning().

warning(msg: str, lvl=0, verbose=True, color='yellow')

Log a warning message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

general_python.common.get_global_logger(**kwargs) Logger

One Logger wrapper per process (PID), safe across threads/forks. Prints the banner only once per entire program via env sentinel.

Parameters:
  • **kwargs – Arguments to pass to the Logger constructor.

  • name (-) – Name of the logger (default: “Global”).

  • lvl (-) – Logging level (default: logging.INFO).

  • append_ts (-) – Whether to append timestamps (default: True).

  • use_ts_in_cmd (-) – Whether to use timestamps in commands (default: True).

  • logfile (-) – Path to a logfile (default: None).

Returns:

The global logger instance.

Return type:

Logger

Example

>>> logger = get_global_logger()
>>> logger.info("This is an informational message.")
>>> logger.debug("This is a debug message.", color='blue')
general_python.common.log_memory_status(context: str = '', logger: Logger | None = None, lvl: int = 0, verbose: bool = True, **kwargs) None[source]

Log current memory usage.

general_python.common.check_memory_for_operation(required_gb: float, operation_name: str, safety_factor: float = 0.8, logger: Logger | None = None) bool[source]

Check if there’s enough memory for an operation.

Parameters:
  • required_gb – Estimated memory required in GB

  • operation_name – Name of operation for logging

  • safety_factor – Fraction of available memory to use (default 0.8)

Returns:

True if safe to proceed, False otherwise

class general_python.common.Timer(name: str | None = None, logger: Logger | None = None, logger_args: Dict[str, Any] | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None)[source]

Bases: object

Enhanced timer class for measuring elapsed time.

This class can be used as a context manager, a decorator, or directly to time code. It supports:

  • Starting, stopping, and resetting the timer.

  • Recording multiple laps.

  • Verbose output to automatically print timing information.

name

Optional name to identify the timer.

Type:

str

verbose

If True, prints timing information on stop.

Type:

bool

format

Optional format for the output timing information.

name: str | None
logger: Logger | None
logger_args: Dict[str, Any] | None
verbose: bool
unit: str
deadline_s: float | None
synchronizer: Callable[[Any], None] | None
start() Timer[source]

Start (or resume) the timer; no-op if already running.

pause() Timer[source]

Pause the timer, accumulating elapsed time.

resume() Timer[source]

Resume after pause.

stop() float[source]

Stop and return elapsed time in seconds.

reset() Timer[source]

Clear state (elapsed, laps, marks) and stop.

lap(name: str | None = None) float[source]

Record a lap (time since last lap or start) and return lap in seconds.

mark(name: str | None = None) None[source]

Create/update a named absolute anchor at current time. Later use since(‘name’).

since(name: str | None = None, ts: int | None = None) float[source]

Seconds elapsed since the named mark. Raises KeyError if mark not set.

elapsed_ns() int[source]

Total elapsed nanoseconds (includes current running span).

elapsed_ms() float[source]

Elapsed milliseconds (float).

elapsed_us() float[source]

Elapsed microseconds (float).

elapsed_s() float[source]

Elapsed seconds (float).

laps() Tuple[List[float], List[str]][source]

Recorded laps (seconds) and their names.

remaining_s(buffer_s: float = 0.0) float | None[source]

If deadline_s is set, return remaining seconds (can be negative). Otherwise None.

overtime(buffer_s: float = 0.0) bool[source]

True if elapsed >= deadline_s - buffer_s; False if no deadline is set.

property state: TimerState

Current timer lifecycle state.

format_elapsed() str[source]

Return elapsed time formatted in the configured display unit.

report(include_laps: bool = True) str[source]

Build a human-readable timing report.

Parameters:

include_laps – Include named lap timings when any have been recorded.

classmethod decorator(name: str | None = None, logger: Logger | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None)[source]

Decorator for timing a function.

Usage:

@Timer.decorator(“block”, verbose=True) def fn(…): …

Parameters:
  • name (-) – The name of the timer (default: function name)

  • logger (-) – Optional logger for logging (default: None)

  • verbose (-) – If True, print timing info (default: False)

  • unit (-) – Time unit for reporting (default: “auto”)

  • deadline_s (-) – Optional deadline in seconds (default: None)

  • synchronizer (-) – Optional synchronizer function (default: None)

__init__(name: str | None = None, logger: Logger | None = None, logger_args: Dict[str, Any] | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None) None
general_python.common.dtype_to_name(dtype)[source]

Normalize dtype-like objects to the canonical QES dtype name.

general_python.common.get_module_description(module_name)[source]

Get the description of a specific module in the common package.

Parameters: - module_name (str): The name of the module.

Returns: - str: The description of the module.

general_python.common.list_available_modules()[source]

List all available modules in the common package.

Returns: - list: A list of available module names.

Path and directory helpers built on pathlib.Path.

The Directories wrapper keeps legacy convenience methods available while exposing a path-like object that can be passed to standard-library APIs. It covers path joining, directory creation, file discovery, copying, and common serialization helpers used by analysis scripts.

class general_python.common.directories.staticproperty(fget=None, fset=None, fdel=None, doc=None)[source]

Bases: property

Descriptor for exposing a zero-argument function as a static property.

class general_python.common.directories.classproperty(fget=None, fset=None, fdel=None, doc=None)[source]

Bases: property

Descriptor for exposing a classmethod-like function as a property.

class general_python.common.directories.Directories(*parts: str | Path)[source]

Bases: object

Class representing a directory handler - static methods are represented with camel case - class methods are represented with underscore

__init__(*parts: str | Path) None[source]

Initialize with one or more path components. >>> d = Directories(“foo”, “bar”) # -> Path(“foo/bar”)

__len__() int | str[source]

Return the number of items in the directory if it is a directory,

__add__(other: str | Path) Directories[source]

Concatenate with another path component. >>> d = Directories(“foo”) + “bar” # -> Path(“foo/bar”)

__iadd__(other: str | Path) Directories[source]

In-place concatenation with another path component. >>> d = Directories(“foo”); d += “bar” # -> Path(“foo/bar”)

__radd__(other: str | Path) Directories[source]

Concatenate with another path component. >>> d = “foo” + Directories(“bar”) # -> Path(“foo/bar”)

__truediv__(other: str | Path) Directories[source]

Concatenate with another path component using / operator. >>> d = Directories(“foo”) / “bar” # -> Path(“foo/bar”)

__rtruediv__(other: str | Path) Directories[source]

Concatenate with another path component using / operator. >>> d = “foo” / Directories(“bar”) # -> Path(“foo/bar”)

__iter__() Iterator[Path][source]

Iterate over parts of the path. >>> d = Directories(“foo/bar”) # -> iterates over [“foo”, “bar”]

__eq__(other: str | Path) bool[source]

Check equality with another path component. >>> d = Directories(“foo”) == “foo” # -> True

__ne__(other: str | Path) bool[source]

Check inequality with another path component. >>> d = Directories(“foo”) != “bar” # -> True

__hash__() int[source]

Hash the path for use in sets or dictionaries. >>> d = Directories(“foo”) # -> hash(Path(“foo”))

__repr__() str[source]

Return a string representation of the path. >>> d = Directories(“foo”) # -> “Directories(‘foo’)”

__str__() str[source]

Return a string representation of the path. >>> d = Directories(“foo”) # -> “foo”

static f_h5(p: List[Path]) List[str][source]

Filter for .h5 files.

static f_csv(p: List[Path]) List[str][source]

Filter for .csv files.

static f_nonempty(p: List[Path]) List[str][source]

Filter for non-empty files.

static f_contains(substr: str) Callable[[Path], bool][source]

Return a filter that checks if the filename contains a substring.

join(*parts: str | Path, create: bool = False) Directories[source]

Return a new Directories for self/path joined with parts. If create=True, mkdir(parents=True, exist_ok=True) is called.

property parent: Directories

Return Directories for parent directory (..).

classmethod win(raw: str) Directories[source]

Parse a Windows-style backslash path into Directories.

format(*args, **kwargs) Directories[source]

Format the path using str.format() and return a new Directories. >>> d = Directories(“foo”).format(“bar”) # -> Path(“foo/bar”)

resolve() Directories[source]

Return a new Directories with the absolute resolved path.

endswith(suffix: str) bool[source]

Check if the path ends with the given suffix.

mkdir(parents: bool = True, exist_ok: bool = True) Directories[source]

Create this directory on disk. Returns self for chaining.

static mkdirs(paths: Iterable[str | Path], parents: bool = True, exist_ok: bool = True) None[source]

Create multiple directories.

list_files(*, include_empty: bool = True, filters: List[Callable[[Path], bool]] = None, sort_key: Callable[[Path], any] | None = None) List[Path][source]

List files (not directories) in this directory. - include_empty : if False, skip files of size zero. - filters : a list of callables Path->bool; all must pass. - sort_key : key function for sorting.

list_dirs(*, include_empty: bool = True, include_hidden: bool = True, relative: bool = False, as_string: bool = False, filters: List[Callable[[Path], bool]] = [], sort_key: Callable[[Path], Any] | None = None) List[Path][source]

List directories in this directory.

Parameters:
  • include_empty (bool) – if False, skip empty directories. If True, include all directories. This checks only if the directory has any entries, not if they are files or directories.

  • filters (list of callables Path -> bool) – A list of callables; all must return True for a directory to be included.

  • sort_key (callable, optional) – Key function for sorting the results.

static list_data_roots(base: str | Path, *, sort: bool = True, as_dirs: bool = True) List[Directories] | List[Path][source]

List all first-level directories inside base…

Parameters:
  • base (PathLike) – Root directory (e.g. data_path)

  • sort (bool) – Sort lexicographically (useful for YYYYMMDD)

  • as_dirs (bool) – Return Directories objects instead of Path

Return type:

List of directories

static expand_data_roots(base: str | Path, *subpath: str | Path, require_exist: bool = True) List[Directories][source]

Expand a relative subpath across all first-level directories.

Parameters:
  • base (PathLike) – Root directory (e.g. data_path)

  • subpath (PathLike) – Relative path to append to each root (e.g. hamil/occ/ns/sp)

  • require_exist (bool) – If True, only include paths that exist on disk.

Example

expand_data_roots(data_path, ‘data’, hamil, …, ‘sp’)

Returns list of:

base/<date>/data/…/sp

static collect_files(dirs: List[Directories], *, prefix: str = None, suffix: str = None, filters: List[Callable[[Path], bool]] = None, sort: bool = False) List[Path][source]

Collect files from multiple directories.

Parameters:
  • dirs – list of Directories

  • prefix – optional filename prefix filter

  • suffix – optional filename suffix filter

  • filters – additional filters (Path -> bool)

  • sort – global sorting

Return type:

Flat list of Paths

clear_empty() List[Path][source]

Remove all zero-length files in this directory. Returns list of files left after removal.

walk() Iterator[Path][source]

Walk the directory tree and yield all files.

glob(pattern: str) List[Path][source]

Return a list of all files matching the pattern in this directory.

random_file(condition: ~typing.Callable[[~pathlib.Path], bool] = <function Directories.<lambda>>) Path[source]

Return a random Path in this directory satisfying condition. Raises ValueError if none match.

copy_files(dest: str | Path, condition: Callable[[Path], bool], overwrite: bool = False) None[source]

Copy all files satisfying condition() from self to dest. Creates dest if needed.

Parameters:
  • dest (PathLike) – Destination directory.

  • condition (Callable[[Path], bool]) – Function that takes a Path and returns True if the file should be copied.

  • overwrite (bool, optional) – If True, overwrite existing files in the destination directory. Default is False.

transfer_files(dest: str | Path, condition: Callable[[Path], bool]) None[source]

Move all files satisfying condition() from self to dest. Creates dest if needed.

property exists: bool

Check if the path exists.

property as_path: Path

Return the path as a Path object.

property is_empty: bool

Check if the directory is empty.

property is_dir: bool

Check if the path is a directory.

property is_file: bool

Check if the path is a file.

Check if the path is a symlink.

property size: int

Return the size of the directory in bytes.

property size_human: str

Return the size of the directory in a human-readable format.

property disk_usage: str

Return the disk usage of the directory in a human-readable format.

property checksum: str

Return the checksum of the directory.

static temp_dir(prefix: str = 'tmp') Directories[source]

Create and return a temporary directory with the given prefix.

current = Directories(PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/general-python/checkouts/latest/docs'))[source]
home = Directories(PosixPath('/home/docs'))[source]
root = Directories(PosixPath('/'))[source]
static from_env(var_name: str) Directories | None[source]

Create a Directories object from an environment variable. Returns None if the variable is not set or the path does not exist.

static from_config(config: dict, key: str) Directories | None[source]

Create a Directories object from a configuration dictionary. Returns None if the key is not found or the path does not exist.

static from_string(s: str) Directories[source]

Create a Directories object from a string path.

static from_parts(*parts: str | Path) Directories[source]

Create a Directories object from multiple path components.

static from_path(p: str | Path) Directories[source]

Create a Directories object from a Path-like object.

class general_python.common.directories.DirectoriesData(**dirs: str)[source]

Bases: object

Collects directories across multiple machines and stores them in dictionaries. Only directories that exist are included in existing. Example: >>> dirs = DirectoriesData( >>> klimak_um_only_f=(“/media/…/klimak_um_only_f_t100000/uniform”, “503”), >>> klimak_all=(“/media/…/klimak_um_plrb_all_t100000/uniform”, “503”), >>> locally=(“data_project/uniform”, “local”) >>> )

__init__(**dirs: str)[source]

Initialize with named directory paths. Each value can be either a string (path) or a tuple (path, machine).

all: Dict[str, Directories]
existing: Dict[str, Directories]
machines: Dict[str, List[str]]
get(name: str, only_existing: bool = True) Directories | None[source]

Get a directory by name. Optionally restrict to existing ones.

Parameters:
  • name (str) – The name of the directory to retrieve.

  • only_existing (bool, optional) – If True, only return the directory if it exists. Default is True.

add(name: str, path: str | Path, machine: str = 'default')[source]

Add a new directory.

remove(name: str) None[source]

Remove a directory entry by name.

filter_names(filters: list[str | Callable[[str], bool]], only_existing: bool = True) list[str][source]

Return names that match any filter. Filters can be substrings or callables (e.g. regex matchers, lambdas).

filter_dirs(filters: list[str | Callable[[str], bool]], only_existing: bool = True) dict[str, Directories][source]

Return {name: Directories} for names matching any filter. Filters can be substrings or callables (e.g. regex matchers, lambdas).

list_existing() List[str][source]

List names of existing directories.

list_existing_dirs() List[Directories][source]

List existing Directories objects.

list_all() List[str][source]

List all directory names provided.

list_all_dirs() List[Directories][source]

List all Directories objects provided.

list_machines() List[str][source]

List all machines.

on(machine: str, only_existing: bool = True) Dict[str, Directories][source]

Get directories for a specific machine. :param machine: The machine name to filter directories. :type machine: str :param only_existing: If True, only return existing directories. Default is True. :type only_existing: bool, optional

Returns:

A dictionary of directory names to Directories objects.

Return type:

Dict[str, Directories]

Raises:

KeyError – If the machine is not known.

register_machine(machine: str)[source]

Ensure machine is known (for clarity, optional).

__add__(other: DirectoriesData) DirectoriesData[source]

Return a new DirectoriesData with merged contents.

__iadd__(other: DirectoriesData) DirectoriesData[source]

In-place merge of other into self.

__radd__(other: DirectoriesData) DirectoriesData[source]

Allow sum([…]) to work by reusing __add__.

Plotting Utilities for Scientific Publications

This module provides a comprehensive set of tools for creating publication-quality plots using Matplotlib. It is designed for scientific computing workflows with support for Nature/Science journal formatting.

Features

  • Publication Styles: Pre-configured styles for Nature, Science, and PRL journals

  • Color Management: Colorblind-safe palettes (Tableau, Plastic, Pastel)

  • Scientific Formatters: LaTeX-style axis labels with scientific notation

  • Flexible Colorbars: Full control over position, scale (log/linear), and discretization

  • Grid Layouts: GridSpec-based subplot management with insets

  • Data Filtering: Parameter-based filtering for multi-experiment datasets

Quick Start

>>> from general_python.common.plot import Plotter
>>> fig, axes = Plotter.get_subplots(nrows=2, ncols=2, sizex=8, sizey=6)
>>> Plotter.plot(axes[0], x, y, color='C0', label='Data')
>>> Plotter.semilogy(axes[1], x, y_log, color='C1')
>>> Plotter.set_ax_params(axes[0], xlabel='Time (s)', ylabel='Signal', title='Panel A')
>>> Plotter.set_legend(axes[0], style='publication')
>>> Plotter.save_fig('.', 'figure', format='pdf', dpi=300)

Examples

Creating a figure with error bars:

fig, ax = Plotter.get_subplots(1, 1, sizex=4, sizey=3)
Plotter.errorbar(ax[0], x, y, yerr=sigma, color='C0', label='Measurement')
Plotter.set_ax_params(ax[0], xlabel=r'$x$', ylabel=r'$f(x)$', yscale='log')

Adding a colorbar:

cbar, cax = Plotter.add_colorbar(
    fig, [0.92, 0.15, 0.02, 0.7], data,
    cmap='viridis', scale='log', label=r'$|\psi|^2$'
)

For more examples, call: Plotter.help()

Email : maxgrom97@gmail.com

general_python.common.plot.configure_style(style: str = 'publication', font_size: int = 10, use_latex: bool = False, dpi: int = 150, **overrides)[source]

Configure matplotlib rcParams for publication-quality figures.

This function sets up consistent styling across all plots. Call it once at the start of your script or notebook.

Parameters:
  • style (str, default='publication') – Style preset to apply. - ‘publication’ : compact Nature/Science-like defaults - ‘presentation’ : larger text and strokes for talks - ‘poster’ : very large sizes for posters - ‘minimal’ : stripped-down axes visuals - ‘default’ : reset to Matplotlib defaults

  • font_size (int, default=10) – Base typographic scale. Label/tick/title/legend sizes are derived from this value in each preset.

  • use_latex (bool, default=False) – If True, tries LaTeX-backed rendering via scienceplots profile. If unavailable, falls back gracefully to non-LaTeX settings.

  • dpi (int, default=150) – Screen/display DPI used in interactive rendering.

  • **overrides (dict) – Additional rcParams overrides. Underscores are accepted and converted to dots, e.g. axes_linewidth=1.0 -> axes.linewidth. Typical high-impact keys: - ‘savefig.dpi’ - ‘axes.prop_cycle’ - ‘figure.constrained_layout.use’ - ‘font.family’

Examples

>>> # Standard publication setup
>>> configure_style('publication', font_size=10)
>>> # Presentation with larger fonts
>>> configure_style('presentation', font_size=14, dpi=100)
>>> # Custom overrides
>>> configure_style('publication', **{'axes.linewidth': 1.5, 'lines.linewidth': 2})

Notes

This function modifies global matplotlib rcParams. To reset to defaults, use configure_style(‘default’) or mpl.rcParams.update(mpl.rcParamsDefault).

Recommended figure sizes for journals: - Nature: single column = 3.5 in, double column = 7 in - Science: single column = 3.5 in, double column = 7.25 in - PRL: single column = 3.4 in, double column = 7 in

general_python.common.plot.get_rcparams_summary() dict[source]

Get a summary of current rcParams relevant to plotting.

Returns:

Dictionary with current settings for fonts, lines, axes, etc.

Return type:

dict

Examples

>>> params = get_rcparams_summary()
>>> print(params['font.size'])
general_python.common.plot.reset_color_cycles(which=None)[source]

Reset the color cycles to the default ones. - which : which color cycle to reset Returns: - the cycle to take

general_python.common.plot.get_color_cycle(which=None)[source]

Get the color cycle to use. - which : which color cycle to use Returns: - the cycle to take

general_python.common.plot.markerNorm(x)
general_python.common.plot.linestyleNorm(x)
general_python.common.plot.reset_linestyles(which=None)[source]

Reset the line styles to the default ones. - which : which line style to reset Returns: - the cycle to take

general_python.common.plot.get_linestyle_cycle(which=None)[source]

Get the line style cycle to use. - which : which line style cycle to use Returns: - the cycle to take

class general_python.common.plot.CustomFormatter(fmt='{x:.2f}')[source]

Bases: Formatter

Matplotlib formatter backed by a Python format string.

__init__(fmt='{x:.2f}')[source]

Initialize the object with a format string.

Args: fmt (str): The format string to be used for formatting.

class general_python.common.plot.PercentFormatter(decimals=2, symbol='%')[source]

Bases: PercentFormatter

Percent formatter with concise defaults for publication plots.

__init__(decimals=2, symbol='%')[source]

Initialize the object with a format string.

Args: decimals (int): The number of decimal places to use. symbol (str): The symbol to use for percentage.

class general_python.common.plot.MathTextSciFormatter(fmt='%1.2e')[source]

Bases: Formatter

Scientific-notation formatter that renders exponents as math text.

__init__(fmt='%1.2e')[source]

Initialize the object with a format string.

Args: fmt (str): The format string to be used for formatting.

general_python.common.plot.set_formatter(ax, formatter_type='sci', fmt='%1.2e', axis='xy')[source]

Sets the formatter for the given axis on the plot.

fmt can take the value:
  • “%1.2e”: Scientific notation with 2 decimal places.

  • “%1.2f”: Fixed notation with 2 decimal places.

  • “%1.0f”: Fixed notation with 0 decimal places.

  • “%1.0e”: Scientific notation with 0 decimal places

  • “%1.0%”: Percentage notation with 0 decimal places.

  • “%1.2%”: Percentage notation with 2 decimal places.

  • “%1.0g”: General notation with 0 decimal places.

  • “%1.2g”: General notation with 2 decimal places.

Integers: - “%d”: Integer notation. - “%i”: Integer notation. - “%u”: Unsigned integer notation. - “%o”: Octal notation. - “%x”: Hexadecimal notation. - “%X”: Hexadecimal notation. - “%c”: Character notation. - “%r”: Repr notation.

Parameters:
  • ax (object) – The axis object on which to set the formatter.

  • formatter_type (str) – The type of formatter to use. Options are “sci”, “custom”, “percent”.

  • fmt (str, optional) – The format string for the axis labels. Defaults to “%1.2e”.

  • axis (str, optional) – The axis on which to set the formatter. Defaults to ‘xy’.

Additional formatter options:
  • “sci”: Scientific notation formatter.

  • “custom”: Custom formatter using a provided format string.

  • “percent”: Percentage formatter.

Returns:

None

class general_python.common.plot.IgnoredAxis(*, index: int | None = None, row_col: Tuple[int, int] | None = None, names: List[str] | None = None, warn: bool = False, reason: str | None = None)[source]

Bases: object

No-op stand-in for a disabled axis.

Any call or chained attribute access is ignored by default. Optionally, warnings can be emitted to make ignored operations explicit.

__init__(*, index: int | None = None, row_col: Tuple[int, int] | None = None, names: List[str] | None = None, warn: bool = False, reason: str | None = None)[source]
property is_disabled: bool

Return True for compatibility with regular axes checks.

set_warn(enabled: bool = True)[source]

Enable or disable warnings for ignored axis operations.

description() str[source]

Return a compact description of the disabled axis target.

class general_python.common.plot.AxesList(axes, nrows: int | None = None, ncols: int | None = None, panel_map: Dict[str, Any] | None = None)[source]

Bases: list

List-like container for subplot axes with optional grid-aware helpers.

Behaviors: - Inherits from list (all list operations remain available). - Supports 2D indexing when grid metadata is available: axes[row, col]. - Forwards unknown attribute access to the first axis, enabling

single-axis-like usage in quick scripts.

__init__(axes, nrows: int | None = None, ncols: int | None = None, panel_map: Dict[str, Any] | None = None)[source]

Initialize the AxesList.

property shape: Tuple[int, int] | None

Grid shape (nrows, ncols) when known, otherwise None.

first()[source]

Return the first axis or raise if the container is empty.

property panel_names: List[str]

Names of registered semantic panels.

has_panel(name: str) bool[source]

Return whether a semantic panel name is registered.

panel(name: str)[source]

Return the axis or nested axes registered for name.

panels() Dict[str, Any][source]

Return a copy of the semantic panel map.

rename_panel(old: str, new: str)[source]

Rename a semantic panel while preserving its mapped axes.

select(*names: str)[source]

Return an AxesList containing the named panels.

as_grid()[source]

Return axes as a rectangular object array using stored grid shape.

at(row: int, col: int)[source]

Return the axis at grid location (row, col).

span(rows, cols)[source]

Return axes in a rectangular grid window.

Examples: - axes.span(slice(0, 2), slice(1, 3)) - axes.span(0, slice(None))

row(row: int)[source]

Return one grid row as an AxesList.

col(col: int)[source]

Return one grid column as an AxesList.

set_title(title, **kwargs)[source]

Set the title for all axes.

apply(fn, *args, **kwargs)[source]

Apply fn to each axis and return self.

disable(target, *, warn: bool = False, hide: bool = True, reason: str | None = None)[source]

Disable one or more axes by replacing them with no-op placeholders.

Parameters:
  • target (int | tuple[int, int] | str | Axes | list-like) – Axis selector. Supported forms: - flat index (int) - grid index (row, col) - panel name (str) - an axis instance already contained in this AxesList - list/tuple/ndarray of the above (disabled recursively)

  • warn (bool, default=False) – If True, any later operation on the disabled axis emits a warning.

  • hide (bool, default=True) – If True and target is a real matplotlib axis, call set_axis_off before disabling.

  • reason (str, optional) – Optional note included in warning messages.

Returns:

Returns self for chaining.

Return type:

AxesList

collapse(*, redraw: bool = True)[source]

Collapse rows that are fully disabled and reflow remaining axes.

Behavior

  • A row is removed only if all entries in that row are disabled.

  • Partially disabled rows are kept intact, so grid holes remain.

  • Recomputes subplot positions for kept rows to remove vertical whitespace.

returns:

Returns self for chaining.

rtype:

AxesList

adjust(same: str = 'xy', *, hide: str = 'both', keep_x: str = 'bottom', keep_y: str = 'left', xlabel: str | None = None, ylabel: str | None = None, xlabel_kwargs: dict | None = None, ylabel_kwargs: dict | None = None, x_label_position: str | None = None, y_label_position: str | None = None, x_label_coords: Tuple[float, float] | None = None, y_label_coords: Tuple[float, float] | None = None, x_label_coords_system: str = 'axes', y_label_coords_system: str = 'axes', x_tick_params: dict | None = None, y_tick_params: dict | None = None, interior_x_tick_params: dict | None = None, interior_y_tick_params: dict | None = None)[source]

Remove duplicated axis labels/ticklabels for multi-panel layouts.

Parameters:
  • same ({'x', 'y', 'xy'}, default='xy') – Which directions should be de-duplicated.

  • hide ({'both', 'labels', 'ticklabels'}, default='both') – What to hide on interior panels.

  • keep_x ({'bottom', 'top', 'all', 'vbottom', 'vtop'}, default='bottom') – Which row keeps x-axis labels/ticklabels. vbottom/vtop additionally force visible labels/ticks to be drawn at the corresponding side (not only by grid-edge selection).

  • keep_y ({'left', 'right', 'all', 'vleft', 'vright'}, default='left') – Which column keeps y-axis labels/ticklabels. vleft/vright additionally force visible labels/ticks to be drawn at the corresponding side (not only by grid-edge selection).

  • xlabel (str, optional) – Label text applied to kept outer x/y axes.

  • ylabel (str, optional) – Label text applied to kept outer x/y axes.

  • xlabel_kwargs (dict, optional) – Forwarded to ax.set_xlabel / ax.set_ylabel on kept axes.

  • ylabel_kwargs (dict, optional) – Forwarded to ax.set_xlabel / ax.set_ylabel on kept axes.

  • x_label_position (str, optional) – Manual label side position (x: top|bottom, y: left|right).

  • y_label_position (str, optional) – Manual label side position (x: top|bottom, y: left|right).

  • x_label_coords (tuple(float, float), optional) – Manual label coordinates for x/y labels.

  • y_label_coords (tuple(float, float), optional) – Manual label coordinates for x/y labels.

  • x_label_coords_system ({'axes', 'data'}) – Coordinate system used for manual label coordinates.

  • y_label_coords_system ({'axes', 'data'}) – Coordinate system used for manual label coordinates.

  • x_tick_params (dict, optional) – Tick styling applied to kept x/y axes via ax.tick_params.

  • y_tick_params (dict, optional) – Tick styling applied to kept x/y axes via ax.tick_params.

  • interior_x_tick_params (dict, optional) – Tick styling for interior x/y axes after de-duplication. Useful to keep ticks but hide labels, adjust lengths, etc.

  • interior_y_tick_params (dict, optional) – Tick styling for interior x/y axes after de-duplication. Useful to keep ticks but hide labels, adjust lengths, etc.

__getitem__(key)[source]

Support panel name access and 2D grid indexing.

class general_python.common.plot.Plotter(default_cmap='viridis', font_size=12, dpi=200)[source]

Bases: object

Publication-quality plotting utilities for scientific computing.

This class provides static methods for creating, customizing, and saving Matplotlib figures suitable for scientific journals (Nature, Science, PRL, etc.).

All methods are @staticmethod, so you can use them without instantiation:

>>> Plotter.plot(ax, x, y, color='C0', label='Data')
>>> Plotter.set_legend(ax, style='publication')

Main Categories

Plotting Methods : plot, scatter, tripcolor_field, semilogy, semilogx, loglog, errorbar, fill_between, histogram Axis Setup : set_ax_params, set_tickparams, setup_log_x, setup_log_y Annotations : set_annotate, set_annotate_letter, set_arrow Colorbars : add_colorbar, get_colormap, discrete_colormap Layouts : get_subplots, get_grid, get_inset Legends : set_legend, set_legend_custom Saving : save_fig, savefig

For full documentation, call: Plotter.help()

__init__(default_cmap='viridis', font_size=12, dpi=200)[source]

Initialize the Plotter with default parameters.

Parameters:
  • default_cmap (str, default='viridis') – Default colormap for heatmaps and colorbars.

  • font_size (int, default=12) – Default font size for labels and text.

  • dpi (int, default=200) – Resolution for rasterized output (PNG, TIFF).

Note

Most methods are @staticmethod and don’t require instantiation. Use the class directly: Plotter.plot(ax, x, y)

ax_off(ax: Axes | List[Axes])[source]

Completely turn off the axis (no ticks, labels, spines, data).

static ax(ax: Axes | List[Axes], *args, **kwargs)[source]

Alias for ax method to allow direct calls.

static disable(axes, target, *, warn: bool = False, hide: bool = True, reason: str | None = None) AxesList[source]

Convenience wrapper for AxesList.disable.

Parameters:
  • axes (AxesList or list-like of axes) – Axes container to operate on.

  • target (selector) – Forwarded to AxesList.disable().

  • warn (bool, default=False) – Emit warnings when disabled axes are used.

  • hide (bool, default=True) – Hide underlying axes before disabling.

  • reason (str, optional) – Optional warning context.

Returns:

The modified axes list.

Return type:

AxesList

static help(topic: str = None)[source]

Print help information about available plotting methods.

Parameters:

topic (str, optional) – Specific topic to get help on. Options: - ‘plot’: Basic plotting methods - ‘axis’: Axis configuration - ‘color’: Colors and colorbars - ‘layout’: Subplots and grids - ‘save’: Saving figures - None: Print overview of all topics

Examples

>>> Plotter.help()           # Overview
>>> Plotter.help('plot')     # Plotting methods
>>> Plotter.help('axis')     # Axis configuration
static plot_style(**kwargs) PlotStyle[source]

Return a PlotStyle config instance.

static kspace_config(**kwargs) KSpaceConfig[source]

Return a KSpaceConfig instance.

static kpath_config(**kwargs) KPathConfig[source]

Return a KPathConfig instance.

static spectral_config(**kwargs) SpectralConfig[source]

Return a SpectralConfig instance.

static figure_config(**kwargs) FigureConfig[source]

Return a FigureConfig instance.

static plotters()[source]

Expose the general_python.common.plotters package.

static statistical_fitter()[source]

Backward-compatible alias for fitter().

static fitter()[source]

Expose shared fitting/scaling helpers from general_python.maths.math_utils.

static math(label: str, *, auto_wrap: bool = True, escape_text: bool = True, **values: Any) str[source]

Build a LaTeX-ready math label from a template.

This method extends Python str.format-style placeholders with simple math filters for scientific labels.

Parameters:
  • label (str) – Template string. Standard format fields are supported, e.g. {J:.3g}, and can be combined with filters, e.g. {point|vec:.2f}, {kpoint|sym}. Greek-name values (for example Gamma or omega) are automatically rendered as LaTeX variables.

  • auto_wrap (bool, default=True) – If True, wrap the final string with $...$ when no dollar sign is present in the rendered output.

  • escape_text (bool, default=True) – If True, plain substituted strings are LaTeX-escaped by default. Use |raw or |tex to bypass escaping for a specific field.

  • **values (Any) – Values used by template fields.

  • filters (Supported)

  • -----------------

  • tex (raw /) – Insert value as-is (no escaping).

  • sym – Force symbol conversion for common Greek names (e.g. Gamma -> \Gamma).

  • num – Numeric formatting helper. Uses format spec if provided, otherwise .6g.

  • vec – Render iterable as \left(v_1, v_2, ...\right).

  • set – Render iterable as \left\{v_1, v_2, ...\right\}.

Returns:

Rendered LaTeX/mathtext-compatible label.

Return type:

str

Examples

>>> Plotter.math(r"\\langle S_i^z \\rangle = {value|num:.3e}", value=1.2e-4)
'$\\langle S_i^z \\rangle = 1.200e-04$'
>>> Plotter.math(r"{kx|sym}-{ky|sym} path, q={q|vec:.2f}", kx="Gamma", ky="K", q=[0, 1/3])
'$\\Gamma-K path, q=\\left(0, 0.33\\right)$'
>>> Plotter.math(r"E={expr|raw}", expr=r"E_0 + \\Delta")
'$E=E_0 + \\Delta$'
static ensure_list(x)[source]

Return x as a list-like container for axis utilities.

static unify_limits(axes, which='y')[source]

Set all axes to the shared x or y limits.

static resolve_planar_limits(points, *, limits: tuple | list | ndarray | None = None, x_limits: tuple | list | ndarray | None = None, y_limits: tuple | list | ndarray | None = None, xmin: float | None = None, xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, limit_to_pi: bool = False, pad_fraction: float = 0.08) Tuple[tuple, tuple][source]

Resolve visible (xlim, ylim) for planar data.

Parameters:
  • points (array-like) – Planar sample points shaped like (N, 2) or (N, D).

  • limits (sequence, optional) – Explicit bounds. Length 2 means shared (min, max) for both axes. Length 4 means (xmin, xmax, ymin, ymax).

  • x_limits (sequence, optional) – Explicit bounds for each axis separately.

  • y_limits (sequence, optional) – Explicit bounds for each axis separately.

  • xmin (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • xmax (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • ymin (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • ymax (float, optional) – Scalar axis bound overrides. These take precedence over inferred limits and can refine limits / x_limits / y_limits.

  • limit_to_pi (bool, default=False) – If True and limits is not provided, use [-pi, pi] on both axes.

  • pad_fraction (float, default=0.08) – Relative padding applied when limits are inferred from points.

static markers()[source]

Markers with common options for line and scatter plots.

static markersC()[source]

Markers cycle with common options for line and scatter plots.

static colors()[source]

Colors with common options for line and scatter plots.

static linestyles()[source]

Linestyles with solid and dashed options.

static linestylesC()[source]

Linestyles cycle with solid and dashed options.

static linestylesCE()[source]

Extended linestyles cycle with more options for dashed lines.

static palette(name: str = 'tableau', n: int | None = None) List[str][source]

Return a named color palette as a list of hex strings.

Parameters:
  • name (str, default='tableau') –

    Palette name. Built-in options:

    Name

    Description

    wong

    Wong (2011) 8-color CBF palette (Nature Methods)

    okabe

    Okabe & Ito – identical to wong

    tol

    Paul Tol’s muted 10-color qualitative set

    tol_bright

    Paul Tol’s bright 7-color high-contrast set

    ibm

    IBM Carbon 5-color accessible palette

    colorblind

    seaborn colorblind 10-color cycle

    deep

    seaborn deep 10-color perceptual palette

    muted

    seaborn muted toned-down palette

    tableau

    Matplotlib default Tableau-10 cycle

    classic

    Matplotlib pre-2.0 default cycle

    nature

    Nature/BioRxiv warm editorial palette

    science

    Science journal-inspired muted palette

    pastel

    Soft pastel tones for presentations

    sunset

    9-stop cool→warm diverging gradient

  • n (int, optional) – Return exactly n colors. When n > len(palette) colors are repeated cyclically.

Returns:

Hex color strings.

Return type:

list[str]

Examples

>>> Plotter.palette('wong')           # 8 colorblind-safe colors
>>> Plotter.palette('nature', n=4)    # first 4 of nature palette
>>> Plotter.palette('deep', n=12)     # 12 colors (cycles)
static palette_cycle(name: str = 'tableau') cycle[source]

Return an infinite itertools.cycle over a named palette.

Parameters:

name (str) – Same keys as palette().

Examples

>>> cyc   = Plotter.palette_cycle('wong')
>>> color = next(cyc)
static colorsC(palette: str = 'tableau') cycle[source]

Return a color cycle for a named palette.

Alias for palette_cycle().

Examples

>>> cyc = Plotter.colorsC('wong')
>>> c1  = next(cyc)
static colorsN(n: int, palette: str = 'tableau') List[str][source]

Return exactly n colors from a named palette (cycling if needed).

Parameters:

Examples

>>> c5 = Plotter.colorsN(5, 'wong')
static set_color_cycle(ax, palette: str | List = 'tableau') None[source]

Set the Matplotlib color cycle on one or more axes.

This makes subsequent ax.plot(...) calls auto-pick colors from the selected palette in order.

Parameters:
  • ax (axes or list of axes)

  • palette (str or list of colors, default='tableau') – Named palette string (see palette()) or an explicit list of any Matplotlib-compatible color specs.

Examples

>>> Plotter.set_color_cycle(ax, 'wong')
>>> ax.plot(x1, y1)   # first wong color
>>> ax.plot(x2, y2)   # second wong color
static apply_palette(axes, palette: str = 'tableau') None[source]

Apply a named color palette as the default color cycle for one or more axes. Shorthand for set_color_cycle().

Examples

>>> fig, axes = Plotter.get_subplots(1, 3)
>>> Plotter.apply_palette(axes, 'wong')
static to_rgba(color, alpha: float | None = None) Tuple[float, float, float, float][source]

Convert any Matplotlib-compatible color spec to an (r, g, b, a) tuple.

Parameters:
  • color (color spec) – Named string, hex, RGB/RGBA tuple, 'C0'-style, etc.

  • alpha (float, optional) – Override the alpha channel (0–1).

Return type:

tuple[float, float, float, float]

Examples

>>> Plotter.to_rgba('C0')
>>> Plotter.to_rgba('#E64B35', alpha=0.5)
static to_hex(color, keep_alpha: bool = False) str[source]

Convert any Matplotlib-compatible color spec to a hex string.

Parameters:
  • color (color spec)

  • keep_alpha (bool, default=False) – If True, return an 8-character #RRGGBBAA string.

Examples

>>> Plotter.to_hex('C0')              # '#1f77b4'
>>> Plotter.to_hex((0.2, 0.4, 0.6, 0.8), keep_alpha=True)
static adjust_color(color, *, lighten: float = 0.0, darken: float = 0.0, saturate: float = 0.0, desaturate: float = 0.0, alpha: float | None = None) Tuple[float, float, float, float][source]

Perceptually adjust a color in HLS space.

Each parameter shifts the corresponding channel by a fraction of the remaining headroom, so operations compose gracefully and values are always clamped to [0, 1].

Parameters:
  • color (color spec) – Any Matplotlib-compatible color.

  • lighten (float, default=0.0) – Push lightness toward 1 (white). 0 = no change, 1 = white.

  • darken (float, default=0.0) – Push lightness toward 0 (black). 0 = no change, 1 = black.

  • saturate (float, default=0.0) – Push saturation toward 1. 0 = no change, 1 = fully saturated.

  • desaturate (float, default=0.0) – Push saturation toward 0 (grey). 0 = no change, 1 = grey.

  • alpha (float, optional) – Override alpha channel (0–1).

Returns:

Adjusted RGBA color.

Return type:

tuple[float, float, float, float]

Examples

>>> Plotter.adjust_color('C0', lighten=0.3)
>>> Plotter.adjust_color('#E64B35', darken=0.4)
>>> Plotter.adjust_color('C2', desaturate=0.5, alpha=0.7)
static lighten(color, amount: float = 0.3) Tuple[float, float, float, float][source]

Return a lightened version of color (push lightness toward white).

Parameters:
  • color (color spec)

  • amount (float, default=0.3) – 0 = no change, 1 = white.

Examples

>>> fill = Plotter.lighten('C0', 0.5)
>>> Plotter.fill_between(ax, x, y1, y2, color=fill)
static darken(color, amount: float = 0.3) Tuple[float, float, float, float][source]

Return a darkened version of color (push lightness toward black).

Parameters:
  • color (color spec)

  • amount (float, default=0.3) – 0 = no change, 1 = black.

Examples

>>> edge = Plotter.darken('C0', 0.25)
static desaturate(color, amount: float = 0.5) Tuple[float, float, float, float][source]

Return a desaturated (greyed-out) version of color.

Parameters:
  • color (color spec)

  • amount (float, default=0.5) – 0 = original, 1 = fully grey.

Examples

>>> faded = Plotter.desaturate('C1', 0.6)
static with_alpha(color, a: float = 0.5) Tuple[float, float, float, float][source]

Return color with a modified alpha channel.

Parameters:
  • color (color spec)

  • a (float) – New alpha value (0–1).

Examples

>>> Plotter.fill_between(ax, x, y1, y2, color=Plotter.with_alpha('C0', 0.25))
static blend(c1, c2, t: float = 0.5, *, n: int | None = None) Tuple[float, float, float, float] | List[Tuple[float, float, float, float]][source]

Linearly interpolate between two colors in linear RGB space.

Parameters:
  • c1 (color spec) – Start and end colors.

  • c2 (color spec) – Start and end colors.

  • t (float, default=0.5) – Blend position: 0 → c1, 1 → c2. Ignored when n is set.

  • n (int, optional) – If provided, return n evenly-spaced colors from c1 to c2 (inclusive of both endpoints).

Returns:

Single RGBA tuple when n is None; list of n tuples otherwise.

Return type:

tuple or list[tuple]

Examples

>>> mid      = Plotter.blend('red', 'blue')
>>> gradient = Plotter.blend('#E64B35', '#4DBBD5', n=7)
static n_colors(n: int, cmap: str | Colormap = 'viridis', vmin: float = 0.0, vmax: float = 1.0, *, as_hex: bool = False) List[source]

Sample n evenly-spaced colors from a colormap.

Ideal for encoding a continuous parameter (temperature, time, β …) as line colors when you want a smooth gradient rather than a categorical palette.

Parameters:
  • n (int) – Number of colors to sample.

  • cmap (str or Colormap, default='viridis') – Source colormap.

  • vmin (float, default 0.0, 1.0) – Fraction range to sample from (allows using only a sub-range of the colormap, e.g. vmin=0.1, vmax=0.9 avoids the near-white ends of sequential maps).

  • vmax (float, default 0.0, 1.0) – Fraction range to sample from (allows using only a sub-range of the colormap, e.g. vmin=0.1, vmax=0.9 avoids the near-white ends of sequential maps).

  • as_hex (bool, default=False) – If True, return hex strings instead of RGBA tuples.

Return type:

list[tuple] or list[str]

Examples

>>> colors = Plotter.n_colors(5, 'plasma')
>>> for c, (x, y) in zip(colors, datasets):
...     Plotter.plot(ax, x, y, color=c)
>>> # Avoid extreme ends of the colormap
>>> colors = Plotter.n_colors(8, 'RdBu_r', vmin=0.1, vmax=0.9)
static cmap_colors(cmap: str | Colormap, values: ndarray, *, vmin: float | None = None, vmax: float | None = None, norm: Normalize | None = None, scale: str = 'linear') List[Tuple][source]

Map an array of scalar values to RGBA colors via a colormap.

Convenience wrapper around get_colormap() when you only need the list of colors (not the full getcolor / norm / mappable bundle).

Parameters:
  • cmap (str or Colormap)

  • values (array-like) – Scalar values to map.

  • vmin (float, optional) – Color limits. Default to min / max(values).

  • vmax (float, optional) – Color limits. Default to min / max(values).

  • norm (Normalize, optional) – Explicit normalization. Takes precedence over scale.

  • scale ({'linear', 'log', 'symlog'}, default='linear')

Returns:

RGBA tuples, one per value.

Return type:

list[tuple]

Examples

>>> beta_values = np.linspace(0.1, 2.0, 8)
>>> colors      = Plotter.cmap_colors('plasma', beta_values)
>>> for val, c in zip(beta_values, colors):
...     Plotter.plot(ax, x, data[val], color=c, label=rf'$\beta={val:.1f}$')
static filter_results(results, filters=None, get_params_fun: callable = None, *, tol=1e-09)[source]

Backward-compatible wrapper around plotters.data_loader.filter_results.

static get_figsize(columnwidth, wf=0.5, hf=None, aspect_ratio=None)[source]
Parameters:
  • [float] (- columnwidth) – width fraction in columnwidth units

  • [float] – height fraction in columnwidth units. If None, it will be calculated based on aspect_ratio.

  • [float] – Aspect ratio (height/width). If None, defaults to golden ratio.

  • [float] – width of the column in latex. Get this from LaTeX using showthecolumnwidth

Returns: [fig_width, fig_height]: that should be given to matplotlib

static get_color(color, alpha=None, edgecolor=(0, 0, 0, 1), facecolor=(1, 1, 1, 0))[source]

Get a dictionary with color properties for matplotlib patches. :param - color [str or tuple]: Color to use, can be a named color or an RGB tuple. :param - alpha [float]: Transparency level (0 to 1). :param - edgecolor [tuple]: Edge color as an RGB tuple. :param - facecolor [tuple]: Face color as an RGB tuple.

Returns:

Dictionary with color properties.

Return type:

  • dictionary [dict]

static add_colorbar(fig: Figure, pos: List[float], mappable: ndarray | list | _ScalarMappable, cmap: str | Colormap = 'viridis', norm: Normalize | None = None, vmin: float | None = None, vmax: float | None = None, scale: str = 'linear', orientation: str = 'vertical', label: str = '', label_kwargs: dict = None, title: str = '', title_kwargs: dict = None, ticks: List | ndarray | None = None, ticklabels: List[str] | None = None, tick_location: str = 'auto', tick_params: dict = None, extend: str = None, format: str | Formatter | None = None, discrete: bool | int = False, boundaries: List[float] = None, invert: bool = False, remove_pdf_lines: bool = True, **kwargs) Tuple[Colorbar, Axes][source]

Add a fully customizable colorbar to the figure at a specific position.

Parameters:
  • fig (matplotlib.figure.Figure) – Parent figure onto which the colorbar axis is added.

  • pos (list[float] | tuple[float, float, float, float]) – [left, bottom, width, height] in figure coordinates (0..1).

  • mappable (array-like | matplotlib.cm.ScalarMappable) –

    • If array-like: a new ScalarMappable is built from cmap/norm (and scale, vmin, vmax).

    • If ScalarMappable: it is used directly. vmin/vmax update its clim; norm is taken from it

    when not provided. Note: in this case discrete/boundaries resampling is not applied.

  • cmap (str | Colormap, default='viridis') – Colormap name or object. If mappable is a ScalarMappable, its cmap is used unless cmap is explicitly different from the default and a new mappable is constructed (array-like path).

  • norm (matplotlib.colors.Normalize, optional) – Normalization to map data to 0-1. Ignored if mappable is ScalarMappable and norm is None (then the mappable’s norm is used).

  • vmin (float, optional) – Data limits. When scale=’log’, non-positive vmin is clamped internally.

  • vmax (float, optional) – Data limits. When scale=’log’, non-positive vmin is clamped internally.

  • scale ({'linear', 'log', 'symlog'}, default='linear') – Creates a suitable Normalize when mappable is array-like and norm is None. - ‘linear’ -> Normalize - ‘log’ -> LogNorm (vmin<=0 clamped to ~1e-10) - ‘symlog’ -> SymLogNorm with linthresh=0.1

  • orientation ({'vertical', 'horizontal'}, default='vertical') – Colorbar orientation.

  • label (str, default='') – Axis label along the long side of the colorbar.

  • label_kwargs (dict, optional) – Passed to ColorbarBase.set_label (e.g., dict(fontsize=…, labelpad=…)).

  • title (str, default='') – Title text set at the end/top of the colorbar. For horizontal bars, the title is placed to the side.

  • title_kwargs (dict, optional) – Text properties for the title (e.g., dict(fontsize=…, pad=…)).

  • ticks (list[float] | np.ndarray, optional) – Explicit major tick locations.

  • ticklabels (list[str], optional) – Custom labels for the ticks (same length as ticks).

  • tick_location ({'auto','left','right','top','bottom'}, default='auto') – Side on which to draw ticks/labels (respects orientation).

  • tick_params (dict, optional) – Passed to cbar.ax.tick_params (e.g., dict(length=4, width=1, direction=’in’)).

  • extend ({'neither','both','min','max','neutral'}, default='neutral') – Colorbar extension behavior. Standard Matplotlib values are ‘neither’, ‘both’, ‘min’, ‘max’. ‘neutral’ is treated as a pass-through here and may behave like ‘neither’ depending on Matplotlib.

  • format (str | matplotlib.ticker.Formatter, optional) – Tick formatting. If str (e.g., ‘%.2e’), uses FormatStrFormatter.

  • discrete (bool | int, default=False) – Discretize colormap when building from array-like: - True -> 10 bins - int N -> N bins Ignored when mappable is a ScalarMappable.

  • boundaries (list[float], optional) – Discrete bin edges. Enables BoundaryNorm and passes boundaries to fig.colorbar (default spacing=’proportional’, overridable via kwargs[‘spacing’]).

  • invert (bool, default=False) – If True, invert the colorbar axis direction.

  • remove_pdf_lines (bool, default=True) – Set solids edgecolor to ‘face’ to avoid white hairlines in vector exports (PDF/SVG).

  • **kwargs – Additional arguments forwarded to fig.colorbar, e.g.: - alpha, spacing (‘uniform’|’proportional’), fraction, pad, shrink, aspect, drawedges, etc.

Returns:

(cbar, cax) – The created colorbar and its axes.

Return type:

tuple[matplotlib.colorbar.Colorbar, matplotlib.axes.Axes]

Notes

  • When mappable is a ScalarMappable, this helper does not modify its colormap discretization.

    To use discrete/boundaries, pass raw data (array-like) instead.

  • For ‘log’ scale, ensure your data are strictly positive (this function clamps vmin if needed).

Examples

# Vertical, linear scale from raw data cbar, cax = Plotter.add_colorbar(fig, [0.92, 0.15, 0.02, 0.7], data, label=’Mz’)

# Horizontal, log scale with sci formatting and extensions cbar, cax = Plotter.add_colorbar(

fig, [0.2, 0.9, 0.6, 0.03], data, scale=’log’, orientation=’horizontal’, format=’%.0e’, extend=’both’, tick_location=’top’, label=’Conductance’

)

# Discrete categorical-like bar with custom tick labels cbar, cax = Plotter.add_colorbar(

fig, [0.85, 0.1, 0.03, 0.8], [0, 1, 2], cmap=’Set1’, discrete=3, ticklabels=[‘Insulator’, ‘Metal’, ‘SC’]

)

# Non-uniform boundaries cbar, cax = Plotter.add_colorbar(

fig, [0.86, 0.15, 0.02, 0.7], data, boundaries=[0, 0.5, 2.0, 10.0], spacing=’proportional’

)

static get_colormap(values: ndarray | None = None, vmin=None, vmax=None, *, cmap='PuBu', elsecolor='blue', get_mappable: bool = False, return_mappable: bool | None = None, norm=None, scale='linear', **kwargs)[source]

Get a colormap for the given values.

Parameters: - values (array-like): The values to map to colors. - cmap (str, optional): The colormap to use. Defaults to ‘PuBu’. - elsecolor (str, optional): The color to use if there is only one value. Defaults to ‘blue’. - get_mappable (bool, optional): If True, also return a ScalarMappable as

the 4th item, ready to pass into Plotter.add_colorbar(…, mappable=…).

  • return_mappable (bool, optional): Alias for get_mappable.

Returns: - getcolor (function): A function that maps a value to a color. - colors (Colormap): The colormap object. - norm (Normalize): The normalization object. - mappable (ScalarMappable, optional): Returned when get_mappable=True

(or return_mappable=True).

Example: >>> getcolor, colors, norm = Plotter.get_colormap([1, 2, 3], cmap=’viridis’) >>> color = getcolor(2.5) >>> getcolor, colors, norm, mappable = Plotter.get_colormap( … [1, 2, 3], cmap=’viridis’, return_mappable=True … )

static apply_colormap(ax, data, cmap='PuBu', colorbar=True, **kwargs)[source]

Apply a colormap to the given data and plot it on the provided axis.

Parameters: - ax (object): The axis object to plot on. - data (array-like): The data to plot. - cmap (str, optional): The colormap to use. Defaults to ‘PuBu’. - colorbar (bool, optional): Whether to add a colorbar. Defaults to True.

Returns: - img (AxesImage): The image object.

static discrete_colormap(N, base_cmap=None)[source]

Create an N-bin discrete colormap from the specified input map.

Parameters: - N (int): Number of discrete colors. - base_cmap (str or Colormap, optional): The base colormap to use. Defaults to None.

Returns: - cmap (Colormap): The discrete colormap.

static set_annotate(ax, elem: str, x: float = 0, y: float = 0, fontsize=None, xycoords='axes fraction', cond=True, zorder=50, boxaround=True, **kwargs)[source]

Make an annotation on the plot. - ax : axis to annotate on - elem : annotation string - x : x coordinate (ignored if xycoords=’best’) - y : y coordinate (ignored if xycoords=’best’) - fontsize : fontsize of the annotation - xycoords : how to interpret the coordinates (from MPL), or ‘best’ to find the best corner - cond : condition to make the annotation

static set_annotate_letter(ax: Axes, iter: int, x: float = 0, y: float = 0, fontsize=12, xycoords='axes fraction', addit='', condition=True, zorder=50, boxaround=False, fontweight='normal', color='black', **kwargs)[source]

Annotate plot with the letter.

Params:

ax: matplotlib.axes.Axes

axis to annotate on

iter:

iteration number

x:

x coordinate

y:

y coordinate

fontsize:

fontsize

xycoords:

how to interpret the coordinates (from MPL)

addit:

additional string to add after the letter

condition:

condition to make the annotation

zorder:

zorder of the annotation

boxaround:

whether to put a box around the annotation

fontweight:

weight of the text (‘bold’, ‘normal’, etc.)

kwargs:

additional arguments for annotation - color : color of the text - weight: weight of the text

Example:

>>> Plotter.set_annotate_letter(ax, 0, x=0.1, y=0.9, fontsize=14, addit=' Test', color='red')
static set_arrow(ax, start_T: str, end_T: str, xystart: float, xystart_T: float, xyend: float, xyend_T: float, arrowprops={'arrowstyle': '->'}, startcolor='black', endcolor='black', **kwargs)[source]

@staticmethod

Make an annotation on the plot. - ax : axis to annotate on - start_T : start text - end_T : end text - xystart : x coordinate start - xystart_T : x coordinate start text - xyend : x coordinate end - xyend_T : x coordinate end text - arrowprops: properties of the arrow - startcolor: color of the arrow at the start - endcolor : color of the arrow in the end - kwargs : additional arguments for annotation

static callout(ax, text: str, xy, xytext=None, *, xycoords: str = 'data', textcoords: str | None = None, arrowstyle: str = '->', color: str = 'black', lw: float = 1.0, boxaround: bool = True, box_alpha: float = 0.85, zorder: int = 20, **kwargs)[source]

Add a compact callout (text + optional arrow) to an axis.

static highlight_box(ax, x: float, y: float, width: float, height: float, *, coords: str = 'data', edgecolor='crimson', facecolor='none', lw: float = 1.3, ls='-', alpha: float = 0.95, zorder: int = 15, **kwargs)[source]

Draw a highlighted rectangular region in data or axes coordinates.

static highlight_circle(ax, x: float, y: float, radius: float, *, coords: str = 'data', edgecolor='darkorange', facecolor='none', lw: float = 1.3, ls='-', alpha: float = 0.95, zorder: int = 15, **kwargs)[source]

Draw a highlighted circular region in data or axes coordinates.

static plot_fit(ax, funct, x, **kwargs)[source]

@staticmethod

Plots the fitting function provided by the user on a given axis using the **kwargs provider afterwards. - ax : axis to annotate on - funct : function to use for the fitting - x : arguments to the function

static hline(ax: Axes, val: float, ls='--', lw=2.0, color='black', label=None, zorder=10, label_cond=True, **kwargs)[source]

horizontal line plotting

static vline(ax, val: float, ls='--', lw=2.0, color='black', label=None, zorder=10, label_cond=True, **kwargs)[source]

vertical line plotting

static scatter(ax, x, y, *, s=10, c='blue', marker='o', alpha=1.0, label=None, edgecolor=None, zorder=5, label_cond=True, linewidths=1.0, cmap=None, norm=None, vmin=None, vmax=None, plotnonfinite=False, clip_on=True, rasterized=False, **kwargs)[source]

Creates a scatter plot on the provided axis, styled for Nature-like plots.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis on which to draw the scatter plot.

  • x (array-like) – The x-coordinates of the points.

  • y (array-like) – The y-coordinates of the points.

  • s (float or array-like, optional) – The size of the points (default: 10).

  • c (color or array-like, optional) – The color of the points (default: ‘blue’).

  • marker (str, optional) – The shape of the points (default: ‘o’).

  • alpha (float, optional) – The transparency of the points (0.0 to 1.0, default: 1.0).

  • label (str, optional) – The label for the points (default: None).

  • edgecolor (str or array-like, optional) – The edge color of the points (default: ‘white’).

  • zorder (int, optional) – The drawing order of the points (default: 5).

  • **kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.scatter.

Example

scatter(ax, x_data, y_data, s=20, c=’red’, alpha=0.5, label=’Sample Data’)

static tripcolor_field(ax, points, values, *, triangles=None, mask=None, shading: str = 'gouraud', **kwargs)[source]

Plot a scalar field sampled on irregular planar points using triangulation.

This helper is meant for 2D scattered data where imshow is not appropriate because the samples do not lie on a regular rectangular grid. Matplotlib first builds a triangulation of the point cloud and then interpolates values inside each triangle.

Typical use cases: - real-space lattice-site scalar fields - Brillouin-zone data on irregular planar k-point sets - any scattered 2D measurement data

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • points (array-like) – Sample positions shaped like (N, 2) or (N, D) with D >= 2. Only the first two Cartesian components are used.

  • values (array-like) – Scalar values of length N.

  • triangles (array-like, optional) – Explicit connectivity passed to matplotlib.tri.Triangulation.

  • mask (array-like of bool, optional) – Triangle mask. True hides the corresponding triangle.

  • shading ({'flat', 'gouraud'}, default='gouraud') – Interpolation mode inside triangles.

  • **kwargs – Forwarded to Axes.tripcolor.

Returns:

The created artist, or None if fewer than three points are given.

Return type:

matplotlib.collections.Collection | None

static plot(ax, *args, y=None, x=None, ls='-', lw=2.0, color='black', label=None, label_cond=True, marker=None, ms=None, zorder=5, drawstyle='default', markevery=None, clip_on=True, rasterized=False, antialiased=True, solid_capstyle=None, solid_joinstyle=None, **kwargs)[source]

plot the data

static fill_between(ax, x, y1, y2, color='blue', alpha=0.5, where=None, interpolate=False, step=None, linewidth=0.0, edgecolor=None, zorder=4, clip_on=True, rasterized=False, **kwargs)[source]

Fills the area between two curves on the provided axis.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis on which to fill the area.

  • x (array-like) – The x-coordinates of the points.

  • y1 (array-like) – The y-coordinates of the first curve.

  • y2 (array-like) – The y-coordinates of the second curve.

  • color (str, optional) – The color of the filled area (default: ‘blue’).

  • alpha (float, optional) – The transparency of the filled area (0.0 to 1.0, default: 0.5).

  • **kwargs – Additional keyword arguments passed to matplotlib.axes.Axes.fill_between.

Example

fill_between(ax, x_data, y1_data, y2_data, color=’red’, alpha=0.3)

static semilogy(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic y-axis.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot.

  • y (array-like) – Data to plot.

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.semilogy.

Examples

>>> Plotter.semilogy(ax, x, np.exp(-x), color='C0', label=r'$e^{-x}$')
static semilogx(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic x-axis.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot.

  • y (array-like) – Data to plot.

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.semilogx.

Examples

>>> Plotter.semilogx(ax, np.logspace(-3, 3, 100), y, color='C1')
static loglog(ax, x, y, ls='-', lw=1.5, color='black', label=None, marker=None, ms=None, label_cond=True, zorder=5, **kwargs)[source]

Plot with logarithmic x and y axes.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data to plot (must be positive).

  • y (array-like) – Data to plot (must be positive).

  • ls (str, default='-') – Line style.

  • lw (float, default=1.5) – Line width.

  • color (str or int, default='black') – Line color. If int, uses colorsList[color].

  • label (str, optional) – Legend label.

  • marker (str, optional) – Marker style.

  • ms (float, optional) – Marker size.

  • **kwargs – Additional arguments passed to ax.loglog.

Examples

>>> # Power law: y = x^(-2)
>>> x = np.logspace(0, 3, 50)
>>> Plotter.loglog(ax, x, x**(-2), label=r'$x^{-2}$', color='C2')
static errorbar(ax, x, y, yerr=None, xerr=None, fmt='o', color='black', capsize=2, capthick=1.0, elinewidth=1.0, markersize=5, label=None, label_cond=True, alpha=1.0, zorder=5, ecolor=None, errorevery=1, barsabove=False, uplims=False, lolims=False, xuplims=False, xlolims=False, clip_on=True, rasterized=False, **kwargs)[source]

Plot data with error bars.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • x (array-like) – Data points.

  • y (array-like) – Data points.

  • yerr (float or array-like, optional) – Vertical error bars. Can be: - scalar: symmetric error for all points - 1D array: symmetric errors - 2D array (2, N): asymmetric [lower, upper] errors

  • xerr (float or array-like, optional) – Horizontal error bars (same format as yerr).

  • fmt (str, default='o') – Format string for markers (’’ for no markers, just error bars).

  • color (str or int, default='black') – Color for markers and error bars.

  • capsize (float, default=2) – Length of error bar caps.

  • capthick (float, default=1.0) – Thickness of error bar caps.

  • elinewidth (float, default=1.0) – Width of error bar lines.

  • markersize (float, default=5) – Size of markers.

  • label (str, optional) – Legend label.

  • alpha (float, default=1.0) – Transparency.

  • **kwargs – Additional arguments passed to ax.errorbar.

Examples

>>> # Symmetric error
>>> Plotter.errorbar(ax, x, y, yerr=sigma, label='Data')
>>> # Asymmetric error
>>> Plotter.errorbar(ax, x, y, yerr=[lower_err, upper_err])
>>> # Error band without markers
>>> Plotter.errorbar(ax, x, y, yerr=sigma, fmt='', elinewidth=2)
static histogram(ax, data, bins=50, density=True, histtype='stepfilled', alpha=0.7, color='C0', edgecolor='black', linewidth=1.0, label=None, orientation='vertical', cumulative=False, log=False, label_cond=True, zorder=5, weights=None, range=None, align='mid', rwidth=None, stacked=False, hatch=None, **kwargs)[source]

Plot a histogram.

Parameters:
  • ax (matplotlib.axes.Axes) – Axis to plot on.

  • data (array-like) – Input data.

  • bins (int or array-like, default=50) – Number of bins or bin edges.

  • density (bool, default=True) – If True, normalize to form a probability density.

  • histtype (str, default='stepfilled') – Type of histogram: ‘bar’, ‘barstacked’, ‘step’, ‘stepfilled’.

  • alpha (float, default=0.7) – Transparency.

  • color (str, default='C0') – Fill color.

  • edgecolor (str, default='black') – Edge color.

  • linewidth (float, default=1.0) – Edge line width.

  • label (str, optional) – Legend label.

  • orientation (str, default='vertical') – ‘vertical’ or ‘horizontal’.

  • cumulative (bool, default=False) – If True, plot cumulative histogram.

  • log (bool, default=False) – If True, use log scale for counts axis.

  • **kwargs – Additional arguments passed to ax.hist.

Returns:

  • n (array) – Histogram values.

  • bins (array) – Bin edges.

  • patches (list) – Patch objects.

Examples

>>> # Basic histogram
>>> Plotter.histogram(ax, data, bins=30, label='Distribution')
>>> # Step histogram (unfilled)
>>> Plotter.histogram(ax, data, histtype='step', linewidth=2)
>>> # Cumulative distribution
>>> Plotter.histogram(ax, data, cumulative=True, density=True)
static contourf(ax, x, y, z, **kwargs)[source]

contourf plotting

static grid(ax, **kwargs)[source]

grid plotting

Kwargs include: - which : {‘major’, ‘minor’, ‘both’}, optional, default: ‘major’

  • Specifies which grid lines to apply the settings to.

  • axis{‘both’, ‘x’, ‘y’}, optional, default: ‘both
    • Specifies which axis to apply the grid settings to.

  • colorcolor, optional
    • Color of the grid lines.

  • linestylestr, optional
    • Style of the grid lines (e.g., ‘-’, ‘–’, ‘-.’, ‘:’).

  • linewidthfloat, optional
    • Width of the grid lines.

  • alphafloat, optional
    • Transparency of the grid lines (0.0 to 1.0).

static set_tickparams(ax, labelsize=None, left=True, right=True, top=True, bottom=True, xticks=None, yticks=None, xticklabels=None, yticklabels=None, maj_tick_l=4, min_tick_l=2, **kwargs)[source]

Sets tickparams to the desired ones. - ax : axis to use - labelsize : fontsize - left : whether to show the left side - right : whether to show the right side - top : whether to show the top side - bottom : whether to show the bottom side - xticks : list of xticks - yticks : list of yticks

static set_ax_params(ax, which: str = 'both', xlabel: str | None = None, ylabel: str | None = None, title: str | None = None, fontsize: int | None = None, labelsize_title: int | None = None, labelsize_tick: int | None = None, labelpad: float | dict = 0.0, title_pad: float = 10.0, xlabel_position: Literal['top', 'bottom'] = 'bottom', ylabel_position: Literal['left', 'right'] = 'left', xlim: tuple | None = None, ylim: tuple | None = None, xscale: Literal['linear', 'log', 'symlog'] = 'linear', yscale: Literal['linear', 'log', 'symlog'] = 'linear', xticks: list | ndarray | None = None, yticks: list | ndarray | None = None, xticklabels: list | None = None, yticklabels: list | None = None, xtickpos: Literal['top', 'bottom', 'both'] = None, ytickpos: Literal['left', 'right', 'both'] = None, tick_length_major: float = 4.0, tick_length_minor: float = 2.0, tick_width: float = 0.8, tick_direction: Literal['in', 'out', 'inout'] = 'in', show_minor_ticks: bool = True, minor_tick_locator: str | None = 'auto', grid: bool = False, grid_axis: Literal['both', 'x', 'y'] = 'both', grid_which: Literal['major', 'minor', 'both'] = 'major', grid_style: str = '--', grid_color: str | None = None, grid_alpha: float = 0.3, grid_linewidth: float = 0.8, show_spines: bool | dict = True, spine_width: float = 1.0, spine_color: str = 'black', aspect: str | float | None = None, tight_layout: bool = False, legend: bool = False, legend_kwargs: dict | None = None, invert_xaxis: bool = False, invert_yaxis: bool = False, auto_formatter: bool = True, label_cond: bool = True, label_pos: dict = None, tick_pos: dict = None, **kwargs)[source]

Comprehensive axis configuration method for publication-quality plots.

This method provides centralized control over all major axis properties with sensible defaults and advanced options for fine-tuning. It integrates with other Plotter methods for a cohesive styling experience.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis object to modify.

  • which ({'both', 'x', 'y'}, default='both') – Specifies which axes to update. Allows independent configuration of x and y axes.

  • Titles** (**Labels and)

  • xlabel (str, optional) – Axis labels. Set to ‘’ to hide labels while maintaining formatting.

  • ylabel (str, optional) – Axis labels. Set to ‘’ to hide labels while maintaining formatting.

  • title (str, optional) – Axis title.

  • fontsize (int, optional) – Default font size for labels (overridable per-element).

  • labelsize_title (int, optional) – Font size for title. If None, uses fontsize + 2.

  • labelsize_tick (int, optional) – Font size for tick labels. If None, uses fontsize - 2.

  • labelpad (float or dict, default=0.0) – Padding between label and axis. Can be {‘x’: val, ‘y’: val}.

  • title_pad (float, default=10.0) – Vertical padding between title and plot area.

  • Positioning** (**Label)

  • xlabel_position ({'top', 'bottom'}, default='bottom') – Position of x-axis label.

  • ylabel_position ({'left', 'right'}, default='left') – Position of y-axis label.

  • Scales** (**Axis Limits and)

  • xlim (tuple, optional) – Axis limits as (min, max). Use None for auto limits.

  • ylim (tuple, optional) – Axis limits as (min, max). Use None for auto limits.

  • xscale ({'linear', 'log', 'symlog'}, default='linear') – Axis scale type. ‘symlog’ uses symmetric log scaling.

  • yscale ({'linear', 'log', 'symlog'}, default='linear') – Axis scale type. ‘symlog’ uses symmetric log scaling.

  • Configuration** (**Spine)

  • xticks (list or np.ndarray, optional) – Explicit tick positions. Leave None for matplotlib auto-ticks.

  • yticks (list or np.ndarray, optional) – Explicit tick positions. Leave None for matplotlib auto-ticks.

  • xticklabels (list, optional) – Custom tick labels. Must match length of ticks if provided.

  • yticklabels (list, optional) – Custom tick labels. Must match length of ticks if provided.

  • tick_length_major (float, default=4.0) – Length of major ticks in points.

  • tick_length_minor (float, default=2.0) – Length of minor ticks in points.

  • tick_width (float, default=0.8) – Width of ticks in points.

  • tick_direction ({'in', 'out', 'inout'}, default='in') – Direction ticks point (‘in’ recommended for publication).

  • show_minor_ticks (bool, default=True) – Whether to show minor ticks.

  • minor_tick_locator ({'auto', 'log'}, default='auto') – How to locate minor ticks. ‘log’ uses LogLocator for log scales.

  • Configuration**

  • grid (bool, default=False) – Enable gridlines.

  • grid_axis ({'both', 'x', 'y'}, default='both') – Which axes to show grid on.

  • grid_which ({'major', 'minor', 'both'}, default='major') – Which ticks to grid on.

  • grid_style (str, default='--') – Line style (‘-’, ‘–’, ‘-.’, ‘:’).

  • grid_color (str, optional) – Grid color. If None, uses current axes color scheme.

  • grid_alpha (float, default=0.3) – Transparency of gridlines (0=transparent, 1=opaque).

  • grid_linewidth (float, default=0.8) – Width of gridlines in points.

  • Configuration**

  • show_spines (bool or dict, default=True) – Visibility of spines. - True: show all spines - False: hide all spines - dict: {‘top’: bool, ‘bottom’: bool, ‘left’: bool, ‘right’: bool}

  • spine_width (float, default=1.0) – Width of spines in points.

  • spine_color (str, default='black') – Color of spines.

  • **Appearance**

  • aspect (str or float, optional) – Aspect ratio (‘equal’, ‘auto’) or numeric value.

  • tight_layout (bool, default=False) – Apply tight layout after configuration.

  • **Legend**

  • legend (bool, default=False) – Whether to display legend using set_legend().

  • legend_kwargs (dict, optional) – Arguments to pass to set_legend() if legend=True.

  • Options** (**Advanced)

  • invert_xaxis (bool, default=False) – Invert the direction of the axes.

  • invert_yaxis (bool, default=False) – Invert the direction of the axes.

  • auto_formatter (bool, default=True) – Automatically apply scientific notation formatter for large/small numbers.

  • **kwargs – Additional keyword arguments passed to matplotlib functions.

Examples

Example 1: Basic publication-ready plot

>>> ax = plt.gca()
>>> Plotter.set_ax_params(
...     ax,
...     xlabel=r'$x$ (nm)',
...     ylabel=r'Energy (eV)',
...     title='Band Structure',
...     xlim=(0, 10),
...     ylim=(-5, 5),
...     grid=True
... )

Example 2: Log-scale with custom ticks

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Frequency (Hz)',
...     ylabel='Magnitude',
...     yscale='log',
...     yticks=[1, 10, 100, 1000],
...     yticklabels=['1', '10', '100', '1 k'],
...     grid=True,
...     grid_which='both',
...     minor_tick_locator='log'
... )

Example 3: Detailed styling

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Temperature (K)',
...     ylabel=r'$\rho$ (Ω·cm)',
...     title='Resistivity vs Temperature',
...     fontsize=12,
...     labelsize_title=14,
...     labelsize_tick=10,
...     xlim=(0, 300),
...     ylim=(0, None),  # auto max
...     grid=True,
...     grid_style='--',
...     grid_alpha=0.4,
...     show_spines={'top': False, 'right': False},
...     spine_width=1.5,
...     tick_length_major=6,
...     legend=True,
...     tight_layout=True
... )

Example 4: Custom tick labels and positions

>>> import numpy as np
>>> Plotter.set_ax_params(
...     ax,
...     xticks=np.linspace(0, 2*np.pi, 5),
...     xticklabels=['0', 'π/2', 'π', '3π/2', '2π'],
...     xlabel=r'Phase',
...     ylabel=r'$\sin(\phi)$'
... )

Example 5: Asymmetric spines (Nature style)

>>> Plotter.set_ax_params(
...     ax,
...     xlabel='Parameter',
...     ylabel='Value',
...     show_spines={'left': True, 'bottom': True, 'top': False, 'right': False},
...     grid=False,
...     tick_direction='out'
... )

Notes

  • Set labelsize_* to None to auto-scale relative to fontsize

  • Grid is best used with light colors and low alpha (0.2-0.4)

  • For log scales, minor_tick_locator=’log’ is recommended

  • Use which=’x’ or which=’y’ for independent axis control

  • Integrates with Plotter.set_legend() for unified styling

See also

set_legend

Configure legend appearance

set_tickparams

Alternative tick configuration method

grid

Add gridlines to axis

static set_xlabel(ax, xlabel, fontsize=None, labelpad=0, loc=None, x=None, y=None, coords: str = 'axes', transform=None, **kwargs)[source]

Set x-axis label with optional alignment and explicit coordinates.

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • xlabel (str) – Label text.

  • fontsize (int, optional) – Label font size.

  • labelpad (float, default=0) – Padding in points.

  • loc ({'left', 'center', 'right'}, optional) – Matplotlib label location argument.

  • x (float, optional) – Explicit label coordinates (if either is provided).

  • y (float, optional) – Explicit label coordinates (if either is provided).

  • coords ({'axes', 'data'}, default='axes') – Coordinate system used for x/y when transform is not provided.

  • transform (matplotlib transform, optional) – Explicit transform for label coordinates.

  • **kwargs – Forwarded to ax.set_xlabel.

static set_ylabel(ax, ylabel, fontsize=None, labelpad=0, loc=None, x=None, y=None, coords: str = 'axes', transform=None, **kwargs)[source]

Set y-axis label with optional alignment and explicit coordinates.

Parameters:
  • ax (matplotlib.axes.Axes) – Target axis.

  • ylabel (str) – Label text.

  • fontsize (int, optional) – Label font size.

  • labelpad (float, default=0) – Padding in points.

  • loc ({'bottom', 'center', 'top'}, optional) – Matplotlib label location argument.

  • x (float, optional) – Explicit label coordinates (if either is provided).

  • y (float, optional) – Explicit label coordinates (if either is provided).

  • coords ({'axes', 'data'}, default='axes') – Coordinate system used for x/y when transform is not provided.

  • transform (matplotlib transform, optional) – Explicit transform for label coordinates.

  • **kwargs – Forwarded to ax.set_ylabel.

static set_ax_labels(ax, fontsize=None, xlabel='', ylabel='', title='', xPad=0, yPad=0, xloc=None, yloc=None, xcoords: str = 'axes', ycoords: str = 'axes', x_pos=None, y_pos=None)[source]

Sets the labels of the x and y axes

static set_label_cords(ax, which: str, inX=0.0, inY=0.0, **kwargs)[source]

Sets the coordinates of the labels

static setup_log_y(ax: Axes, ylims=(1e-12, 1000000.0), decade_step=4)[source]

Configure clean log-scale y ticks at powers of 10 with LaTeX-like labels.

static setup_log_x(ax: Axes, xlims=(1e-12, 1000000.0), decade_step=4)[source]

Configure clean log-scale x ticks at powers of 10 with LaTeX-like labels.

static set_smart_lim(ax, *, which: str = 'both', data: ndarray | None = None, margin_p: float = 0, margin_m: float = 1, xlim: tuple | None = None, ylim: tuple | None = None)[source]

Auto-compute robust axis limits and apply them to ax.

static hide_unused_panels(axes: Axes, n_panels: int)[source]

Hide unused panels in a subplot grid.

static labellines(ax, align=False, xvals=None, yoffsets=[], zorder=2, **kwargs)[source]

Add labels to lines with a given slope. Uses labelLines package. :param - ax: Matplotlib axis object. :param - align: Align the label with the slope of the line. :param - xvals: The x values to place the labels at. :param - yoffsets: The y offsets for the labels. :param - zorder: The zorder of the labels.

static unset_spines(ax, top: bool = True, right: bool = True, bottom: bool = False, left: bool = False)[source]

Remove specified spines from the axis for cleaner publication-style plots.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • top (bool, default=True) – If True, REMOVE the top spine. If False, KEEP it.

  • right (bool, default=True) – If True, REMOVE the right spine. If False, KEEP it.

  • bottom (bool, default=False) – If True, REMOVE the bottom spine. If False, KEEP it.

  • left (bool, default=False) – If True, REMOVE the left spine. If False, KEEP it.

Examples

# Nature-style (remove top and right, keep left and bottom) - DEFAULT >>> Plotter.unset_spines(ax)

# Remove all spines (frameless plot) >>> Plotter.unset_spines(ax, top=True, right=True, bottom=True, left=True)

# Keep all spines >>> Plotter.unset_spines(ax, top=False, right=False, bottom=False, left=False)

# Only keep bottom spine (minimal style) >>> Plotter.unset_spines(ax, top=True, right=True, bottom=False, left=True)

Notes

The default settings (top=True, right=True) produce the classic “Nature” or “Science” journal style with only left and bottom spines.

static unset_ticks(ax, xticks: bool = False, yticks: bool = False, xlabel: bool = False, ylabel: bool = False, remove_labels_only: bool = True)[source]

Remove tick labels (and optionally tick marks) from the axis.

Useful for creating clean shared-axis plots where inner panels don’t need redundant tick labels.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • xticks (bool, default=False) – If True, REMOVE x-tick labels. If False, KEEP them.

  • yticks (bool, default=False) – If True, REMOVE y-tick labels. If False, KEEP them.

  • xlabel (bool, default=False) – If True, also REMOVE the x-axis label.

  • ylabel (bool, default=False) – If True, also REMOVE the y-axis label.

  • remove_labels_only (bool, default=True) – If True, only remove the text labels, keeping tick marks visible. If False, remove both the tick marks and labels.

Examples

# Remove x-tick labels for stacked plots with shared x-axis >>> for ax in axes[:-1]: # All except bottom … Plotter.unset_ticks(ax, xticks=True, xlabel=True)

# Remove all tick labels (keep tick marks) >>> Plotter.unset_ticks(ax, xticks=True, yticks=True)

# Remove tick marks AND labels (completely clean) >>> Plotter.unset_ticks(ax, xticks=True, yticks=True, remove_labels_only=False)

# Remove y-ticks and y-axis label for side-by-side shared y-axis >>> for ax in axes[1:]: # All except leftmost … Plotter.unset_ticks(ax, yticks=True, ylabel=True)

Notes

This function is commonly used in combination with sharex/sharey in multi-panel figures to avoid redundant labels.

static unset_all(ax, spines: bool = True, ticks: bool = True, labels: bool = True)[source]

Completely strip an axis of spines, ticks, and labels.

Useful for image plots, heatmaps, or decorative panels where axis elements are not needed.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • spines (bool, default=True) – If True, remove all spines.

  • ticks (bool, default=True) – If True, remove all tick marks and labels.

  • labels (bool, default=True) – If True, remove axis labels.

Examples

# Completely clean axis (for images/heatmaps) >>> Plotter.unset_all(ax)

# Keep only spines (box around plot) >>> Plotter.unset_all(ax, spines=False)

static unset_ticks_and_spines(ax, xticks: bool = True, yticks: bool = True, top: bool = True, right: bool = True, bottom: bool = False, left: bool = False)[source]

Convenience method to remove both ticks and spines in one call.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to modify.

  • xticks (bool, default=True) – If True, REMOVE x-tick labels.

  • yticks (bool, default=True) – If True, REMOVE y-tick labels.

  • top (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • right (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • bottom (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

  • left (bool) – If True, REMOVE the corresponding spine. Defaults remove top and right (Nature-style).

Examples

# Clean Nature-style with no tick labels >>> Plotter.unset_ticks_and_spines(ax)

# Only remove top/right spines, keep all ticks >>> Plotter.unset_ticks_and_spines(ax, xticks=False, yticks=False)

static set_formater(ax, formater='%.1e', axis='xy')[source]

Sets the formatter for the given axis on the plot. :param ax: The axis object on which to set the formatter. :type ax: object :param formater: The format string for the axis labels. Defaults to “%.1e”. :type formater: str, optional :param axis: The axis on which to set the formatter. Defaults to ‘xy’. :type axis: str, optional

Returns:

None

static set_standard_formater(ax, axis='xy')[source]

Sets the formatter for the given axis on the plot. :param ax: The axis object on which to set the formatter. :type ax: object :param axis: The axis on which to set the formatter. Defaults to ‘xy’. :type axis: str, optional

Returns:

None

class GridBuilder(figsize=(10, 8))[source]

Bases: object

Builder class for creating complex figure layouts with nested grids.

Use this when you need different numbers of columns in different rows, or complex nested arrangements that can’t be achieved with a simple grid.

Parameters:

figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

Examples

Create a layout with varying column counts per row:

>>> builder = Plotter.GridBuilder(figsize=(12, 8))
>>> builder.add_row(ncols=1, height_ratio=1)     # Header row (1 panel)
>>> builder.add_row(ncols=3, height_ratio=2)     # Main row (3 panels)
>>> builder.add_row(ncols=2, height_ratio=1.5)   # Footer row (2 panels)
>>> fig, axes = builder.build(wspace=0.2, hspace=0.3)
>>> # axes = [[ax00], [ax10, ax11, ax12], [ax20, ax21]]

Access axes:

>>> header_ax = axes[0][0]
>>> main_left, main_center, main_right = axes[1]
>>> footer_left, footer_right = axes[2]
__init__(figsize=(10, 8))[source]
add_row(ncols: int, height_ratio: float = 1.0, width_ratios: List[float] = None)[source]

Add a row to the layout.

Parameters:
  • ncols (int) – Number of columns in this row.

  • height_ratio (float, default=1.0) – Relative height of this row compared to others.

  • width_ratios (list of float, optional) – Relative widths of columns within this row. If None, columns are equal width.

Returns:

self – For method chaining.

Return type:

GridBuilder

build(wspace: float = 0.2, hspace: float = 0.2, left: float = 0.1, right: float = 0.95, top: float = 0.95, bottom: float = 0.1)[source]

Build the figure with the specified layout.

Parameters:
  • wspace (float, default=0.2) – Horizontal space between columns within rows.

  • hspace (float, default=0.2) – Vertical space between rows.

  • left (float) – Figure margins (fraction of figure size).

  • right (float) – Figure margins (fraction of figure size).

  • top (float) – Figure margins (fraction of figure size).

  • bottom (float) – Figure margins (fraction of figure size).

Returns:

  • fig (matplotlib.figure.Figure) – The created figure.

  • axes (list of lists) – 2D list of axes, where axes[row][col] gives the axis at that position.

static make_grid(nrows: int, ncols: int, figsize: tuple = (10, 8), width_ratios: List[float] = None, height_ratios: List[float] = None, wspace: float = 0.2, hspace: float = 0.2, left: float = 0.1, right: float = 0.95, top: float = 0.95, bottom: float = 0.1, sharex: str = False, sharey: str = False, panel_labels: bool = False, panel_label_style: str = 'parenthesis', despine: bool = False)[source]

Create a figure with a grid of subplots with full control over layout.

This is the recommended method for creating publication-quality multi-panel figures with precise control over spacing and sizing.

Parameters:
  • nrows (int) – Number of rows.

  • ncols (int) – Number of columns.

  • figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

  • width_ratios (list of float, optional) – Relative widths of columns. Length must equal ncols. Example: [2, 1, 1] makes first column 2x wider.

  • height_ratios (list of float, optional) – Relative heights of rows. Length must equal nrows. Example: [1, 3] makes second row 3x taller.

  • wspace (float, default=0.2) – Horizontal space between columns (fraction of avg width).

  • hspace (float, default=0.2) – Vertical space between rows (fraction of avg height).

  • left (float) – Figure margins (0 to 1, fraction of figure size).

  • right (float) – Figure margins (0 to 1, fraction of figure size).

  • top (float) – Figure margins (0 to 1, fraction of figure size).

  • bottom (float) – Figure margins (0 to 1, fraction of figure size).

  • sharex (str or bool, default=False) – Share x-axis: ‘row’, ‘col’, ‘all’, or False.

  • sharey (str or bool, default=False) – Share y-axis: ‘row’, ‘col’, ‘all’, or False.

  • panel_labels (bool, default=False) – Add panel labels (a), (b), (c), etc.

  • panel_label_style (str, default='parenthesis') – Style for panel labels: ‘parenthesis’, ‘plain’, ‘bold’.

  • despine (bool, default=False) – Remove top and right spines (Nature-style).

Returns:

  • fig (matplotlib.figure.Figure) – The created figure.

  • axes (list of Axes) – Flat list of axes [ax0, ax1, ax2, …], row-major order.

Examples

Basic 2x3 grid:

>>> fig, axes = Plotter.make_grid(2, 3, figsize=(10, 6))
>>> ax0, ax1, ax2, ax3, ax4, ax5 = axes

Unequal column widths:

>>> fig, axes = Plotter.make_grid(1, 2, width_ratios=[3, 1])

Stacked panels with shared x-axis:

>>> fig, axes = Plotter.make_grid(3, 1, sharex='col', hspace=0.05)
>>> for ax in axes[:-1]:

… Plotter.unset_ticks(ax, xticks=True, xlabel=True)

Publication figure:

>>> fig, axes = Plotter.make_grid(2, 2, figsize=(8, 8),

… panel_labels=True, despine=True)

static figure(figsize: tuple = (10, 8), **kwargs) Figure[source]

Create a Matplotlib figure with specified size and options.

Parameters:
  • figsize (tuple, default=(10, 8)) – Figure size in inches (width, height).

  • **kwargs – Additional keyword arguments passed to plt.figure().

Returns:

The created figure object.

Return type:

matplotlib.figure.Figure

Examples

Basic figure creation:

>>> fig = Plotter.figure(figsize=(12, 6))

With additional options:

>>> fig = Plotter.figure(figsize=(8, 8), dpi=150, facecolor='white')
static get_grid(nrows: int, ncols: int, *, wspace: float = None, hspace: float = None, width_ratios: List[float] = None, height_ratios: List[float] = None, ax_sub=None, left: float = None, right: float = None, top: float = None, bottom: float = None, figure=None, **kwargs) GridSpec[source]

Create a GridSpec for flexible subplot layouts.

This is the foundation for creating complex multi-panel figures with control over panel sizes and spacing.

Parameters:
  • nrows (int) – Number of rows in the grid.

  • ncols (int) – Number of columns in the grid.

  • wspace (float, optional) – Width space between columns (0.0 to 1.0, fraction of average axis width). Recommended: 0.2-0.4 for labels, 0.05-0.1 for tight layouts.

  • hspace (float, optional) – Height space between rows (0.0 to 1.0, fraction of average axis height). Recommended: 0.2-0.4 for titles, 0.05-0.1 for tight layouts.

  • width_ratios (list of float, optional) – Relative widths of columns. E.g., [2, 1, 1] makes first column 2x wider. Length must equal ncols.

  • height_ratios (list of float, optional) – Relative heights of rows. E.g., [1, 2] makes second row 2x taller. Length must equal nrows.

  • ax_sub (SubplotSpec, optional) – If provided, creates a nested GridSpec within this subplot. Use for complex layouts with grids inside grids.

  • left (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • right (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • top (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • bottom (float, optional) – Figure margins (0.0 to 1.0). Controls space for labels.

  • **kwargs – Additional arguments passed to GridSpec.

Returns:

The grid specification object.

Return type:

GridSpec or GridSpecFromSubplotSpec

Examples

Basic 2x3 grid:

>>> fig = plt.figure(figsize=(12, 8))
>>> gs  = Plotter.get_grid(2, 3, wspace=0.3, hspace=0.4)
>>> ax0 = fig.add_subplot(gs[0, 0])  # Row 0, Col 0
>>> ax1 = fig.add_subplot(gs[0, 1:]) # Row 0, Cols 1-2 (span)
>>> ax2 = fig.add_subplot(gs[1, :])  # Row 1, all columns (span)

Unequal widths (main panel + sidebar):

>>> gs = Plotter.get_grid(1, 2, width_ratios=[3, 1])
>>> # First column is 3x wider than second

Nested grid (inset layout):

>>> outer       = Plotter.get_grid(1, 2)
>>> ax_left     = fig.add_subplot(outer[0])
>>> inner       = Plotter.get_grid(2, 2, ax_sub=outer[1], wspace=0.1, hspace=0.1)
>>> ax_inner_00 = fig.add_subplot(inner[0, 0])

Control margins:

>>> gs = Plotter.get_grid(2, 2, left=0.1, right=0.95, top=0.95, bottom=0.1)

See also

get_grid_subplot

Create subplot from GridSpec index

get_subplots

High-level function for simple layouts

static get_grid_subplot(gs, fig, index, sharex=None, sharey=None, **kwargs)[source]

Create a subplot from a GridSpec at the specified index.

Parameters:
  • gs (GridSpec) – The GridSpec object.

  • fig (matplotlib.figure.Figure) – The figure to add the subplot to.

  • index (int, tuple, or slice) – Position in the grid. Can be: - int: Linear index (0, 1, 2, …) - tuple: (row, col) for single cell - slice/tuple with slices: For spanning multiple cells

  • sharex (Axes, optional) – Share axis with another subplot. Use for aligned multi-panel figures.

  • sharey (Axes, optional) – Share axis with another subplot. Use for aligned multi-panel figures.

  • **kwargs – Additional arguments passed to fig.add_subplot.

Returns:

The created subplot.

Return type:

matplotlib.axes.Axes

Examples

Single cell by linear index:

ax0 = Plotter.get_grid_subplot(gs, fig, 0)  # First cell
ax1 = Plotter.get_grid_subplot(gs, fig, 1)  # Second cell

Single cell by (row, col):

ax = fig.add_subplot(gs[1, 2])  # Row 1, Col 2

Span multiple cells:

ax_wide = fig.add_subplot(gs[0, :])   # Entire first row
ax_tall = fig.add_subplot(gs[:, 0])   # Entire first column
ax_block = fig.add_subplot(gs[0:2, 1:3])  # 2x2 block

Shared axes (for aligned panels):

ax0 = Plotter.get_grid_subplot(gs, fig, 0)
ax1 = Plotter.get_grid_subplot(gs, fig, 1, sharex=ax0)
ax2 = Plotter.get_grid_subplot(gs, fig, 2, sharex=ax0, sharey=ax0)
# ax1 and ax2 share x-axis with ax0; ax2 also shares y-axis
static get_grid_map(nrows: int, ncols: int) dict[source]

Generate a mapping from panel labels to grid indices.

Useful for referencing panels by name rather than index.

Parameters:
  • nrows (int) – Number of rows.

  • ncols (int) – Number of columns.

Returns:

Mapping with keys: - ‘by_index’: {0: (0,0), 1: (0,1), …} - ‘by_letter’: {‘a’: 0, ‘b’: 1, …} - ‘by_rowcol’: {(0,0): 0, (0,1): 1, …} - ‘grid’: 2D list of indices

Return type:

dict

Examples

>>> gmap = Plotter.get_grid_map(2, 3)
>>> gmap['by_letter']['c']  # Get index for panel 'c'
2
>>> gmap['by_index'][4]  # Get (row, col) for index 4
(1, 1)
>>> gmap['grid']  # 2D layout
[[0, 1, 2], [3, 4, 5]]
static configure_axes(ax, visible: bool = True, spines: bool | dict | str = True, ticks: bool | dict | str = True, tick_labels: bool | dict | str = True, xlabel: str = None, ylabel: str = None, title: str = None, xscale: str = None, yscale: str = None, xlim: tuple = None, ylim: tuple = None, fontsize: int = None, **kwargs)[source]

Configure axis visibility, spines, ticks, and labels in one call.

This is a convenience function for common axis customizations.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis to configure.

  • visible (bool, default=True) – If False, hide the entire axis (ax.axis(‘off’)).

  • spines (bool, dict, or str, default=True) – Control spine visibility: - True: Show all spines - False: Hide all spines - ‘left’: Hide all except left - ‘bottom’: Hide all except bottom - ‘minimal’: Hide top and right (Nature-style) - dict: {‘top’: False, ‘right’: False, …}

  • ticks (bool, dict, or str, default=True) – Control tick visibility: - True/False: Show/hide all ticks - ‘x’/’y’: Show only x/y ticks - dict: {‘x’: True, ‘y’: False}

  • tick_labels (bool, dict, or str, default=True) – Control tick label visibility (same format as ticks).

  • xlabel (str, optional) – Axis labels and title.

  • ylabel (str, optional) – Axis labels and title.

  • title (str, optional) – Axis labels and title.

  • xscale (str, optional) – Axis scale: ‘linear’, ‘log’, ‘symlog’.

  • yscale (str, optional) – Axis scale: ‘linear’, ‘log’, ‘symlog’.

  • xlim (tuple, optional) – Axis limits as (min, max).

  • ylim (tuple, optional) – Axis limits as (min, max).

  • fontsize (int, optional) – Font size for labels.

  • **kwargs – Additional arguments (e.g., labelpad).

Examples

Minimal style (no top/right spines):

Plotter.configure_axes(ax, spines='minimal', xlabel='Time', ylabel='Value')

Hide axis completely (for images/heatmaps):

Plotter.configure_axes(ax, visible=False)

Keep only left spine and y-ticks:

Plotter.configure_axes(ax, spines='left', ticks='y', tick_labels='y')

Log scale with custom limits:

Plotter.configure_axes(ax, yscale='log', ylim=(1e-6, 1e0))

Full configuration:

Plotter.configure_axes(
    ax,
    spines='minimal',
    xlabel=r'$x$ (nm)', ylabel=r'$\\rho$ (a.u.)',
    xscale='linear', yscale='log',
    xlim=(0, 100), ylim=(1e-3, 1),
    fontsize=12
)
static disable_axis(ax, which: str = 'both')[source]

Disable axis components for clean images/heatmaps.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis to modify.

  • which (str, default='both') – What to disable: - ‘both’: Disable x and y (full axis off) - ‘x’: Disable x-axis only - ‘y’: Disable y-axis only - ‘labels’: Keep ticks but hide labels - ‘ticks’: Keep labels but hide ticks - ‘spines’: Hide all spines

Examples

>>> Plotter.disable_axis(ax)  # Completely clean
>>> Plotter.disable_axis(ax, 'x')  # Keep y-axis
>>> Plotter.disable_axis(ax, 'labels')  # Keep ticks, no labels
static get_grid_ax(nrows: int, ncols: int, wspace: float = None, hspace: float = None, width_ratios: List[float] = None, height_ratios: List[float] = None, ax_sub=None, **kwargs) Tuple[GridSpec, list][source]

Get a GridSpec and an empty list for axes (convenience wrapper).

Parameters:
  • nrows (int) – Grid dimensions.

  • ncols (int) – Grid dimensions.

  • wspace (float, optional) – Spacing between subplots.

  • hspace (float, optional) – Spacing between subplots.

  • width_ratios (list, optional) – Relative sizes.

  • height_ratios (list, optional) – Relative sizes.

  • ax_sub (SubplotSpec, optional) – For nested grids.

  • **kwargs – Additional GridSpec arguments.

Returns:

(GridSpec, empty_axes_list)

Return type:

tuple

Examples

>>> gs, axes = Plotter.get_grid_ax(2, 3, wspace=0.3)
>>> for i in range(6):
...     Plotter.app_grid_subplot(axes, gs, fig, i)
static app_grid_subplot(axes: list, gs, fig, index: int, sharex=None, sharey=None, **kwargs)[source]

Append a subplot to an axes list (convenience method).

Parameters:
  • axes (list) – List to append the new axis to.

  • gs (GridSpec) – The GridSpec.

  • fig (Figure) – The figure.

  • index (int) – Grid index.

  • sharex (Axes, optional) – Share axes with another subplot.

  • sharey (Axes, optional) – Share axes with another subplot.

  • **kwargs – Additional arguments.

Examples

>>> gs, axes    = Plotter.get_grid_ax(2, 2)
>>> fig         = plt.figure()
>>> for i in range(4):
...     Plotter.app_grid_subplot(axes, gs, fig, i)
>>> # axes is now [ax0, ax1, ax2, ax3]
static twin_axis(ax, which='y', label='', color='C1', scale='linear', lim=None, fontsize=None, labelpad=0, **kwargs)[source]

Create a twin axis with a secondary scale.

Parameters:
  • ax (matplotlib.axes.Axes) – Primary axis.

  • which (str, default='y') – Which axis to twin: ‘y’ creates twinx(), ‘x’ creates twiny().

  • label (str, default='') – Label for the secondary axis.

  • color (str, default='C1') – Color for the secondary axis (spine, ticks, label).

  • scale (str, default='linear') – Scale for secondary axis: ‘linear’ or ‘log’.

  • lim (tuple, optional) – Limits for the secondary axis.

  • fontsize (int, optional) – Font size for the label.

  • labelpad (float, default=0) – Padding for the label.

  • **kwargs – Additional arguments passed to set_ylabel/set_xlabel.

Returns:

ax2 – The secondary axis.

Return type:

matplotlib.axes.Axes

Examples

>>> ax2 = Plotter.twin_axis(ax, which='y', label='Temperature (K)', color='red')
>>> Plotter.plot(ax2, x, temperature, color='red')
static power_law_guide(ax, x_range, exponent, *, add_label: bool = True, label=None, position='lower right', color='gray', ls='--', lw=1.5, offset_log=0, zorder=3, **kwargs)[source]

Add a power-law guide line to a log-log plot.

Useful for showing scaling behavior (e.g., y ~ x^{-2}).

Parameters:
  • ax (matplotlib.axes.Axes) – Axis with log-log scale.

  • x_range (tuple) – (x_start, x_end) for the guide line.

  • exponent (float) – Power-law exponent (slope in log-log).

  • label (str, optional) – Label (e.g., r’$\sim N^{-2}$’). If None, auto-generates.

  • position (str, default='lower right') – Where to anchor the line: ‘lower right’, ‘upper left’, etc.

  • color (str, default='gray') – Line color.

  • ls (str, default='--') – Line style.

  • lw (float, default=1.5) – Line width.

  • offset_log (float, default=0) – Vertical offset in log10 units.

  • **kwargs – Additional arguments passed to ax.plot.

Returns:

line – The guide line object.

Return type:

Line2D

Examples

>>> # Show y ~ x^{-2} scaling
>>> Plotter.power_law_guide(ax, (10, 1000), -2, label=r'$\\sim N^{-2}$')
static get_inset(ax, position=[0.0, 0.0, 1.0, 1.0], add_box=False, box_alpha=0.5, box_ext=0.02, facecolor='white', zorder=1, **kwargs)[source]

Create an inset axis within the given axis.

Parameters:
  • ax (matplotlib.axes.Axes) – The parent axis.

  • position (list) – [x0, y0, width, height] for the inset axis in relative coordinates.

  • add_box (bool, default=False) – Whether to add a semi-transparent box around the inset.

  • box_alpha (float, default=0.5) – Transparency of the box.

  • box_ext (float, default=0.02) – Extension of the box beyond the inset axis.

  • facecolor (str, default='white') – Face color of the box.

  • zorder (int, default=1) – Z-order of the inset axis.

  • **kwargs – Additional arguments passed to fig.add_axes.

  • Returns

  • ax2 (-)

static set_transparency(ax, alpha=0.0)[source]

Set the background patch transparency for an axis.

static set_legend(ax, fontsize=None, frameon: bool = False, loc: str = 'best', alignment: str = 'left', markerfirst: bool = False, framealpha: float = 1.0, reverse: bool = False, style=None, labelspacing: float = 0.5, handlelength: float = 1.5, handletextpad: float = 0.4, borderpad: float = 0.4, columnspacing: float = 1.0, ncol: int = 1, **kwargs)[source]

Sets the legend with a preferred style for publication-quality plots.

Parameters: - ax : Axis to which the legend will be added. - fontsize : Font size of the legend labels. - frameon : Whether to draw a frame around the legend. - loc : Location of the legend (‘best’, ‘upper right’, etc.). - alignment : Text alignment (‘left’, ‘center’, ‘right’). - markerfirst : Whether the marker or label appears first in the legend. - framealpha : Transparency of the legend frame (1.0 is opaque). - reverse : Reverse the order of legend items. - style : Predefined style for the legend (‘minimal’, ‘boxed’, etc.). - labelspacing : Vertical space between legend entries. - handlelength : Length of the legend markers. - handletextpad : Space between legend markers and text. - borderpad : Padding inside the legend box. - columnspacing : Spacing between legend columns. - ncol : Number of columns in the legend. - kwargs : Additional arguments passed to ax.legend().

static set_legend_custom(ax, conditions: list, fontsize=None, frameon=False, loc='best', alignment='left', markerfirst=False, framealpha=1.0, reverse=False, **kwargs)[source]

Set the legend with custom conditions for the labels - ax : axis to use - conditions: list of conditions - fontsize : fontsize - frameon : frame on or off - loc : location of the legend - alignment : alignment of the legend - markerfirst: marker first or not - framealpha: alpha of the frame

static get_subplots(nrows=1, ncols=1, sizex=10.0, sizey=10.0, sizex_def=3, sizey_def=3, annot_x_pos=None, annot_y_pos=None, panel_labels=False, single_if_1=False, share_x=False, share_y=False, width_ratios=None, height_ratios=None, constrained_layout=None, tight_layout=False, layout=None, mosaic=None, spans=None, named_panels=None, **kwargs) Tuple[Figure, AxesList][source]

Create subplot layouts and return a list-like AxesList wrapper.

Parameters:
  • nrows (int, default=(1, 1)) – Grid shape used for regular subplot creation and for spans.

  • ncols (int, default=(1, 1)) – Grid shape used for regular subplot creation and for spans.

  • sizex (float or sequence, default=10.0) – Figure width/height in inches, or ratio sequences per column/row.

  • sizey (float or sequence, default=10.0) – Figure width/height in inches, or ratio sequences per column/row.

  • sizex_def (float, default=3) – Inch scaling used when sizex/sizey are ratio sequences.

  • sizey_def (float, default=3) – Inch scaling used when sizex/sizey are ratio sequences.

  • annot_x_pos (float or sequence, optional) – Position(s) for panel label annotations in axes-fraction units.

  • annot_y_pos (float or sequence, optional) – Position(s) for panel label annotations in axes-fraction units.

  • panel_labels (bool or sequence, default=False) – If truthy, annotate each axis with labels (auto or user-provided).

  • single_if_1 (bool, default=False) – If True and only one axis is created, return that axis instead of an AxesList.

  • share_x (bool, default=False) – Share x/y axes across created panels.

  • share_y (bool, default=False) – Share x/y axes across created panels.

  • width_ratios (sequence, optional) – GridSpec ratios overriding ratios inferred from sizex/sizey.

  • height_ratios (sequence, optional) – GridSpec ratios overriding ratios inferred from sizex/sizey.

  • constrained_layout (bool, optional) – Explicitly control constrained layout engine.

  • tight_layout (bool, default=False) – Call fig.tight_layout() after creation (when compatible).

  • layout (str, optional) – Matplotlib layout engine name (e.g. 'constrained', 'tight').

  • mosaic (subplot-mosaic spec, optional) – Use plt.subplot_mosaic with named panels.

  • spans (dict, optional) – Named span panels on a regular grid. Example: {'main': (0, 2, 0, 3), 'side': (0, 2, 3, 4)}.

  • named_panels (sequence or dict, optional) – Panel aliases for regular grids or mosaic alias remapping.

  • **kwargs (dict) – Forwarded Matplotlib options. Common keys include: dpi, subplot_kw, gridspec_kw, hspace, wspace, left/right/top/bottom, grid, grid_kws, despine, axis_off, suptitle, suptitle_kws, post_hook.

Returns:

  • fig (matplotlib.figure.Figure) – Created figure.

  • axes (AxesList or matplotlib.axes.Axes) – AxesList wrapper (list-compatible, grid-aware, named-panel access). Returns single axis only when single_if_1=True.

Notes

AxesList supports: - list operations (iterate, append-like access, slicing) - grid indexing: axes[row, col] - named access: axes['main'] - helpers: row(), col(), span(), select(), apply()

Examples

Standard grid: fig, axes = Plotter.get_subplots(2, 3, sizex=9, sizey=5) axes[1, 2].plot(x, y)

Named aliases: fig, axes = Plotter.get_subplots(1, 3, named_panels=['left', 'mid', 'right']) axes['mid'].set_title('Center')

Mosaic: fig, axes = Plotter.get_subplots(mosaic=[['A', 'A', 'B'], ['C', 'D', 'D']]) axes['A'].plot(x, y)

Spans: fig, axes = Plotter.get_subplots(nrows=3, ncols=4, spans={'main': (0, 2, 0, 3), 'side': (0, 2, 3, 4), 'bottom': (2, 3, 0, 4)}) axes['main'].plot(x, y)

static subplots(*args, **kwargs)[source]

Alias of Plotter.get_subplots().

static subplot_mosaic(mosaic, *args, **kwargs)[source]

Convenience alias for mosaic layouts.

Equivalent to: Plotter.get_subplots(mosaic=mosaic, *args, **kwargs)

static save_fig(directory: str, filename: str, format='pdf', dpi=200, adjust=True, fig=None, **kwargs)[source]

Save figure to a specific directory. - directory : directory to save the file - filename : name of the file - format : format of the file - dpi : dpi of the file - adjust : adjust the figure

static savefig(directory, filename, format, dpi, adjust, fig=None, **kwargs)[source]

Alias for save_fig() with the historical lowercase name.

static plot_heatmaps(dfs: list, colormap='viridis', cb_width=0.1, movefirst=True, index=None, columns=None, values=None, sortidx=True, zlabel='', sizemult=3, xvals=True, yvals=True, vmin=None, vmax=None, **kwargs)[source]

Plot a sequence of pivoted DataFrame heatmaps on a shared figure.

class general_python.common.plot.PlotterSave[source]

Bases: object

File-output helpers for simple plot-adjacent data artifacts.

static dict2json(directory: str, fileName: str, data)[source]

Save dictionary to json file - directory : directory to save the file - fileName : name of the file - data : dictionary to save

static json2dict(directory: str, fileName: str) dict[source]

Load dictionary from json file

static json2dict_multiple(directory: str, keys: list)[source]

Based on the specified keys, load the dictionaries from the json files The keys are the names of the files as well!

static singleColumnData(directory: str, fileName: str, y, typ='.npy')[source]

Stores the values as a single vector

static twoColumnsData(directory: str, fileName: str, x, y, typ='.npy')[source]

Stores the x, y vectors in 2D form (multiple rows and two columns)

static matrixData(directory: str, fileName: str, x, y, typ='.npy')[source]

Stores the x, y vectors in matrix form (appending single column at start for x values)

static app_df(df, colname: str, y, fill_value=nan)[source]

Appends the data to the dataframe.

Parameters: - df (pd.DataFrame): The dataframe to append data to. - colname (str): The column name to append data under. - y (array-like): The data to append. - fill_value: The value to use for filling if resizing is needed.

static app_array(arr, y)[source]

Appends the data to a numpy array.

Parameters: - arr (np.ndarray): The numpy array to append data to. - y (np.ndarray): The data to append.

Returns: - np.ndarray: The updated numpy array with appended data.

class general_python.common.plot.MatrixPrinter[source]

Bases: object

Class for printing matrices and vectors

static print_matrix(matrix: ndarray)[source]

Prints the matrix in a nice form

static print_vector(vector: ndarray)[source]

Prints the vector in a nice form

static print_matrices(matrices: list)[source]

Prints a list of matrices in a nice form

static print_vectors(vectors: list)[source]

Prints a list of vectors in a nice form

This module provides a DataHandler class for handling and processing data arrays. It includes methods for filtering, initializing, interpolating, aggregating, concatenating, and averaging data arrays. Classes:

DataHandler: A class containing static methods for data handling and processing.

general_python.common.datah._filter_typical_values(current_x, current_y, typical, threshold=1.0) tuple
general_python.common.datah._initialize_combined_arrays(y_list, x_list, typical, threshold=1.0) tuple
general_python.common.datah._interpolate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple
general_python.common.datah._aggregate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple

Aggregates and updates combined x and y data arrays with current x and y data arrays.

general_python.common.datah.concat_and_average(y_list, x_list, typical=False, use_interpolation=True, threshold=1.0) tuple
general_python.common.datah.concat_and_fill(y_list, x_list, lengths, missing_val=np.nan) tuple
class general_python.common.datah.DataHandler[source]

Bases: object

DataHandler class provides static methods for handling and processing data arrays, including filtering, interpolating, aggregating, concatenating, and cutting matrices based on specific criteria. .. method:: _filter_typical_values(current_x, current_y, typical, threshold=1.0) -> tuple

_initialize_combined_arrays(y_list, x_list, typical, threshold=1.0) tuple[source]

Initializes and combines arrays from given lists. If the typical flag is set to True, it filters the combined arrays to include only elements where the values in y_combined are less than the threshold.

_interpolate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple[source]
_aggregate_and_update(x_combined, y_combined, current_x, current_y, divider) tuple[source]

Aggregates and updates combined x and y data arrays with current x and y data arrays by summing common bins and appending unique bins.

concat_and_average(y_list, x_list, typical=False, use_interpolation=True, threshold=1.0) tuple[source]
concat_and_fill(y_list, x_list, lengths, missing_val=np.nan) tuple[source]
cut_matrix_bad_vals_zero(M, axis=0, tol=1e-9, check_limit

float | None = 10) -> np.ndarray: Cuts off the slices (along any specified axis) in matrix M where all elements are close to zero.

cut_matrix_bad_vals(M, axis=0, threshold=-1e4, check_limit=None) np.ndarray[source]

Cuts off the rows or columns in matrix M where the first check_limit elements are all below a threshold.

static concat_and_average(y_list, x_list, typical=False, use_interpolation=True, threshold=1.0)[source]

Concatenates and averages y values across multiple histograms.

:param y_list : List of y matrices (each one corresponding to a realization). :param x_list : List of x vectors (each one corresponding to a realization). :param typical : If True, filter y values less than 1.0. :param use_interpolation: If True, interpolate y values for non-matching bins.

If False, aggregate only exact matches and append unique bins.

:param threshold : The threshold value for filtering y values (default: 1.0). :returns : Combined y values and x bins after averaging.

static concat_and_fill(y_list, x_list, lengths, missing_val=nan)[source]

Concatenates y values across multiple histograms, combines x vectors into a single sorted array, and fills missing values.

Parameters:
  • y_list – List of y arrays (each one corresponding to a realization).

  • x_list – List of x arrays (each one corresponding to a realization group).

  • lengths – List indicating how many y arrays correspond to each x array.

  • missing_val – Value to fill for missing data points after interpolation (default: np.nan).

Returns:

A 2D NumPy array of y values interpolated to a common x grid and the combined x bins.

static cut_matrix_bad_vals_zero(M, axis=0, tol=1e-09, check_limit: float | None = 10)[source]

Cut off the slices (along any specified axis) in matrix M where all elements are close to zero. If a 1D vector is provided, it returns the vector unless all elements are close to zero, in which case it returns an empty array.

Parameters: - M (numpy.ndarray) : The input matrix or vector. - axis (int) : The axis along which to check for zero elements.

For example, 0 for rows, 1 for columns, etc. Ignored if M is a 1D vector.

  • tol (float) : The tolerance for considering elements as zero.

  • check_limit (int) : The maximum number of elements along the axis to check for zeros.

Returns: - numpy.ndarray: The resulting matrix after removing slices (along the specified axis)

that are close to zero, or the vector after removing if all elements are close to zero.

static cut_matrix_bad_vals(M, axis=0, threshold=-10000.0, check_limit=None)[source]

Cut off the rows or columns in matrix M where the first check_limit elements are all below a threshold.

Parameters: - M (numpy.ndarray): The input matrix. - axis (int): The axis along which to check for elements below the threshold (0 for rows, 1 for columns). - threshold (float): The threshold value. - check_limit (int, optional): The number of elements to check from each row or column.

Returns: - numpy.ndarray: The resulting matrix after removing rows or columns where the first check_limit elements are below the threshold.

Small logging helpers with verbosity, indentation, color, and file output.

The Logger wrapper keeps the public API lightweight while avoiding duplicate notebook handlers and optionally mirroring logs to a file. Console colors can be disabled with PYLOGCOLORS=0; file logging is enabled by setting PYLOGFILE to a non-empty path-like value.

email : maxgrom97@gmail.com

class general_python.common.flog.Logger(name: str = 'Global', logfile: str | None = None, lvl: int = 20, append_ts: bool = False, use_ts_in_cmd: bool = False)

Bases: object

Logger class for handling console and file logging with verbosity control.

LEVELS = {10: 'debug', 20: 'info', 30: 'warning', 40: 'error'}
LEVELS_R = {'debug': 10, 'error': 40, 'info': 20, 'warning': 30}
__init__(name: str = 'Global', logfile: str | None = None, lvl: int = 20, append_ts: bool = False, use_ts_in_cmd: bool = False)

Initialize the logger instance.

Parameters:
  • logfile (str) – Name of the log file (without extension if empty, a timestamp will be used).

  • lvl (int) – Logging level (default: logging.INFO).

  • append_ts (bool) – Whether to append a timestamp to the log file name (default: False).

  • use_ts_in_cmd (bool) – Whether to use a timestamp in console output (default: False).

static breakline(n: int)

Print multiple break lines.

Parameters:

n (int) – Number of break lines.

static colorize(txt: str, color: str)

Apply color to the given text (for console output).

Parameters:
  • txt (str) – Text to colorize.

  • color (str) – Color name.

Returns:

Colorized text.

Return type:

str

configure(directory: str)

Configure the logger to use a specific directory for log files.

Parameters:

directory (str) – Path to the directory where log files will be stored.

dbg(msg: str, lvl=0, verbose=True, color=None)

Alias for debug().

debug(msg: str, lvl=0, verbose=True, color=None)

Log a debug message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the message.

classmethod endl(n: int)

Print n blank lines through the logger break-line helper.

err(msg: str, lvl=0, verbose=True, color='red')

Alias for error().

error(msg: str, lvl=0, verbose=True, color='red')

Log an error message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the message.

inf(msg: str, lvl=0, verbose=True, color=None)

Alias for info().

info(msg: str, lvl=0, verbose=True, color=None)

Log an informational message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

static print(msg: str, lvl=0)

Format a message with a timestamp.

Parameters:
  • msg (str) – Message to format.

  • lvl (int) – Indentation level.

Returns:

Formatted message.

Return type:

str

static print_tab(lvl=0)

Generate indentation for message formatting.

Parameters:

lvl (int) – Number of tabulators.

Returns:

Indented string.

Return type:

str

say(*args, end=True, log=20, lvl=0, verbose=True, color=None)

Print and log multiple messages if verbosity is enabled.

Parameters:
  • *args – Messages to log.

  • end (bool) – Append newline (default: True).

  • log (int) – Log level (10 : info, 20 : debug, 30 : warning, 40 : error) (default: 10).

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

timing(func)

Decorator to measure and log the execution time of functions. :param func: function to be timed

Use as:

@logger.timing def my_function(…):

title(tail: str, desired_size: int = 50, fill: str = '=', lvl=0, verbose=True, color=None)

Create a formatted title with filler characters if verbosity is enabled.

Parameters:
  • tail (str) – Text in the middle of the title.

  • desired_size (int) – Total width of the title.

  • fill (str) – Character used for filling.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

  • color (str) – Optional color for the title.

warn(msg: str, lvl=0, verbose=True, color='yellow')

Alias for warning().

warning(msg: str, lvl=0, verbose=True, color='yellow')

Log a warning message if verbosity is enabled.

Parameters:
  • msg (str) – Message to log.

  • lvl (int) – Indentation level.

  • verbose (bool) – Log if True (default: True).

class general_python.common.flog.Colors(color: str)

Bases: object

Class for defining ANSI colors for console output. This class provides a way to colorize text in the terminal using ANSI escape codes.

black

ANSI escape code for black text.

Type:

str

red

ANSI escape code for red text.

Type:

str

green

ANSI escape code for green text.

Type:

str

yellow

ANSI escape code for yellow text.

Type:

str

blue

ANSI escape code for blue text.

Type:

str

white

ANSI escape code for resetting color to default.

Type:

str

__call__(text: str) str

Apply the color to the given text.

Parameters:

text (str) – Text to colorize.

Returns:

Colorized text.

Return type:

str

__init__(color: str)
__len__() int

Get the length of the color string.

Returns:

Length of the color string.

Return type:

int

black = '\x1b[30m'
blue = '\x1b[34m'
green = '\x1b[32m'
red = '\x1b[31m'
white = '\x1b[0m'
yellow = '\x1b[33m'
general_python.common.flog.printV(what: str, v=True, tabulators=0)

Prints the message only if verbosity is enabled.

general_python.common.flog.printJust(file, sep='\t', elements=[], width=8, endline=True, scientific=False)

[summary] Function that can print a list of elements creating indents The separator also can be used to clean the indents.

Arguments: - width : governing the width of each column. - endline : if true, puts an endline after the last element of the list. - scientific: allows for scientific printing.

general_python.common.flog.printDictionary(d: dict) str

Returns a formatted string representation of a dictionary.

general_python.common.flog.print_arguments(parser, logger: Logger | None = None, title: str = 'Options for the script', columnsize: int = 30) None

Prints the arguments of a parser in a formatted table using the provided logger. :param parser: The argument parser containing the script’s options. :type parser: argparse.ArgumentParser :param logger: The logger instance used to print the formatted output. :type logger: Optional[Logger] :param title: The title to display above the options table. Defaults to “Options for the script”. :type title: str, optional :param columnsize: The width of the “Option” column in the table. Defaults to 30. :type columnsize: int, optional

Returns:

None

general_python.common.flog.log_timing_summary(logger: Logger, phase_durations: Dict[str, float], total_duration: float | None = None, title: str = 'Timing Summary', phase_col_width: int = 18, duration_col_width: int = 14, duration_precision: int = 4, lvl: int = 0, add_total_row: bool = True, extra_info: List[str] | None = None)

Logs a timing summary in a tabular format using the provided logger.

Parameters: logger:

Logger instance to log the timing summary.

phase_durations:

Dictionary mapping phase names (str) to their durations (float).

total_duration:

Total duration of the process (optional). If provided and add_total_row is True, a “Total” row will be added.

title:

Title for the summary table.

phase_col_width:

Width for the ‘Phase’ column.

duration_col_width:

Width for the ‘Duration (s)’ column.

duration_precision:

Decimal precision for duration values.

lvl:

Base logging level for the summary.

add_total_row:

Whether to include a ‘Total’ row using total_duration.

extra_info:

Optional list of strings to log after the table (e.g., notes, performance).

general_python.common.flog.get_global_logger(**kwargs) Logger

One Logger wrapper per process (PID), safe across threads/forks. Prints the banner only once per entire program via env sentinel.

Parameters:
  • **kwargs – Arguments to pass to the Logger constructor.

  • name (-) – Name of the logger (default: “Global”).

  • lvl (-) – Logging level (default: logging.INFO).

  • append_ts (-) – Whether to append timestamps (default: True).

  • use_ts_in_cmd (-) – Whether to use timestamps in commands (default: True).

  • logfile (-) – Path to a logfile (default: None).

Returns:

The global logger instance.

Return type:

Logger

Example

>>> logger = get_global_logger()
>>> logger.info("This is an informational message.")
>>> logger.debug("This is a debug message.", color='blue')

binary.py

This module implements binary - manipulation routines that allow you to work with binary representations of integers or vectors.

It includes functions to: - check if a number is a power of 2, - check a given bit in an integer or an indexable vector, - convert an integer to a base representation (spin: +/- value or binary 0/1), - convert a base representation back to an integer or a binary string or a vector of values, - flip bits (all at once or a single bit), - reverse the bits of a 64 - bit integer, - extract bits (by ordinal position or via a mask), - and prepare a bit mask from a list of positions.

Functions are written so that they work with plain Python integers as well as with NumPy or JAX arrays. You can choose the backend (np or jnp) by passing the corresponding module.

DescriptionThis module implements binary - manipulation routines that allow you to work with binary representations of integers

or vectors.


general_python.common.binary.ctz64(x: uint64) int64[source]

Count trailing zeros in a 64-bit unsigned integer (Numba-safe).

Returns 64 if x == 0 (no bits set). Uses binary search - O(log bits).

Parameters:

x – 64-bit unsigned integer

Returns:

Number of trailing zero bits (0-64)

Example

>>> ctz64(np.uint64(8))  # 0b1000 -> 3 trailing zeros
3
general_python.common.binary.popcount64(x: uint64) int64[source]

Count number of set bits in a 64-bit integer (Numba-safe).

Uses parallel bit-counting algorithm - O(1).

Parameters:

x – 64-bit unsigned integer

Returns:

Number of set bits (0-64)

Example

>>> popcount64(np.uint64(0b1011))
3
general_python.common.binary.mask_from_indices(idxs: ndarray) uint64[source]

Convert array of bit indices to a bitmask (Numba-safe).

Parameters:

idxs – Array of indices (int64) indicating which bits to set

Returns:

64-bit mask with bits set at given indices

Example

>>> mask_from_indices(np.array([0, 2, 3], dtype=np.int64))
np.uint64(13)  # 0b1101
general_python.common.binary.indices_from_mask(mask: uint64) ndarray[source]

Convert bitmask to array of set bit indices (Numba-safe).

Returns indices in ascending order. Uses ctz64 for efficiency.

Parameters:

mask – 64-bit mask

Returns:

Array of indices (int64) where bits are set

Example

>>> indices_from_mask(np.uint64(13))  # 0b1101
array([0, 2, 3], dtype=int64)
general_python.common.binary.complement_mask(mask: uint64, ns: int) uint64[source]

Return the complement of a mask within ns bits.

Parameters:
  • mask – Original bitmask

  • ns – Number of bits in the system (1-64)

Returns:

Complement mask (bits flipped within range [0, ns))

general_python.common.binary.complement_indices(n: int, indices: ndarray) ndarray[source]

Return indices in [0..n) not in indices.

O(n) boolean scratch, minimal allocations.

Parameters:
  • n – Upper bound of the range (exclusive)

  • indices – Input indices to exclude

Returns:

Array of complementary indices (sorted)

Example

>>> complement_indices(5, np.array([1, 3]))
array([0, 2, 4], dtype=int64)
general_python.common.binary.check_int(n, k)[source]

Checks if the k-th bit in the binary representation of an integer n is set (1).

Parameters:
  • n (int) – The integer to check.

  • k (int) – The position of the bit to check (0-indexed, from the right).

Returns:

A non-zero value if the k-th bit is set, otherwise 0.

Return type:

int

general_python.common.binary.popcount(n: int, spin: bool = True, backend: str = 'default')[source]

Calculate the number of 1-bits in the binary representation of an integer.

Parameters:

n (int) – The integer whose 1-bits are to be counted.

Returns:

The number of 1-bits in the binary representation of the input integer.

Return type:

int

general_python.common.binary.int2base(n: int, size: int, backend='default', spin: bool = True, spin_value: float = 0.5, out: ndarray | Array | None = None)[source]

Convert an integer to a base representation (spin: +/- value or binary 0/1).

Parameters:
  • n (int) – The integer to convert.

  • size (int) – The number of bits in the binary representation.

  • backend (np) – The backend to use (np or jnp).

  • spin (bool) – A flag to indicate whether to use spin values.

  • spin_value (float) – The spin value to use.

Returns:

The binary representation of the integer.

Return type:

np.ndarray or jnp.ndarray

general_python.common.binary.base2int(vec: ndarray | Array, spin: bool = True, spin_value: float = 0.5) int64[source]

Convert a base representation back to an integer.

Parameters:
  • vec (np.ndarray or jnp.ndarray) – The binary representation of the integer.

  • spin (bool) – A flag to indicate whether to use spin values.

  • spin_value (float) – The spin value to use.

Returns:

The integer representation of the binary vector.

Return type:

int

general_python.common.binary.flip_all(n: ndarray | Array, size: int, spin: bool = True, spin_value: float = 0.5, backend='default')[source]

Flip all bits in a representation of a state.

Parameters:
  • n (int) – The value to flip.

  • size (int) – The number of bits - the size of the state.

  • spin (bool) – A flag to indicate whether to use spin values.

  • spin_value (float) – The spin value to use.

  • backend (np) – The backend to use (np or jnp or default).

Note

The function is implemented for both integers and NumPy arrays (np or jnp). The function is implemented for both binary and spin representations.

Returns:

A flipped state.

general_python.common.binary.rev(n: ndarray | Array, size: int, bitsize=64, backend='default')[source]

Reverse the bits of a 64-bit integer.

Parameters:
  • n (int) – The integer to reverse.

  • size (int) – The number of bits in the integer.

Returns:

The integer with the bits reversed.

Return type:

int

general_python.common.binary.rotate_left(n: ndarray | Array, size: int, backend='default', axis: int | None = None)[source]

Rotate the bits of an integer to the left.

Parameters:
  • n (int) – The integer to rotate.

  • size (int) – The number of bits in the integer.

  • backend (str) – The backend to use (np or jnp or default).

  • axis (int) – The axis along which to rotate the bits.

Returns:

The integer with the bits rotated to the left.

Return type:

int

general_python.common.binary.rotate_right(n: ndarray | Array, size: int, backend='default', axis: int | None = None)[source]

Rotate the bits of an integer to the right.

Parameters:
  • n (int) – The integer to rotate.

  • size (int) – The number of bits in the integer.

  • backend (str) – The backend to use (np or jnp or default).

  • axis (int) – The axis along which to rotate the bits (only for arrays - Optional).

Returns:

The integer with the bits rotated to the right.

Return type:

int

general_python.common.binary.rotate_left_by(n: ndarray | Array, size: int, shift: int = 1, backend='default', axis: int | None = None)[source]

Rotate the bits of n to the left by ‘shift’ positions.

For an integer state, it treats n as having a binary representation of length ‘size’ and performs a cyclic left shift:

result = ((n << shift) | (n >> (size - shift))) & ((1 << size) - 1)

For array-like states (NumPy or JAX arrays), it calls the backend’s roll function along the specified axis with a negative shift.

Parameters:
  • n (int or array-like) – The state to rotate.

  • size (int) – The number of bits in the integer representation.

  • shift (int, optional) – The number of positions to shift. Default is 1.

  • backend (str, optional) – The backend to use (e.g., ‘default’, ‘numpy’, or ‘jax’).

  • axis (int, optional) – The axis along which to rotate for array-like states.

Returns:

Rotated state (same type as n).

Timing utilities for benchmarking and profiling.

This module provides: - A Timer class for context-manager and decorator-based timing with support for laps and deadlines. - Functional wrappers (timeit) for simple benchmarking. - JAX-aware synchronization helpers to ensure accurate timing of asynchronous operations.

Features: - High-precision monotonic clock (nanoseconds). - Automatic synchronization for JAX arrays (blocks until ready). - Detailed reporting with laps and mean/std stats.

email : maxgrom97@gmail.com

class general_python.common.timer.TimerState(*values)[source]

Bases: Enum

Lifecycle states reported by Timer.state.

RUNNING = 'running'
PAUSED = 'paused'
STOPPED = 'stopped'
class general_python.common.timer.Timer(name: str | None = None, logger: Logger | None = None, logger_args: Dict[str, Any] | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None)[source]

Bases: object

Enhanced timer class for measuring elapsed time.

This class can be used as a context manager, a decorator, or directly to time code. It supports:

  • Starting, stopping, and resetting the timer.

  • Recording multiple laps.

  • Verbose output to automatically print timing information.

name

Optional name to identify the timer.

Type:

str

verbose

If True, prints timing information on stop.

Type:

bool

format

Optional format for the output timing information.

name: str | None
logger: Logger | None
logger_args: Dict[str, Any] | None
verbose: bool
unit: str
deadline_s: float | None
synchronizer: Callable[[Any], None] | None
start() Timer[source]

Start (or resume) the timer; no-op if already running.

pause() Timer[source]

Pause the timer, accumulating elapsed time.

resume() Timer[source]

Resume after pause.

stop() float[source]

Stop and return elapsed time in seconds.

reset() Timer[source]

Clear state (elapsed, laps, marks) and stop.

lap(name: str | None = None) float[source]

Record a lap (time since last lap or start) and return lap in seconds.

mark(name: str | None = None) None[source]

Create/update a named absolute anchor at current time. Later use since(‘name’).

since(name: str | None = None, ts: int | None = None) float[source]

Seconds elapsed since the named mark. Raises KeyError if mark not set.

elapsed_ns() int[source]

Total elapsed nanoseconds (includes current running span).

elapsed_ms() float[source]

Elapsed milliseconds (float).

elapsed_us() float[source]

Elapsed microseconds (float).

elapsed_s() float[source]

Elapsed seconds (float).

laps() Tuple[List[float], List[str]][source]

Recorded laps (seconds) and their names.

remaining_s(buffer_s: float = 0.0) float | None[source]

If deadline_s is set, return remaining seconds (can be negative). Otherwise None.

overtime(buffer_s: float = 0.0) bool[source]

True if elapsed >= deadline_s - buffer_s; False if no deadline is set.

property state: TimerState

Current timer lifecycle state.

format_elapsed() str[source]

Return elapsed time formatted in the configured display unit.

report(include_laps: bool = True) str[source]

Build a human-readable timing report.

Parameters:

include_laps – Include named lap timings when any have been recorded.

classmethod decorator(name: str | None = None, logger: Logger | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None)[source]

Decorator for timing a function.

Usage:

@Timer.decorator(“block”, verbose=True) def fn(…): …

Parameters:
  • name (-) – The name of the timer (default: function name)

  • logger (-) – Optional logger for logging (default: None)

  • verbose (-) – If True, print timing info (default: False)

  • unit (-) – Time unit for reporting (default: “auto”)

  • deadline_s (-) – Optional deadline in seconds (default: None)

  • synchronizer (-) – Optional synchronizer function (default: None)

__init__(name: str | None = None, logger: Logger | None = None, logger_args: Dict[str, Any] | None = None, verbose: bool = False, unit: str = 'auto', deadline_s: float | None = None, synchronizer: Callable[[Any], None] | None = None) None
general_python.common.timer.timeit(fn: Callable[[...], Any], *args, **kwargs) Tuple[Any, float][source]

Functional wrapper to time a callable.

Usage:

res, dt = timeit(my_function, arg1, arg2)

general_python.common.timer.benchmark(name: str = 'Block', sync: bool = True)[source]

Context manager for timing blocks of code.

Usage:
with benchmark(“Gradient Step”) as t:

train_step()

print(t.elapsed)

Lattices Module

Lattice factory and registry for geometry-aware simulations.

The package provides canonical lattice classes (square, triangular, honeycomb, hexagonal, graph) together with registry helpers for custom lattices.

Input/output contracts

Factory functions return subclasses of Lattice with explicit geometry metadata (dimensions, boundary conditions, primitive vectors, and neighbor maps). Typical constructor inputs are integer sizes (lx, ly, lz), a boundary mode, and optional flux or graph descriptors.

Shape and dtype expectations

Coordinate arrays are expected as real-valued arrays with shape (ns, dim). Index-based neighbor structures are integer arrays or lists over site ids in [0, ns). Plotting helpers consume NumPy-compatible arrays.

Numerical stability and determinism

Topology construction is deterministic for fixed parameters. Floating-point roundoff can affect reciprocal-space formatting or plotting labels but should not change connectivity.

email : maxgrom97@gmail.com license : MIT version : 1.0 ———————————–

class general_python.lattices.BoundaryFlux(values: Mapping[LatticeDirection, float])[source]

Bases: object

Collection of magnetic fluxes piercing lattice boundary loops.

The value associated with a direction is interpreted as the phase phi (in radians) acquired upon wrapping around the boundary once along that direction. The corresponding hopping phase factor is exp(1j * phi).

The fluxes are stored as a mapping from LatticeDirection to corresponding complex phase values.

Options for specifying fluxes include: - Uniform flux in all directions (single float value). - Direction-specific fluxes (mapping from direction to float). - Zero flux (empty mapping).

Physically, these fluxes correspond to magnetic fluxes threading the holes of a torus formed by periodic boundary conditions.

Example: >>> flux = BoundaryFlux({LatticeDirection.X: np.pi/2, LatticeDirection.Y: np.pi}) >>> flux.phase(LatticeDirection.X) (6.123233995736766e-17+1j) >>> flux.phase(LatticeDirection.Y) (-1+0j) >>> flux.is_trivial False

For non-abelian gauge fields, more complex structures are needed.

values: Mapping[LatticeDirection, float]
phase(direction: LatticeDirection, winding: int = 1) complex[source]

Return exp(1j * winding * phi_direction).

Parameters:

directionLatticeDirection

The lattice direction for which to get the phase factor.

windingint, optional

The winding number for the phase factor. Defaults to 1.

phase_product(wx: int = 0, wy: int = 0, wz: int = 0) complex[source]

Return total phase from combined winding numbers in all directions.

Returns \(\exp(i (w_x \phi_x + w_y \phi_y + w_z \phi_z))\).

get(direction: LatticeDirection) float[source]

Return raw flux (radians) for direction, defaulting to 0.

property is_trivial: bool

True when all fluxes are effectively zero (mod 2π).

A flux of exactly \(2\pi n\) (integer multiples) is considered trivial because the hopping phase reduces to unity.

property is_nontrivial: bool

True when any direction carries a non-zero flux (mod 2π).

property total_flux: float

Sum of all flux values (in radians).

k_shift_fractions(Lx: int = 1, Ly: int = 1, Lz: int = 1) Tuple[float, float, float][source]

Return the fractional k-grid offset induced by the boundary fluxes.

With flux \(\phi_\mu\) in direction \(\mu\) of length \(L_\mu\), the allowed Bloch momenta shift from \(f_\mu = n_\mu / L_\mu\) to \(f_\mu + \phi_\mu / (2\pi L_\mu)\).

Returns:

(delta_fx, delta_fy, delta_fz) – Fractional coordinate shifts (add to the standard grid).

Return type:

tuple[float, float, float]

as_dict() Dict[str, float][source]

Return a plain {direction_name: flux} dictionary.

as_array() ndarray[source]

Return fluxes as [phi_x, phi_y, phi_z] array (radians).

classmethod zero() BoundaryFlux[source]

Create a BoundaryFlux with zero flux in all directions.

classmethod uniform(phi: float) BoundaryFlux[source]

Create a BoundaryFlux with phi (rad) in every direction.

__bool__() bool[source]

True when the flux is non-trivial (has physical effect).

__init__(values: Mapping[LatticeDirection, float]) None
class general_python.lattices.Lattice(dim: int = None, lx: int = 1, ly: int = 1, lz: int = 1, bc: str = None, adj_mat: ndarray = None, flux: ndarray = None, *args, **kwargs)[source]

Bases: ABC

Abstract Base Class for defining lattice structures.

This class serves as the foundation for all lattice implementations in the lattices module. It handles geometry, connectivity, boundary conditions, and k-space properties.

Indexing Convention

Lattice sites are indexed linearly from 0 to Ns - 1. The mapping from spatial coordinates to linear index depends on the concrete implementation, but typically follows a row-major (lexicographic) order:

  • 1D: Left to right.

  • 2D: Bottom-left to top-right (x varies fastest).

  • 3D: Front-bottom-left to back-top-right.

Features

  • Geometry: Calculation of real-space coordinates, unit vectors, and basis vectors.

  • Connectivity: Automatic identification of Nearest Neighbors (NN) and Next-Nearest Neighbors (NNN).

  • Boundaries: Support for various boundary conditions: * PBC: Periodic Boundary Conditions (torus topology).

    • X-direction periodic, Y-direction periodic, Z-direction periodic

    • OBC: Open Boundary Conditions (hard edges). * X-direction open, Y-direction open, Z-direction open

    • MBC: Mixed Boundary Conditions (e.g., cylinder topology). * X-direction periodic, Y-direction open, Z-direction open

    • SBC: Switched Boundary Conditions (e.g. twisted cylinder). * X-direction open, Y-direction periodic, Z-direction open

    • TWISTED: Twisted Boundary Conditions with specified fluxes.

  • Reciprocal Space: Automatic calculation of reciprocal lattice vectors and Brillouin Zone paths.

  • Visualization: Integration with plotting utilities via .plot.

Ns

Total number of sites in the lattice.

Type:

int

dim

Spatial dimension of the lattice (1, 2, or 3).

Type:

int

Lx, Ly, Lz

Linear dimensions of the lattice.

Type:

int

bc

Active boundary condition.

Type:

LatticeBC

coordinates

Array of shape (Ns, 3) containing real-space coordinates of all sites.

Type:

np.ndarray

nn

Adjacency list for nearest neighbors. nn[i] is a list of neighbors for site i.

Type:

List[List[int]]

property bad_lattice_site

Bad lattice site

a = 1
b = 1
c = 1
unit_length = 1
__init__(dim: int = None, lx: int = 1, ly: int = 1, lz: int = 1, bc: str = None, adj_mat: ndarray = None, flux: ndarray = None, *args, **kwargs)[source]

General Lattice class. This class contains the general lattice model.

Parameters:
  • dim (int, optional) – Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.

  • lx (int, optional) – Length of the lattice in the x-direction.

  • ly (int, optional) – Length of the lattice in the y-direction.

  • lz (int, optional) – Length of the lattice in the z-direction.

  • bc (str, optional) – Boundary conditions (e.g., ‘PBC’, ‘OBC’).

  • adj_mat (np.ndarray, optional) – Adjacency matrix for the lattice.

  • flux (np.ndarray, optional) – Flux piercing the boundaries. This can be a dictionary specifying the flux in each direction, or a single value applied to all directions. Importantly, this automatically implies TWISTED boundary conditions, so the bc parameter can be left as None or set to ‘TWISTED’ for clarity.

__str__()[source]

String representation of the lattice

__repr__()[source]

Representation of the lattice

__len__()[source]

Length of the lattice (number of sites)

__getitem__(index: int)[source]

Get the site at the given index

__iter__()[source]

Iterate over the lattice sites

__contains__(item: int)[source]

Check if the lattice contains the given site

init(verbose: bool = False, *, force_dft: bool = False, **kwargs)[source]

Initializes the lattice object by calculating coordinates, reciprocal vectors, and neighbor lists.

This method performs the following steps: 1. Calculates the real-space coordinates, r-vectors, and k-vectors of the lattice. 2. If the number of sites (self.Ns) is less than 100, computes the discrete Fourier transform (DFT) matrix. 3. If an adjacency matrix (self._adj_mat) is provided:

  • Determines the number of sites (Ns) from the adjacency matrix.

  • For each site, identifies nearest neighbors (nn) as those connected by the highest weight in the adjacency matrix, and next-nearest neighbors (nnn) as those connected by the next highest distinct weight.

  • Stores forward neighbors (indices greater than the current site) for both nn and nnn.

  1. If no adjacency matrix is provided, calculates nearest and next-nearest neighbors using default methods.

5. Calculates normalization or symmetry properties of the lattice. This method sets up all necessary neighbor lists and lattice properties required for further computations.

get_region(kind: str | RegionType = RegionType.HALF, *, origin: int | List[float] | None = None, radius: float | None = None, direction: str | None = None, sublattice: int | None = None, sites: List[int] | None = None, depth: int | None = None, plaquettes: List[int] | None = None, **kwargs) List[int][source]

Return a list of site indices defining a spatial region.

Parameters:
  • kind (str or RegionType) – Type of region: ‘half’, ‘disk’, ‘sublattice’, ‘graph’, ‘plaquette’, ‘custom’. We also support specific half cuts like ‘half_x’, ‘half_y’, ‘half_z’ for convenience.

  • origin (int or list[float], optional) – Center of the region. Can be a site index or coordinate vector.

  • radius (float, optional) – Radius for ‘disk’ regions.

  • direction (str, optional) – Direction for ‘half’ cuts (‘x’, ‘y’, ‘z’).

  • sublattice (int, optional) – Sublattice index for ‘sublattice’ regions.

  • sites (list[int], optional) – Explicit list of sites for ‘custom’ regions.

  • depth (int, optional) – Depth/distance for ‘graph’ regions.

  • plaquettes (list[int], optional) – List of plaquette indices for ‘plaquette’ regions.

Returns:

Sorted list of site indices belonging to the region.

Return type:

list[int]

get_entropy_cuts(cut_type: str = 'all', *, include_sublattice: bool = True, sweep_by_unit_cell: bool | None = None) Dict[str, List[int]][source]

Return canonical bipartition cuts for entanglement-entropy workflows.

This is a convenience wrapper around self.regions.get_entropy_cuts().

generate_regions(kind: str | RegionType = RegionType.KITAEV_PRESKILL, **kwargs)[source]

Generate many region candidates for a selected region type.

This is a thin wrapper around self.regions.generate_regions().

property lx
property Lx
property ly
property Ly
property lz
property Lz
property area
property volume
property lxly
property lxlz
property lylz
property lxlylz
property dim
property sites
property size
property nsites
property ns
property Ns
property sites_per_cell: int

Sites per unit cell (1 for Bravais, 2 for honeycomb, etc.).

symmetry_perms(point_group: str = 'full') ndarray[source]

Generate space-group permutation table for this lattice.

Delegates to generate_space_group_perms().

When TWISTED boundary conditions are active, the point-group part is disabled (only translations are returned) because a generic flux breaks point-group symmetry unless the flux respects it.

Parameters:

point_group (str) – 'full' for maximal point group, 'translations' for translations only.

Return type:

ndarray, shape (|G|, Ns)

lattice_symmetries() Dict[str, object][source]

Return a dictionary describing the spatial symmetries of this lattice.

The information is consistent for both single-particle and many-body representations. When TWISTED boundary conditions are present the point-group part is absent (flux generically breaks it).

Returns:

Keys: - 'lattice_type' : LatticeType enum - 'sites_per_cell' : int - 'n_cells' : number of unit cells - 'dim' : spatial dimension - 'bc' : boundary condition enum - 'is_periodic' : (bool, bool, bool) per direction - 'is_twisted' : bool - 'translation_group' : ZL_x x ZL_y (as tuple (Lx, Ly)) - 'point_group' : str or None ('D4' for square Lx==Ly, etc.) - 'space_group_order' : total number of space-group elements - 'flux' : BoundaryFlux or None

Return type:

dict

symmetry_info() str[source]

Return a human-readable summary of the lattice symmetries.

Consistent for both single-particle (band-structure / Bloch) and many-body (Hilbert-space symmetry sectors) viewpoints.

Return type:

str

property a1
property a2
property a3
property k1
property b1
property k2
property b2
property k3
property b3
property n1
property n2
property n3
property basis
property multipartity
property vectors
property avec
property bvec
property dft

Return the discrete Fourier transform (DFT) matrix for the lattice.

property nn

Return the nearest-neighbor connectivity matrix for the lattice.

property bonds

Return the bond connectivity matrix for the lattice.

property nn_forward

Return the forward nearest-neighbor connectivity matrix for the lattice.

property nnn

Return the next-nearest-neighbor connectivity matrix for the lattice.

property nnn_forward

Return the forward next-nearest-neighbor connectivity matrix for the lattice.

property coordinates

Return the real-space coordinates of the lattice sites.

property subs

Return the sublattice indices of the lattice sites. For a Bravais lattice, this would simply be an array of zeros. For a non-Bravais lattice, this would indicate which sublattice each site belongs to.

property cells

Return the unit cell coordinates of the lattice sites. For a Bravais lattice, this would simply be the integer coordinates of the unit cells. For a non-Bravais lattice, this would include the basis vectors as well.

property fracs

for a square lattice, these would be (x/Lx, y/Ly, z/Lz) for each site.

Type:

Return fractional coordinates of the lattice sites. Example

property kvectors

Return the allowed k-vectors in reciprocal space for the lattice.

property rvectors

Return the allowed r-vectors in real space for the lattice.

property bc
property bc_x
property bc_y
property bc_z
property cardinality
property name
property type
sublattice(site: int) int[source]

Return the sublattice index for a given site. By default, returns 0 for all sites (single sublattice). Override in subclasses for multi-sublattice lattices.

k_vector(qx, qy=0.0, qz=0.0) ndarray[source]

Return the k-vector in Cartesian coordinates for given (qx, qy, qz) in reciprocal lattice units.

k_grid(n_k: int | Tuple[int, int, int], shift: bool | Tuple[bool, bool, bool] | None = None) ndarray[source]

Generate a full k-point grid for the given lattice.

Parameters:
  • lattice (Lattice) – Lattice object with reciprocal lattice vectors _k1, _k2, _k3.

  • n_k (Iterable[int]) –

    Number of points (Lx, Ly, Lz) along each reciprocal direction.

    We define the k-points as: k = f1 * b1 + f2 * b2 + f3 * b3, where f_i = n_i / N_i, with n_i = 0, 1, …, N_i - 1.

Returns:

k_points – Cartesian coordinates of k-points in reciprocal space.

Return type:

np.ndarray, shape (Nk, dim)

extract_momentum(eigvecs: ndarray, *, eigvals: ndarray = None, tol: float = 1e-10) ndarray[source]

Extract crystal momentum vectors k from real-space eigenvectors.

wigner_seitz_extend(k_points: ndarray, data: ndarray | None = None, *, copies: int | Iterable[int] | None = None, **kwargs) Tuple[ndarray, ndarray | None][source]

Extend k-space points and optional data across translated Brillouin zones.

The helper works for arbitrary k-space dimensions and any number of reciprocal translation vectors. Legacy b1/b2/b3 with nx/ny/nz remain supported for existing callers.

Allows to generate extended k-point grids for plotting band structures along high-symmetry paths…

Parameters:
  • k_points (ndarray, shape (N, dim)) – Array of k-points in reciprocal space to be extended.

  • data (ndarray, shape (N, ...) or None) – Optional data associated with each k-point (e.g. eigenvalues) to be extended alongside the k-points. Must have the same leading dimension as k_points.

  • copies (int or iterable of ints, optional) – Number of translated copies to generate in each reciprocal direction. If an integer is provided, the same number of copies will be generated in all directions. If an iterable is provided, it should have a length equal to the number of reciprocal lattice vectors (e.g. 3 for 3D), specifying the number of copies in each direction separately.

  • **kwargs – Additional keyword arguments to pass to the underlying ws_extend function. See its documentation for details.

Returns:

  • extended_k_points (ndarray, shape (M, dim)) – Array of extended k-points in reciprocal space, including the original points and their translated copies

  • extended_data (ndarray, shape (M, …) or None) – Extended data associated with each k-point, if the input data was provided. Otherwise, None

wigner_seitz_mask(Kx, Ky=None, Kz=None, *, shells: int = 1, tol: float = 1e-12, **kwargs) ndarray[source]

Return a boolean mask for the Wigner-Seitz cell in reciprocal space. This can be used to identify which k-points lie within the first Brillouin zone.

Parameters:
  • Kx (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • Ky (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • Kz (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • shells (int) – Number of shells of Wigner-Seitz cell to include in the mask.

  • tol (float) – Tolerance for determining if a point is within the Wigner-Seitz cell, accounting for numerical precision issues.

  • **kwargs – Additional keyword arguments to pass to the underlying ws_bz_mask function. See its documentation for details.

wigner_seitz_shifts(*, copies: int | Iterable[int] | None = None, include_origin: bool = False, tol: float = 1e-12, **kwargs) ndarray[source]

Return reciprocal-lattice translation vectors for Brillouin-zone copies.

This is the shared geometry helper for selecting or drawing translated Brillouin zones. It returns zone-center shifts only, not an extended k-mesh.

Parameters:
  • copies (int or iterable of int, optional) – Number of translated copies to generate in each reciprocal direction.

  • include_origin (bool, default=False) – Whether to include the central Brillouin zone at Gamma.

  • tol (float, default=1e-12) – Tolerance used when removing numerically duplicated shifts.

  • **kwargs – Additional keyword arguments forwarded to tools.lattice_kspace.ws_bz_shifts.

Returns:

Array of reciprocal-space translation vectors for zone copies.

Return type:

np.ndarray

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for this lattice type.

Override in subclasses to provide lattice-specific high-symmetry points. Returns None if not defined for this lattice type.

Returns:

High-symmetry points with default path, or None if not defined.

Return type:

HighSymmetryPoints or None

Example

>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> pts = lattice.high_symmetry_points()
>>> print(pts.Gamma.frac_coords)  # (0.0, 0.0, 0.0)
>>> print(pts.default_path())     # ['Gamma', 'X', 'M', 'Gamma']
default_bz_path() List[Tuple[str, List[float]]] | None[source]

Return the default Brillouin zone path for this lattice.

Returns:

Default path as list of (label, [f1, f2, f3]) tuples, or None if not defined.

Return type:

List[Tuple[str, List[float]]] or None

default_resolve_path(path: Iterable[tuple[str, Iterable[float]]] | StandardBZPath | str | List[str] | HighSymmetryPoints) List[Tuple[str, List[float]]][source]

Resolve path input to a list of (label, fractional_coord) pairs.

Parameters:
  • path (list[(label, coords)], StandardBZPath, str, List[str], or HighSymmetryPoints) – Path definition (fractional coordinates), standard enum, enum name string, list of point labels, or HighSymmetryPoints object.

  • lattice (Lattice, optional) – Lattice object used to resolve labels if path is a list of strings.

Returns:

resolved_path – Resolved path as a list of (label, fractional_coord) pairs.

Return type:

list[(label, list[float])]

Example

>>> path = _resolve_path_input("SQUARE_2D")
>>> for label, coord in path:
...     print(f"{label}: {coord}")
contains_special_point(point: str | HighSymmetryPoint | Tuple[float, ...] | ndarray, *, tol: float = 1e-12) bool[source]

Return True if the lattice momentum grid contains a special point. This method helps to check whether a finite lattice contains a particular high-symmetry point in the Brillouin zone, which is important for band structure calculations and topological analyses.

Parameters:
  • point – Special point identifier. Accepted forms: - label string (e.g. "Gamma", "K", "K'"), - HighSymmetryPoint, - explicit fractional coordinate tuple/array.

  • tol (float) – Absolute tolerance used in the coordinate match.

Notes

The check is done in fractional reciprocal coordinates and naturally includes flux-induced shifts from twisted boundary conditions because it uses self.kvectors_frac.

bz_path(path: List[str] | str | StandardBZPath | None = None, *, points_per_seg: int = 40) Tuple[ndarray, ndarray, List[Tuple[int, str]], ndarray][source]

Generate k-points along a Brillouin zone path.

Parameters:
  • path (list of str, str, StandardBZPath, or None) – Path specification. Can be: - List of high-symmetry point names: [‘Gamma’, ‘X’, ‘M’, ‘Gamma’] - StandardBZPath enum or string: ‘SQUARE_2D’ - None: use default path for this lattice

  • points_per_seg (int) – Number of interpolated points per path segment.

Returns:

  • k_path (np.ndarray, shape (Npath, 3)) – Cartesian k-points along the path.

  • k_dist (np.ndarray, shape (Npath,)) – Cumulative distance for plotting x-axis.

  • labels (List[Tuple[int, str]]) – Indices and labels for high-symmetry points.

  • k_path_frac (np.ndarray, shape (Npath, 3)) – Fractional k-coordinates along the path.

Example

>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> k_path, k_dist, labels, k_frac = lattice.bz_path()
>>> # Or with custom path:
>>> k_path, k_dist, labels, k_frac = lattice.bz_path(['Gamma', 'M', 'Gamma'])
bz_path_points(path: List[str] | str | StandardBZPath | None = None, *, points_per_seg: int = 40, k_vectors: np.ndarray | None = None, k_vectors_frac: np.ndarray | None = None, tol: float = 1e-12, periodic: bool = True) KPathSelection[source]

Build an ideal Brillouin-zone path and optionally match it to an existing k-grid.

If no k-grid is provided, the returned object still contains the continuous path geometry, which is useful for plotting or for constructing a path that is not constrained to the sampled reciprocal mesh. When a sampled grid is provided, reciprocal-lattice copies are generated automatically as needed so paths in extended Brillouin-zone regions can still match the existing data.

Parameters:
  • path (list of str, str, StandardBZPath, or None) – Path specification. Can be: - List of high-symmetry point names: [‘Gamma’, ‘X’, ‘M’, ‘Gamma’] - StandardBZPath enum or string: ‘SQUARE_2D’ - None: use default path for this lattice

  • points_per_seg (int) – Number of interpolated points per path segment.

  • k_vectors (np.ndarray, shape (Nk, 3), optional) – Cartesian k-vectors of the existing grid to match against.

  • k_vectors_frac (np.ndarray, shape (Nk, 3), optional) – Fractional k-vectors of the existing grid to match against. Required if k_vectors is provided.

  • tol (float) – Tolerance for matching path points to the existing k-grid. With periodic=True it is interpreted in fractional reciprocal coordinates. With periodic=False it is interpreted in plotted Cartesian reciprocal coordinates.

  • periodic (bool, default=True) – If True, allow reciprocal-translation-equivalent points to match. Set to False for visual matching in the displayed Brillouin-zone copy.

bz_path_data(k_vectors: ndarray, k_vectors_frac: ndarray, values: ndarray, path: List[str] | Literal['CHAIN_1D', 'SQUARE_2D', 'TRIANGULAR_2D', 'CUBIC_3D', 'HONEYCOMB_2D'] | str | StandardBZPath | None = None, *, points_per_seg: int = 40, return_result: bool = True) KPathResult | Tuple[ndarray, ndarray, List[Tuple[int, str]], ndarray][source]

Extract k-path data from a k-grid using fractional coordinate matching.

This function finds the closest k-points on the actual grid to an ideal path through high-symmetry points. It handles periodic boundary conditions in k-space and automatically reuses reciprocal-lattice copies of the sampled grid when the requested path lies in an extended Brillouin-zone region. It also allows to return a structured KPathResult dataclass or a tuple…

Parameters:
  • lattice (Lattice) – Lattice object with reciprocal lattice vectors

  • k_vectors (np.ndarray, shape (..., 3)) – Cartesian k-points (will be flattened)

  • k_vectors_frac (np.ndarray, shape (..., 3)) – Fractional coordinates of k-points (will be flattened)

  • values (np.ndarray) – Data values sampled on the k-grid. The k-grid axes may appear as (Lx, Ly, Lz, ...) or after leading batch axes such as time or frequency, e.g. (Nw, Lx, Ly, Lz) or (Nw, Lx, Ly, Lz, ...). A single flattened k-grid axis of length Nk is also supported.

  • path (various, optional) – Path specification. Can be: - StandardBZPath enum value (e.g., StandardBZPath.SQUARE_2D) - String name (e.g., ‘SQUARE_2D’) - List of (label, [f1,f2,f3]) tuples - HighSymmetryPoints object (uses default path) - None: uses lattice’s default path if available

  • points_per_seg (int) – Number of interpolated points per path segment

  • return_result (bool) – If True (default), return KPathResult dataclass. If False, return tuple for backwards compatibility.

Returns:

If return_result=True: KPathResult dataclass with all path data. The returned values preserve any leading batch axes and replace the k-grid axes with a path axis. If return_result=False: (k_cart, k_frac, k_dist, labels, values) tuple

Return type:

KPathResult or tuple

Examples

>>> # Using default path from HighSymmetryPoints
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, HighSymmetryPoints.square_2d())
>>> plt.plot(result.k_dist, result.values)
>>> # Using standard path enum
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, 'SQUARE_2D')
>>> # Custom path
>>> custom_path = [('G', [0,0,0]), ('X', [0.5,0,0]), ('G', [0,0,0])]
>>> result      = bz_path_data(lattice, k_grid, k_frac, energies, custom_path)
property flux: BoundaryFlux
set_flux(value: float | Mapping[str | LatticeDirection, float] | None, *, reinit: bool = True) None[source]

Set boundary flux and optionally recalculate k-vectors, DFT, and neighbors.

Parameters:
  • value (float, Mapping, or None) – New flux specification (see _normalize_flux_dict()).

  • reinit (bool) – If True (default), recalculate reciprocal vectors, k-vectors, DFT matrix, and neighbor lists to be consistent with the new flux.

property has_flux: bool

True when a non-trivial boundary flux is attached.

property is_twisted: bool

True when the boundary conditions are TWISTED.

property is_topological: bool

True when the lattice carries a non-trivial boundary flux.

A non-trivial flux (mod \(2\pi\)) introduces a measurable Aharonov-Bohm phase and may change the topological sector of the ground state.

flux_summary() str[source]

Return a human-readable summary of the boundary-flux configuration.

boundary_phase(direction: LatticeDirection, winding: int = 1) complex[source]

Return the complex phase accumulated after crossing the boundary along direction.

Parameters:

directionLatticeDirection

The lattice direction (X, Y, or Z).

windingint

The winding number (number of times crossing the boundary).

Returns:

complex

The complex phase factor e^{i * flux * winding}.

boundary_phases() ndarray[source]

Return a lookup table of complex boundary phases.

Returns:

tabletable[d, w] is exp(i * w * phi_d) for direction d and winding number w.

Return type:

np.ndarray, shape (3, Ns+1)

boundary_phase_from_winding(wx: int, wy: int, wz: int) complex[source]

Return total complex boundary phase accumulated from winding numbers. If no winding (all zero), returns real 1.0.

bond_winding(i: int, j: int) tuple[int, int, int][source]

Compute how many times a bond (i -> j) crosses the periodic boundary in each lattice direction.

Returns (wx, wy, wz), where each entry is 0 if no crossing, +1 if wrapped positively, -1 if wrapped negatively.

Parameters:

iint

Index of the starting lattice site.

jint

Index of the ending lattice site.

Returns:

tuple[int, int, int]

A tuple indicating the winding numbers (wx, wy, wz) for the bond from site i to site j.

is_spanning(sites: Iterable[int]) bool[source]

Check if a set of sites spans the lattice (non-contractible on a torus).

This method uses a BFS-based winding number tracking on the induced subgraph of the provided site indices. If any loop with a non-zero winding number along a periodic direction is found, the set is considered spanning.

bond_phase(i: int, j: int) complex[source]

Return the complex hopping phase factor for the bond \(i \to j\).

For bonds that do not cross a periodic boundary, this is 1. For boundary-crossing bonds under TWISTED BC, the phase is \(\exp(i\,\phi_\mu)\) for each direction \(\mu\) in which the bond wraps.

This is the factor that should multiply the bare hopping amplitude in real-space Hamiltonian construction.

Parameters:
  • i (int) – Source and target site indices.

  • j (int) – Source and target site indices.

Returns:

Phase factor (unit modulus).

Return type:

complex

hopping_matrix_with_flux(*, include_nnn: bool = False) ndarray[source]

Build an \(N_s \times N_s\) matrix of complex hopping amplitudes that includes the Peierls phases from boundary fluxes.

Diagonal is zero. Off-diagonal H[i,j] = t_{ij} * phase(i->j) where t_{ij} = 1 for all connected pairs and phase is the product of boundary phases along directions that the bond wraps.

Parameters:

include_nnn (bool) – If True, include next-nearest-neighbor hoppings as well.

Returns:

H – Complex hopping matrix.

Return type:

np.ndarray, shape (Ns, Ns)

get_nnn_middle_sites(i: int, j: int, orientation: str | None = None) list[int][source]

Return the list of ‘middle’ sites l that are nearest neighbors of both i and j - i.e., sites forming two-step NNN paths i-l-j.

Works for any lattice that implements get_nn(site, idx) and get_nn_num(site).

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

  • orientation ({'anticlockwise', 'clockwise', None}, optional) – If provided, will sort/choose based on geometric angle. Default: None (return all middle sites).

Returns:

List of middle-site indices (can be 0, 1, or 2 elements).

Return type:

list[int]

get_chirality_sign(i: int, j: int, normal: ndarray | None = None, orientation: str | None = None) int[source]

Compute the local orientation (chirality) sign nu_{ij} = pm 1 for a NNN pair (i,j), defined by the cross product of the two bond vectors i-l and l-j.

Works for any 2D or quasi-2D lattice with known site coordinates.

Parameters:
  • i (int) – Site indices (next-nearest neighbors).

  • j (int) – Site indices (next-nearest neighbors).

  • normal (np.ndarray, optional) – Orientation of the lattice plane (default: +z for 2D).

Returns:

+1 for anticlockwise, -1 for clockwise, 0 if not a valid NNN pair.

Return type:

int

bond_type(i: int, j: int) str[source]

Determine the bond type between sites i and j.

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

Returns:

‘nn’ for nearest neighbor, ‘nnn’ for next-nearest neighbor, ‘none’ otherwise.

Return type:

str

periodic_flags() Tuple[bool, bool, bool][source]

Return booleans indicating whether (x, y, z) directions are periodic.

TWISTED boundary conditions are topologically equivalent to PBC (the lattice is still a torus), so all three directions are periodic.

is_periodic(direction: LatticeDirection | None = None, allow_twisted: bool = True) bool[source]

Check if a given direction has periodic boundary conditions.

property typek
property spatial_norm
site_index(x: int, y: int, z: int)[source]

Convert (x, y, z) coordinates to a unique site index (row-major).

Default implementation uses standard lexicographic ordering. Override in subclasses if a different indexing convention is needed.

site_diff(i: int | tuple, j: int | tuple, *, minimum_image: bool = False, real_space: bool = False) Tuple[float, float, float][source]

Return the displacement i -> j with optional PBC minimum-image wrapping.

Parameters:
  • i (int or tuple) – Site indices or explicit coordinates.

  • j (int or tuple) – Site indices or explicit coordinates.

  • minimum_image (bool, default=False) – If True, wrap each periodic direction to the shortest displacement.

  • real_space (bool, default=False) – If True and i, j are site indices, return displacement in real-space vectors (uses displacement()). Otherwise use lattice coordinates.

site_distance(i: int | tuple, j: int | tuple, *, minimum_image: bool = False, real_space: bool = False) float[source]

Return Euclidean distance between two sites/coordinates.

Parameters:
  • minimum_image (bool, default=False) – If True, periodic directions use minimum-image convention.

  • real_space (bool, default=False) – If True and inputs are indices, measure in real-space lattice vectors.

calculate_reciprocal_vectors()[source]

Calculates the reciprocal lattice vectors based on the primitive vectors. Always returns 3D vectors (padding with zeros for lower dimensions).

Returns: - k1, k2, k3 : Reciprocal lattice vectors (always 3D)

calculate_dft_matrix(phase=False, use_fft: bool = False) ndarray[source]

Bloch-type DFT matrix on the site basis.

Indices:

i = (R, beta) real-space cell R and sublattice beta n = (k, alpha) k-point k and sublattice alpha

Elements: $$

F_{(k,alpha),(R,beta)} =

1/sqrt(Nc) * delta_{alpha,beta} * exp(-i k . R).

$$ This is unitary: $$

F^dagger F = I_{Ns}, F F^dagger = I_{Ns},

$$ where Ns = Nc * Nb is the total number of sites, Nc is the number of unit cells, and Nb is the number of sublattices.

Important

When boundary fluxes are present (TWISTED BC), the k-grid used to build the DFT matrix is shifted by phi_mu / (2 pi L_mu) in each direction, exactly as in calculate_k_vectors().

Note that this DFT matrix does not include basis-dependent phases (i.e., exp(-i k . r_basis)).

Calculates the Discrete Fourier Transform (DFT) matrix for the lattice. This method can be optimized using FFT (Fast Fourier Transform) in the future. Reference: https://en.wikipedia.org/wiki/DFT_matrix

Parameters:
  • (bool) (- phase)

  • Returns

  • (ndarray) (- DFT matrix)

get_nei(site: int, **kwargs)[source]

Returns the nearest neighbors of a given site.

Parameters:

direction (-)

get_nei_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

abstractmethod get_real_vec(x: int, y: int, z: int)[source]

Returns the real vector given the coordinates. Uses the lattice constants.

abstractmethod get_norm(x: int, y: int, z: int)[source]

Returns the norm of the vector given the coordinates.

abstractmethod get_nn_direction(site: int, direction: LatticeDirection)[source]

Returns the nearest neighbors in a given direction.

get_nnn_direction(site: int, direction: LatticeDirection)[source]

Returns the next nearest neighbors in a given direction.

wrong_nei(nei)[source]

Check if a given neighbor index is invalid.

A neighbor is considered invalid if it is:
  • None

  • Equal to self.bad_lattice_site

  • NaN (not a number)

  • Less than 0

Parameters:

nei (Any) – The neighbor index to check.

Returns:

True if the neighbor index is invalid, False otherwise.

Return type:

bool

get_nn_num(site: int)[source]

Returns the number of nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of nearest)

get_nn(site, num: int = -1)[source]

Returns the nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

get_nnn_num(site: int)[source]

Returns the number of next nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of next nearest)

get_nnn(site, num: int = -1)[source]

Returns the next nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of next nearest neighbors

get_nn_forward_num_max()[source]

Returns the maximum number of forward nearest neighbors in the lattice.

Returns: - maximum number of nearest neighbors

get_nn_forward_num(site: int)[source]

Returns the number of forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of nearest)

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

get_nnn_forward_num(site: int)[source]

Returns the number of forward next nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of next nearest)

get_nnn_forward(site: int, num: int = -1)[source]

Returns the forward next nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of next nearest neighbors

neighbors(site: int, order=1)[source]

Return neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), ‘all’ for both.

neighbors_forward(site: int, order=1)[source]

Return forward neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), ‘all’ for both.

any_neighbor(site: int, order=1)[source]

Return any neighbor (first) of given order or None.

any_neighbor_forward(site: int, order=1)[source]

Return any forward neighbor (first) of given order or None.

property n_nodes: int

Number of nodes (sites) in the lattice — alias for Ns.

property n_edges: int

Number of unique undirected nearest-neighbour edges.

property positions: ndarray

Real-space position vectors (same as rvectors).

property site_offsets: ndarray

Position offsets of sites inside the unit cell (same as basis).

property basis_coords: ndarray

Integer basis coordinates [nx, ny, nz, sub] for every site.

Shape (Ns, 4) — the first three columns are the cell-index triplet and the last column is the sublattice label.

property ndim: int

Spatial dimensionality of the lattice.

property extent: Tuple[int, ...]

Number of unit cells in each direction (Lx, Ly, Lz).

property pbc: Tuple[bool, bool, bool]

Per-axis periodicity flags (alias for periodic_flags()).

edges(*, filter_color: int | None = None, return_color: bool = False) List[source]

Return list of nearest-neighbour edges.

Parameters:
  • filter_color (int, optional) – If given, return only edges whose bond_type equals this colour.

  • return_color (bool) – If True each element is (i, j, color); otherwise (i, j).

Returns:

Unique undirected edges (i, j) with i < j.

Return type:

list[tuple]

property edge_colors: List[int]

Sequence of bond-type colours for every edge in edges(), matching the order returned by edges().

displacement(i: int, j: int, *, minimum_image: bool = True) ndarray[source]

Real-space displacement vector from site i to site j.

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

  • minimum_image (bool) – If True (default) and the lattice is periodic, return the shortest displacement under periodic boundary conditions.

Return type:

np.ndarray shape (3,)

distance(i: int, j: int, *, minimum_image: bool = True) float[source]

Euclidean distance between sites i and j (PBC-aware by default).

get_coordinates(*args)[source]
get_r_vectors(*args)[source]
get_k_vectors(*args)[source]
get_site_diff(i: int, j: int)[source]
get_k_vec_idx(sym=False)[source]
get_dft(*args)[source]

Returns the DFT matrix

get_spatial_norm(*args)[source]

Returns the spatial norm at lattice site i or all of them

get_difference_idx_matrix(cut=True) list[source]

Returns the matrix with indcies corresponding to a slice from the QMC. A usefull function for reading the position Green’s function saved from: @url https://github.com/makskliczkowski/DQMC The Green’s functions are saved in the following manner. If cut is True, data has (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i. For 1D simulation: 1 column and (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, …, 0, …, Lx) For 2D simulation: (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, …, 0, …, Lx) and (2 * Ly - 1) columns for possition differences (-Ly, -Ly + 1, …, 0, …, Ly) For 3D simulation: Same as in 2D but after (2 * Lx - 1) x (2 * Ly - 1) matrix has finished, a new slice for Lz appears for next columns Lz * (2*Ly - 1) - cut : if true (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i.

calculate_bonds()[source]

Calculates the bonds for the lattice using forward nn.

calculate_coordinates()[source]

Calculates the coordinates for each lattice site in up to 3D.

Each site index i corresponds to:

cell = i // n_basis sub = i % n_basis

where n_basis = len(self._basis) (e.g., 2 for honeycomb).

Works for any lattice with defined self._a1, _a2, _a3 and self._basis list.

calculate_r_vectors()[source]

Calculates the real-space vectors (r) for each site. Must match the ordering in calculate_coordinates().

calculate_k_vectors()[source]

Calculates the allowed reciprocal-space k-vectors (momentum grid) consistent with the lattice size and primitive reciprocal vectors.

When boundary fluxes are present (TWISTED BC), the fractional coordinates are shifted by \(\phi_\mu / (2\pi L_\mu)\) in each direction, so that the Bloch condition matches the twisted boundary.

The sampling follows the same fftfreq ordering used by the Bloch transform (Γ at index [0,0,0], followed by positive frequencies and finally the negative branch). This keeps the analytic grids aligned with the numerically constructed H(k) blocks.

filter_k_vectors(qx: int | None = None, qy: int | None = None, qz: int | None = None) ndarray[source]

Filters the k-vectors to find those matching the specified fractional components.

Parameters:
  • (int) (qx)

  • (int (qz)

  • optional) (Fractional component in the z-direction. Defaults to None.)

  • (int

  • optional)

Returns:

Array of indices of k-vectors matching the specified components.

Return type:

np.ndarray

translation_operators()[source]

Return translation matrices T1, T2, T3 on the one-hot basis.

calculate_norm_sym()[source]

Calculate a symmetry-normalization measure for each site.

Default: Euclidean norm of the coordinate vector. Override in subclasses for lattice-specific behaviour.

abstractmethod calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
calculate_nn()[source]

Calculates the nearest neighbors.

For TWISTED boundary conditions the neighbor connectivity is identical to PBC — the flux phases are applied separately when building the Hamiltonian or the DFT matrix.

calculate_plaquettes(use_obc: bool = True)[source]
calculate_wilson_loops()[source]

Calculates the Wilson loops (non-contractible loops) for the lattice based on its boundary conditions. Returns a list of lists, where each inner list contains the site indices of a Wilson loop.

Assumes standard lexicographic site indexing (x + y*Lx + z*Lx*Ly).

abstractmethod calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
calculate_nnn()[source]

Calculates the next nearest neighbors.

Like calculate_nn(), each calculate_nnn_in implementation is expected to set self._nnn (and optionally self._nnn_forward) directly. The return value—if any—is stored as a fallback.

adjacency_matrix(sparse: bool = False, save: bool = True, *, mode: str = 'binary', include_self: bool = False, include_nnn: bool = False, typed_self_separate: bool = True, n_types: int = 3) ndarray[source]

Construct adjacency matrix A_ij = 1 if i and j are neighbors.

Parameters:
  • save (bool) – save the adjacency matrix in the lattice object for future use.

  • mode (str) –

    ‘binary’ :

    A_ij = 1 if i and j are neighbors, 0 otherwise.

    ’typed’ :

    A_ij = weight of the bond between i and j (1 for nn, 2 for nnn, etc.), 0 otherwise.

  • include_self (bool) – include self-connections (diagonal elements) if True.

  • include_nnn (bool) – include next-nearest neighbors if True.

  • typed_self_separate (bool) – if True, self-connections are given a unique weight (n_types) to distinguish them from other types of connections.

  • n_types (int) – number of different neighbor types (nn, nnn, etc.) to consider.

  • sparse (bool) – return a scipy.sparse CSR matrix if True.

Returns:

adjacency matrix of size (Ns, Ns).

Return type:

A (ndarray or sparse CSR)

print_neighbors(logger: Logger)[source]

Logs the neighbors of each site in the lattice using the provided logger.

For each site in the lattice, this method retrieves its nearest neighbors and logs their indices. Additionally, for each neighbor, it logs detailed information using a higher verbosity level.

Parameters:

logger – An object with an info method for logging messages. The info method should accept parameters lvl (int) for verbosity level and color (str) for message color.

print_forward(logger: Logger)[source]

Logs the forward nearest neighbors for each site in the lattice.

For each site in the lattice, this method retrieves the number of forward nearest neighbors and logs their indices using the provided logger. The method outputs two levels of information: - Level 1 (green): Lists the neighbors of each site. - Level 2 (blue): Details each neighbor’s index for the site.

Parameters:

logger (A logging object with an info method that accepts a message,) – a logging level (lvl), and a color (color).

get_geometric_encoding(*, tol=1e-06)[source]

Map each site i to (cell_idx, sub_idx) purely from geometry.

Returns:

  • cell_idx ((Ns,) int array in [0, Nc-1])

  • sub_idx ((Ns,) int array in [0, Nb-1])

realspace_from_kspace(H_k: ndarray, *, block_diag: bool = True, kgrid: ndarray | None = None)[source]

Inverse Bloch transform: reconstruct real-space matrix from k-space blocks.

This is the exact inverse of kspace_from_realspace(). It reconstructs the real-space Hamiltonian from momentum-space blocks using the inverse Fourier transform:

\[H_{\text{real}} = \sum_k W(k)^\dagger H(k) W(k)\]

where \(W(k)\) is the Bloch unitary matrix.

Parameters:
  • H_k (np.ndarray) –

    K-space Hamiltonian blocks in one of two formats:

    • Grid format: shape (Lx, Ly, Lz, Nb, Nb) for full BZ grid

      (as returned by kspace_from_realspace with block_diag=True)

    • List format: shape (Nk, Nb, Nb) for custom k-points

    Must be in fftfreq order (no fftshift applied) to match the forward transform.

  • block_diag (bool, default=True) –

    Mode selector matching the forward transform:

    • If True: Expects H_k in block-diagonal format (grid or list of blocks)

      and returns reconstructed real-space matrix.

    • If False: Expects H_k as full transformed matrix (Ns, Ns) and

      applies inverse DFT directly.

  • kgrid (Optional[np.ndarray], default=None) –

    K-point grid for reference (only used when block_diag=True).

    • If None: Assumes H_k is on the full BZ grid in fftfreq order

    • If provided: Must match the k-points used for the forward transform

      Shape (Lx, Ly, Lz, 3) or (Nk, 3) in fftfreq order.

Returns:

H_real – Reconstructed real-space matrix with shape (Ns, Ns) where Ns = Nc * Nb is the total number of sites.

Return type:

np.ndarray

Notes

  • Round-trip accuracy:
    • Eigenvalues are preserved to machine precision (~1e-15)

  • Both H_k and kgrid must be in fftfreq order (no fftshift)

  • The reconstruction is exact for translationally invariant systems:
    • H_real_reconstructed H_real_original to numerical precision

  • For systems with periodic boundary conditions, the forward and inverse

transforms form a perfect isometry on the Hilbert space.

Examples

Example 1: Round-trip transform (full grid)

>>> # Forward transform
>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>>
>>> # Inverse transform
>>> H_real_recon = lattice.realspace_from_kspace(H_k, kgrid=k_grid)
>>>
>>> # Verify reconstruction
>>> np.allclose(H_real, H_real_recon)  # True

Example 2: Inverse transform without explicit kgrid

>>> # If kgrid is omitted, it's reconstructed using fftfreq convention
>>> H_real_recon = lattice.realspace_from_kspace(H_k)
>>> np.allclose(H_real, H_real_recon)  # True

Example 3: Full matrix mode (inverse DFT)

>>> H_k_full        = lattice.kspace_from_realspace(H_real,     block_diag=False)
>>> H_real_recon    = lattice.realspace_from_kspace(H_k_full,   block_diag=False)
>>> np.allclose(H_real, H_real_recon)  # True

See also

kspace_from_realspace

Forward Bloch transform (real-space to k-space)

structure_factor

Compute momentum-resolved structure factors

References

kspace_from_realspace(mat: ndarray, block_diag: bool = False, kpoints: ndarray | None = None, unitary_norm: bool = True, return_transform: bool = False)[source]

Transform a real-space matrix (Hamiltonian, operator, correlator) to momentum space.

This method provides a convenient interface to the Bloch transform for periodic systems. The transform uses the formula:

\[H_{ab}(k) = \sum_{i,j} W^*_{i,a}(k) H_{i,j} W_{j,b}(k)\]

where \(W_{i,a}(k) = \frac{1}{\sqrt{N_c}} e^{-ik \cdot r_i} \delta_{\text{sub}(i),a}\)

Parameters:
  • mat (np.ndarray) – Real-space matrix with shape (Ns, Ns) where Ns = Nc * Nb is the total number of sites (unit cells x basis sites per cell).

  • block_diag (bool, default=False) –

    Mode selector for different output formats:

    • If False: Returns full transformed matrix H_k_full with shape (Ns, Ns)

      This is the complete DFT of the real-space matrix, useful for structure factors.

    • If True: Returns block-diagonal form with k-space blocks H_k, momentum grid,

      and fractional coordinates. This is the standard mode for band structure calculations.

    Output: (H_k, k_grid, k_grid_frac) where:
    • H_k: shape (Lx, Ly, Lz, Nb, Nb) - Hamiltonian blocks at each k-point

    • k_grid: shape (Lx, Ly, Lz, 3) - Cartesian k-point coordinates

    • k_grid_frac: shape (Lx, Ly, Lz, 3) - Fractional k-point coordinates

  • kpoints (Optional[np.ndarray], default=None) –

    Custom k-point sampling (only used when block_diag=True):

    • If None: Uses automatic full Brillouin zone grid based on lattice size

      (recommended for most use cases)

    • If provided: Array of shape (Nk, 3) with custom k-points in Cartesian coordinates

      Returns (H_k, kpoints) with H_k shape (Nk, Nb, Nb)

  • unitary_norm (bool, default=True) – Use unitary normalization \(1/\sqrt{N_c}\) for the Bloch transform. If False, uses normalization \(1/N_c\) instead. Keep True for standard quantum mechanics convention preserving operator norms.

  • return_transform (bool, default=False) –

    If True, also return the Bloch unitary matrix W used for the transformation. This is useful for transforming additional operators or computing correlation functions.

    Note: Only available when block_diag=True. The unitary is returned as a 4th output value with shape (Lx, Ly, Lz, Ns, Nb) or (Nk, Ns, Nb) if custom k-points are provided.

Returns:

  • **Case 1 (block_diag=False (default)**) –

    H_k_fullnp.ndarray

    Full transformed matrix with shape (Ns, Ns). This is the complete DFT of the input matrix, preserving all information.

  • **Case 2 (block_diag=True, kpoints=None (full grid)**) –

    H_knp.ndarray

    K-space Hamiltonian blocks with shape (Lx, Ly, Lz, Nb, Nb) where:

    • Lx, Ly, Lz are the lattice dimensions

    • Nb is the number of basis sites per unit cell

    • H_k[ix, iy, iz] is the Nb x Nb block at k-point [ix, iy, iz]

    k_gridnp.ndarray

    Cartesian k-point coordinates with shape (Lx, Ly, Lz, 3). The Γ-point is at index [Lx//2, Ly//2, Lz//2] after fftshift.

    k_grid_fracnp.ndarray

    Fractional k-point coordinates with shape (Lx, Ly, Lz, 3). Values are in the range [0, 1) corresponding to the first Brillouin zone.

    Wnp.ndarray, optional

    Bloch unitary matrix with shape (Lx, Ly, Lz, Ns, Nb). Only returned if return_transform=True. Use for transforming operators: O_k = W† @ O_real @ W

  • **Case 3 (block_diag=True, kpoints provided (custom sampling)**) –

    H_knp.ndarray

    K-space Hamiltonian blocks with shape (Nk, Nb, Nb) where Nk is the number of custom k-points provided.

    kpoints_outnp.ndarray

    Echo of the input k-points with shape (Nk, 3).

    Wnp.ndarray, optional

    Bloch unitary matrix with shape (Nk, Ns, Nb). Only returned if return_transform=True.

Examples

Example 1: Full matrix transform for structure factor

>>> H_k_full = lattice.kspace_from_realspace(H_real, block_diag=False)
>>> # H_k_full has shape (Ns, Ns)

Example 2: Block-diagonal form for band structure (recommended)

>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>> # H_k has shape (Lx, Ly, Lz, Nb, Nb)
>>> # Diagonalize each block to get bands
>>> energies = np.linalg.eigvalsh(H_k)  # shape (Lx, Ly, Lz, Nb)

Example 3: Custom k-points (e.g., high-symmetry path)

>>> k_path = lattice.generate_kpath(['Γ', 'X', 'M', 'Γ'], npoints=100)
>>> H_k, k_pts = lattice.kspace_from_realspace(
...     H_real, block_diag=True, kpoints=k_path
... )
>>> # H_k has shape (100, Nb, Nb)
>>> energies = np.linalg.eigvalsh(H_k)  # shape (100, Nb)

Example 4: Get Bloch unitary for operator transforms

>>> H_k, k_grid, k_frac, W = lattice.kspace_from_realspace(
...     H_real, block_diag=True, return_transform=True
... )
>>> # Transform another operator using the same W
>>> O_k = np.einsum('kia,ij,kjb->kab', W.conj(), O_real, W)

Notes

  • Periodic boundary conditions (PBC) are assumed for the Bloch transform.

  • The method assumes translational invariance of the system, which ensures

the spectrum of H_real equals the union of spectra of H(k) blocks. - For the full grid (kpoints=None), the k-points follow the fftfreq convention with the Γ-point initially at index [0, 0, 0], then shifted to the center. - Site ordering is arbitrary; the method uses the lattice geometry (coordinates + basis) to correctly identify sublattices and apply phases. - For sparse input matrices, automatic conversion to dense format is performed.

See also

realspace_from_kspace

Inverse transform from k-space to real-space

structure_factor

Compute momentum-resolved structure factors with reduction options

generate_kpath

Generate high-symmetry k-point paths for band structure plotting

References

structure_factor(mat: ndarray, *, reduction: Literal['none', 'sum', 'trace', 'mean', 'diag'] = 'sum', norm: Literal['none', 'cell', 'site'] = 'none')[source]

Convert a real-space correlation matrix into a momentum-resolved structure factor.

This is a convenience wrapper around the basis-aware Bloch projector in QES.general_python.lattices.tools.lattice_kspace.kspace_from_realspace. The real-space input mat is first transformed into the multipartite k-space block representation evaluated on self.kvectors

\[C_{\alpha\beta}(q) = \frac{1}{N_c} \sum_{R,R'} e^{-i q\cdot(R-R')} \langle O_{R,\alpha} O_{R',\beta} \rangle,\]

where R, R' label unit cells and alpha, beta label basis sites inside the unit cell. The reduction argument then decides how this multipartite object is converted into a scalar structure factor at each sampled momentum q.

Parameters:
  • mat (np.ndarray) – Real-space correlation or operator matrix with shape (Ns, Ns) or batched shape (..., Ns, Ns). Any leading axes, e.g. time, frequency, disorder sample, or state index, are preserved.

  • reduction ({"none", "sum", "trace", "mean", "diag"}, default="sum") –

    How to reduce the multipartite k-space blocks:

    • "none":

      return the full k-space blocks

      C(q) with shape (Lx, Ly, Lz, Nb, Nb) (i.e., no reduction).

    • "sum":

      return sum_{alpha,beta} C_{alpha beta}(q) (i.e., sum over all entries of each block).

    • "trace":

      return sum_alpha C_{alpha alpha}(q) (i.e., sum over diagonal entries of each block).

    • "mean":

      return the arithmetic mean of all multipartite block entries at each q (i.e., sum over all entries and divide by Nb^2).

    • "diag":

      return the eigenvalues of each block, which can be useful for identifying dominant modes or instabilities. The output shape will be (Lx, Ly, Lz, Nb) since each block’s eigenvalues are returned as a vector of length Nb.

  • norm ({"none", "cell", "site"}, default="none") –

    Optional post-normalization of the returned k-space quantity:

    • "none":

      keep the raw Bloch-projector normalization, i.e. the blocks

      C(q) defined above with the prefactor 1 / N_c.

    • "cell":

      alias for "none" kept for readability when you want to

      emphasize unit-cell normalization.

    • "site":

      divide the returned blocks or reduced values by the number of

      basis sites N_b. For scalar reductions such as "sum", this converts the default unit-cell normalization into the more common site normalization 1 / N_s used in \(S(q) = \langle O_{-q} O_q \rangle\).

Returns:

  • values (np.ndarray) – Momentum-resolved structure factor. For input shape (..., Ns, Ns) the output shape is:

    • (..., Lx, Ly, Lz, Nb, Nb) for reduction="none"

    • (..., Lx, Ly, Lz, Nb) for reduction="diag"

    • (..., Lx, Ly, Lz) for "sum", "trace", or "mean"

    For a single input matrix (Ns, Ns), the leading ... is absent.

  • k_grid (np.ndarray) – Cartesian sampled k-grid with shape (Lx, Ly, Lz, 3).

  • k_frac (np.ndarray) – Fractional sampled k-grid with shape (Lx, Ly, Lz, 3).

Notes

Use reduction="none" when sublattice-resolved information matters. Use one of the scalar reductions when you want a single value per momentum that can be fed directly into bz_path_data.

The default norm="none" preserves the existing unit-cell normalization. For comparisons against structure factors built from Fourier-transformed site operators, norm="site" is typically the physically relevant choice.

Examples

>>> Sq, k_grid, k_frac  = lattice.structure_factor(corr_zz, reduction="sum")
>>> path                = lattice.bz_path_data(k_grid, k_frac, Sq, path=['Gamma', 'K', 'M', 'Gamma'])
>>>
>>> # Frequency-resolved data with shape (Nw, Ns, Ns)
>>> Sqw, k_grid, k_frac = lattice.structure_factor(corr_zz_w, reduction="sum")
>>> # Sqw has shape (Nw, Lx, Ly, Lz)
summary_string(*, precision: int = 3) str[source]

Return a textual summary of lattice metadata.

real_space_table(*, max_rows: int = 10, precision: int = 3) str[source]

Return a formatted table of real-space vectors.

reciprocal_space_table(*, max_rows: int = 10, precision: int = 3) str[source]

Return a formatted table of reciprocal-space vectors.

brillouin_zone_overview(*, precision: int = 3) str[source]

Return a textual overview of the sampled Brillouin zone.

describe(*, precision: int = 3, max_rows: int = 10, include_vectors: bool = True, include_reciprocal: bool = True, include_brillouin_zone: bool = True) str[source]

Combine multiple presentation helpers into a single multi-section string.

plot_real_space(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a real-space scatter plot.

plot_reciprocal_space(**kwargs)[source]

Scatter-plot of reciprocal lattice vectors (k-points).

Parameters mirror plot_real_space()

latticeLattice

The lattice object to plot.

axAxes, optional

Matplotlib axes to plot on. If None, a new figure is created.

show_indicesbool, default=False

If True, annotate each k-point with its index.

show_axesbool, default=True

If False, hides the coordinate axes.

colorstr, default=”C1”

Color of the k-point markers.

markerstr, default=”o”

Marker style.

figsizetuple, optional

Figure size in inches (width, height).

titlestr, optional

Title of the plot.

elev, azimfloat, optional

Elevation and azimuth angles for 3D plots.

extend_kpointsbool, default=False

If True, draw translated reciprocal-space copies around the original mesh.

extend_copiesint or iterable of int, default=2

Number of copies per reciprocal direction used when extend_kpoints=True. Scalars are applied to all active reciprocal directions.

extend_tolfloat, default=1e-10

Tolerance used to identify which extended points are already present in the original reciprocal mesh.

**scatter_kwargs

Include: - point_edgecolor: Color of the marker edges (default “white”). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default “C2”). - edgecolor_extended: Edge color for translated copies (default “gray”). - marker_extended: Marker for translated copies (default marker). - Any other valid arguments for ax.scatter.

plot_brillouin_zone(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a Brillouin zone plot.

Parameters:
  • lattice (Lattice) – The lattice object containing k-vectors.

  • ax (Axes, optional) – Matplotlib axes to plot on. If None, a new figure is created.

  • facecolor (str, default="tab:blue") – Color to fill the Brillouin Zone area.

  • edgecolor (str, default="black") – Color for the Brillouin Zone boundary.

  • alpha (float, default=0.25) – Transparency level for the Brillouin Zone fill.

  • figsize (tuple, optional) – Figure size in inches (width, height).

  • title (str, optional) – Title of the plot.

  • elev (float, optional) – Elevation and azimuth angles for 3D plots.

  • azim (float, optional) – Elevation and azimuth angles for 3D plots.

plot_structure(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a detailed lattice structure plot.

Parameters:
  • show_indices (bool) – If True, annotates nodes with their site indices.

  • highlight_boundary (bool) – If True, draws boundary nodes with a distinct color/edge.

  • show_axes (bool) – If False, hides the coordinate axes for a cleaner diagram.

  • partition_colors (tuple of str, optional) – Colors to use for bipartite/sublattice coloring. If provided, nodes are colored based on sublattice parity.

  • show_periodic_connections (bool) – If True, indicates wrap-around connections textually or graphically.

  • show_primitive_cell (bool) – If True, overlays the primitive unit cell vectors/box.

  • (e.g. (... other kwargs passed to the underlying plotting function)

  • size (node)

  • map (color)

  • etc.)

  • details. (see plot_lattice_structure() for)

plot_high_symmetry(**kwargs)[source]

Convenience wrapper for plotting the Brillouin zone, high-symmetry path, and sampled reciprocal mesh.

Parameters:
  • path (list[str], str, or iterable[(label, frac)], optional) – High-symmetry path specification. If omitted, the lattice default path is used.

  • show_kpoints (bool, default=True) – Draw sampled reciprocal-space mesh points.

  • show_bz (bool, default=True) – Draw the first Brillouin zone.

  • show_path (bool, default=True) – Draw the ideal high-symmetry path.

  • show_matched_kpoints (bool, default=True) – Highlight sampled k-points whose distance to the path is within the matching tolerance.

  • points_per_seg (int, default=40) – Number of interpolation points per path segment for the ideal path.

  • path_match_tol (float, optional) – Distance tolerance used when highlighting mesh points near the drawn path.

  • extend (bool, default=False) – Draw translated copies of the sampled k-mesh.

  • extend_copies (int or iterable[int], optional) – Number of reciprocal-cell copies per direction. In 2D, extend_copies=1 includes the first shell around the first Brillouin zone and extend_copies=2 includes the second shell as well.

  • show_background_bz (bool, default=False) – Draw translated Brillouin-zone copies behind the mesh.

  • hs_plot ({"none", "markers", "labels", "both"}, default="markers") – Whether to draw exact high-symmetry markers, labels, or both.

  • legend_kwargs (dict, optional) – Extra keyword arguments passed to axis.legend.

  • **kwargs – Additional style overrides forwarded to plot_high_symmetry_points.

property plot

Access plotting utilities for this lattice.

Returns a LatticePlotter instance providing methods: - real_space(**kwargs) : Scatter plot of sites. - reciprocal_space(**kwargs) : Scatter plot of reciprocal lattice vectors. - brillouin_zone(**kwargs) : Visualization of the Brillouin Zone. - structure(**kwargs) : Detailed connectivity plot with boundaries.

Example

>>> lat.plot.structure(show_indices=True, highlight_boundary=True)
>>> lat.plot.brillouin_zone()
class general_python.lattices.LatticeBC(*values)[source]

Bases: Enum

Enumeration for the boundary conditions in the lattice model.

PBC = 1
OBC = 2
MBC = 3
SBC = 4
TWISTED = 5
class general_python.lattices.LatticeDirection(*values)[source]

Bases: Enum

Enumeration for the lattice directions

X = 0
Y = 1
Z = 2
class general_python.lattices.LatticeType(*values)[source]

Bases: Enum

Contains all the implemented lattice types for the lattice model.

SQUARE = 1
HEXAGONAL = 2
HONEYCOMB = 3
GRAPH = 4
TRIANGULAR = 5
CHAIN = 6
class general_python.lattices.SquareLattice(lx=1, ly=1, lz=1, dim=None, bc=pbc, **kwargs)[source]

Bases: Lattice

Square Lattice Class for 1D, 2D, and 3D lattices.

The lattice vectors are defined as: - a = [1, 0, 0], - b = [0, 1, 0], - c = [0, 0, 1]

and the reciprocal lattice vectors are: - a* = [2*pi, 0, 0], - b* = [0, 2*pi, 0], - c* = [0, 0, 2*pi]

Input/output contracts

  • Constructor expects integer dimensions lx, ly, lz (as applicable to dim).

  • bc must be a LatticeBC enum or compatible string/int identifier.

  • Coordinates are returned as floating-point arrays of shape (Ns, dim).

  • Neighbor lists are lists of lists, where neighbors[i] contains indices of neighbors of site i.

Shape and dtype expectations

  • coordinates: Real-valued array of shape (Ns, dim).

  • kvectors: Real-valued array of shape (Ns, 3) (or dim).

  • Neighbor indices are integers in range [0, Ns).

High-symmetry points in the Brillouin zone: - 1D: Γ (0) -> X (Pi) -> Γ (2Pi) - 2D: Γ (0,0) -> X (Pi,0) -> M (Pi,Pi) -> Γ (0,0) - 3D: Γ -> X -> M -> Γ -> R -> X

__init__(lx=1, ly=1, lz=1, dim=None, bc=pbc, **kwargs)[source]

Initializer of the square lattice

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the square/cubic lattice.

Returns:

High-symmetry points with default path based on dimension: - 1D: Γ -> X -> Γ (zone boundary at Pi) - 2D: Γ -> X -> M -> Γ (standard square BZ path) - 3D: Γ -> X -> M -> Γ -> R -> X (standard cubic BZ path)

Return type:

HighSymmetryPoints

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a square/cubic special point is present in the current k-grid.

get_k_vec_idx(sym=False)[source]

Returns the indices of kvectors, considering symmetry reduction.

calculate_norm_sym()[source]

Calculates the normalization factors considering symmetric momenta.

site_index(x, y, z)[source]

Convert (x, y, z) coordinates to a site index. :param x: x-coordinate :type x: int :param y: y-coordinate :type y: int :param z: z-coordinate :type z: int

get_real_vec(x: int, y: int, z: int)[source]

Returns the real vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Returns the Euclidean norm of a real-space vector.

get_nn_direction(site: int, direction: LatticeDirection)[source]

Returns nearest neighbors in a given direction (X, Y, Z). :param site: Site index :type site: int :param direction: Direction to get the nearest neighbors :type direction: LatticeDirection

get_nn_forward_num_max()[source]

Maximum number of forward nearest-neighbor bonds per square-lattice site.

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

get_nnn_forward(site, num: int = -1)[source]

Returns the forward next-nearest neighbors of a given site

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) for 1D, 2D, and 3D square lattices. Also calculates the forward nearest neighbors (NNF).

Parameters:
  • pbcx (-) – Periodic boundary condition in x direction

  • pbcy (-) – Periodic boundary condition in y direction

  • pbcz (-) – Periodic boundary condition in z direction

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) for 1D, 2D, and 3D square lattices. Also calculates the forward next-nearest neighbors (NNNF).

Parameters:
  • pbcx (-) – Periodic boundary condition in x direction

  • pbcy (-) – Periodic boundary condition in y direction

  • pbcz (-) – Periodic boundary condition in z direction

static dispersion(k)[source]

Simple nearest-neighbour tight-binding/spin-wave-like dispersion for the square lattice. Accepts k as (2,) or (…,2) array and returns same-shaped scalar or array of energies.

class general_python.lattices.HexagonalLattice(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Bases: Lattice

Armchair-oriented hexagonal (honeycomb) lattice up to 3 dimensions.

The lattice is constructed so that armchair edges lie along the horizontal (x) axis, giving a rectangular bounding box aligned with the coordinate system. Two sites per unit cell (A / B sublattices).

Parameters:
  • dim (int) – Lattice dimensionality (1, 2, or 3).

  • lx (int) – Number of unit cells along each lattice-vector direction.

  • ly (int) – Number of unit cells along each lattice-vector direction.

  • lz (int) – Number of unit cells along each lattice-vector direction.

  • bc (str or LatticeBC) – Boundary conditions ('pbc', 'obc', etc.).

  • **kwargs – Forwarded to Lattice (e.g. flux).

__init__(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

General Lattice class. This class contains the general lattice model.

Parameters:
  • dim (int, optional) – Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.

  • lx (int, optional) – Length of the lattice in the x-direction.

  • ly (int, optional) – Length of the lattice in the y-direction.

  • lz (int, optional) – Length of the lattice in the z-direction.

  • bc (str, optional) – Boundary conditions (e.g., ‘PBC’, ‘OBC’).

  • adj_mat (np.ndarray, optional) – Adjacency matrix for the lattice.

  • flux (np.ndarray, optional) – Flux piercing the boundaries. This can be a dictionary specifying the flux in each direction, or a single value applied to all directions. Importantly, this automatically implies TWISTED boundary conditions, so the bc parameter can be left as None or set to ‘TWISTED’ for clarity.

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the hexagonal BZ.

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a hexagonal special point is present in the current k-grid.

static dispersion(k, a=1.0)[source]

Hexagonal/honeycomb (armchair) nearest-neighbour dispersion magnitude. Uses the three NN vectors defined in the hexagonal geometry.

get_real_vec(x: int, y: int, z: int)[source]

Real-space position for stored coordinate tuple (x, y, z).

The base class calculate_coordinates already stores proper vectors via _a1, _a2, _a3, _basis. This helper is kept for backwards compatibility and any custom coordinate look-ups.

get_norm(x: int, y: int, z: int)[source]

Euclidean norm of the real-space vector.

get_nn_direction(site: int, direction: LatticeDirection)[source]

Return the nearest-neighbour in the specified bond direction.

Mapping:

X -> intra-cell bond (A<->B within same cell) Y -> bond along a2 Z -> bond along a1

bond_type(s1: int, s2: int) int[source]

Return directional bond type (X_BOND, Y_BOND, Z_BOND) or -1.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculate nearest neighbours for the armchair hexagonal lattice.

Each site has exactly 3 nearest neighbours (honeycomb coordination).

Bond convention (for an A-site at cell (cx, cy)):

[X_BOND] intra-cell -> B(cx, cy ) [Y_BOND] along -a2 -> B(cx, cy-1) [Z_BOND] along a1 - a2 -> B(cx+1, cy-1)

For a B-site at cell (cx, cy):

[X_BOND] intra-cell -> A(cx, cy ) [Y_BOND] along +a2 -> A(cx, cy+1) [Z_BOND] along -a1+a2 -> A(cx-1, cy+1)

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculate next-nearest neighbours for the armchair hexagonal lattice.

NNN connect sites within the same sublattice. Each site has 6 NNN in the full 2D honeycomb; finite clusters may have fewer depending on boundary conditions.

NNN displacements (same sublattice) in cell coordinates:

+a1, -a1, +a2, -a2, +(a1-a2), -(a1-a2)
i.e.  (+-1, 0), (0, +-1), (+-1, -+1)
get_sym_pos(x, y, z)[source]

Map coordinates to a position in the symmetry norm array.

For the armchair lattice with 2 sublattices, y ranges over 0 .. 2*Ly - 1 (cell index ×2 + sublattice).

get_sym_pos_inv(x, y, z)[source]

Inverse of get_sym_pos().

symmetry_checker(x, y, z)[source]

Always returns True (placeholder for future symmetry calculations).

calculate_plaquettes()[source]

Calculate hexagonal plaquettes (6-site loops) of the armchair lattice.

Each plaquette is a list of 6 site indices forming a closed loop around one hexagonal face. Only unique plaquettes are returned.

class general_python.lattices.HoneycombLattice(lx=3, ly=1, *, lz=1, dim=2, bc='pbc', **kwargs)[source]

Bases: Lattice

Implementation of the Honeycomb Lattice.

The honeycomb lattice is a 2D lattice with a hexagonal structure. The lattice consists of two sublattices (A and B) arranged in a hexagonal pattern. Nearest and next-nearest neighbors are computed based on a hexagonal unit cell.

High-symmetry points in the Brillouin zone: - Gamma:

Zone center at (0, 0)

  • K:

    Dirac point at (2/3, 1/3) - hosts linear band crossings in graphene

  • K’:

    Other Dirac point at (1/3, 2/3)

  • M:

    Edge midpoint at (1/2, 0)

Default path: Γ -> K -> M -> Γ

References

Lx, Ly, Lz

Number of lattice sites in x, y, and z directions.

bc

Boundary condition (e.g. PBC or OBC).

a, c

Lattice parameters.

vectors

Primitive lattice vectors.

kvectors

Reciprocal lattice vectors.

rvectors

Real-space vectors.

__init__(lx=3, ly=1, *, lz=1, dim=2, bc='pbc', **kwargs)[source]

Initialize a honeycomb lattice.

Parameters:
  • (int) (lz) – Lattice dimension (1, 2, or 3)

  • lx – Lattice sizes in x, y, z directions.

  • ly – Lattice sizes in x, y, z directions.

  • (int) – Lattice sizes in x, y, z directions.

  • bc – Boundary condition (e.g. LatticeBC.PBC or LatticeBC.OBC)

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the honeycomb lattice.

Returns:

High-symmetry points for the hexagonal Brillouin zone: - Γ (Gamma): Zone center (0, 0) - K: Dirac point at (2/3, 1/3) - hosts linear band crossings - K’: Other Dirac point at (1/3, 2/3) - M: Edge midpoint at (1/2, 0)

Default path: Γ -> K -> M -> Γ

Return type:

HighSymmetryPoints

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a honeycomb special point is present in the current k-grid.

get_real_vec(x: int, y: int, z: int = 0)[source]

Returns the real-space vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Returns the Euclidean norm of the real-space vector.

get_nn_direction(site, direction)[source]

Returns the nearest neighbor in the specified direction.

For the honeycomb lattice, we choose a mapping:

LatticeDirection.X -> neighbor at index 0 of _nn[site] LatticeDirection.Y -> neighbor at index 1 of _nn[site] LatticeDirection.Z -> neighbor at index 2 of _nn[site]

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbor for the given site.

(For honeycomb, this could be defined as the first neighbor in a chosen ordering.)

get_nnn_forward(site: int, num: int = -1)[source]

Returns the forward next-nearest neighbor for the given site.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) using boundary conditions.

The implementation uses a helper function to apply periodic or open boundary conditions. For 2D, for example, we use a different treatment on even and odd indices.

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) of the honeycomb lattice.

NNN are second-nearest neighbors, connecting sites on the same sublattice. For sublattice A (even sites), the three NNN directions are obtained by composing two consecutive NN hops:

NNN_1: Y-bond then X-bond^{-1} → cell shift (0, -1) [down in y] NNN_2: Z-bond then Y-bond^{-1} → cell shift (-1, 0) [left in x] NNN_3: X-bond then Z-bond^{-1} → cell shift (+1, +1) [diagonal]

For sublattice B (odd sites), the shifts are inverted.

calculate_norm_sym()[source]

Uses base implementation.

get_sym_pos(x, y, z)[source]

Returns the symmetry-transformed position.

get_sym_pos_inv(x, y, z)[source]

Returns the inverse symmetry-transformed position.

bond_type(s1: int, s2: int) int[source]

Return directional bond type (X_BOND_NEI, Y_BOND_NEI, Z_BOND_NEI) or -1.

calculate_plaquettes(open_bc: bool | None = None)[source]

Calculate the hexagonal plaquettes of the honeycomb lattice.

static dispersion(k, a=1.0)[source]

Honeycomb (graphene-like) nearest-neighbour dispersion magnitude. Computes |f(k)| where f = sum_{δ} exp(-i k·δ) for the three A->B vectors used by this lattice implementation.

class general_python.lattices.TriangularLattice(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Bases: Lattice

Implementation of the Triangular Lattice (2D). The triangular lattice is a 2D Bravais lattice with each site having 6 nearest neighbors.

__init__(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Initialize a Triangular Lattice.

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the triangular lattice Brillouin zone.

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a triangular special point is present in the current k-grid.

get_real_vec(x: int, y: int, z: int)[source]

Returns the real-space vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Return the Euclidean norm of integer coordinate offsets.

get_nn_direction(site, direction)[source]

Return the nearest neighbor associated with a lattice direction.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) for the triangular lattice.

Each site has 6 nearest neighbors in 2D corresponding to the six lattice-vector displacements:

+a1, -a1, +a2, -a2, +(a1-a2), -(a1-a2)
i.e. cell-coordinate offsets:

(+1,0), (-1,0), (0,+1), (0,-1), (+1,-1), (-1,+1)

Forward bonds are those connecting to a site with strictly higher index so that each bond is counted exactly once.

Parameters:
  • pbcx (bool) – Whether periodic boundary conditions apply along each direction.

  • pbcy (bool) – Whether periodic boundary conditions apply along each direction.

  • pbcz (bool) – Whether periodic boundary conditions apply along each direction.

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) for the triangular lattice.

NNN are at cell-coordinate offsets:

(+2,0), (-2,0), (0,+2), (0,-2), (+2,-2), (-2,+2),
(+1,+1), (-1,-1), (+2,-1), (-2,+1), (+1,-2), (-1,+2)
site_index(x, y, z)[source]

Convert integer cell coordinates to a linear site index.

static dispersion(k, a=1.0)[source]

Simple triangular-lattice dispersion approximation: ω(k) = 2J * [3 - cos(k·a1) - cos(k·a2) - cos(k·(a1 - a2))] where a1=(a,0), a2=(a/2, √3 a/2). Accepts k as (2,) or (…,2).

class general_python.lattices.GraphLattice(adjacency: ndarray | Sequence[Sequence[float]], coords: ndarray | Sequence[Sequence[float]] | None = None, bc: str | LatticeBC | None = obc, flux: float | BoundaryFlux | Mapping[str | LatticeDirection, float] | None = None, metadata: GraphMetadata | None = None, **kwargs)[source]

Bases: Lattice

Lattice backed by an adjacency matrix instead of structural formulas.

Parameters:
  • adjacency – Square (Ns    imes Ns) adjacency matrix. Non-zero entries denote connected pairs. The absolute value of the weight is used to rank neighbor order (highest -> nearest).

  • coords – Optional vertex embedding of shape (Ns, dim). Defaults to a 1D chain ordering. Used for plotting and distance heuristics.

  • bc – Boundary condition descriptor. Defaults to open boundaries.

  • flux – Optional boundary flux phases - forwarded to Lattice.

  • metadata – Auxiliary metadata (name, tags, free-form info).

__init__(adjacency: ndarray | Sequence[Sequence[float]], coords: ndarray | Sequence[Sequence[float]] | None = None, bc: str | LatticeBC | None = obc, flux: float | BoundaryFlux | Mapping[str | LatticeDirection, float] | None = None, metadata: GraphMetadata | None = None, **kwargs)[source]

General Lattice class. This class contains the general lattice model.

Parameters:
  • dim (int, optional) – Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.

  • lx (int, optional) – Length of the lattice in the x-direction.

  • ly (int, optional) – Length of the lattice in the y-direction.

  • lz (int, optional) – Length of the lattice in the z-direction.

  • bc (str, optional) – Boundary conditions (e.g., ‘PBC’, ‘OBC’).

  • adj_mat (np.ndarray, optional) – Adjacency matrix for the lattice.

  • flux (np.ndarray, optional) – Flux piercing the boundaries. This can be a dictionary specifying the flux in each direction, or a single value applied to all directions. Importantly, this automatically implies TWISTED boundary conditions, so the bc parameter can be left as None or set to ‘TWISTED’ for clarity.

site_index(x: int, y: int = 0, z: int = 0) int[source]

Return linear index for coordinates. For graph lattices we interpret the first argument as the explicit vertex index.

get_real_vec(x: int, y: int = 0, z: int = 0)[source]

Returns the real vector given the coordinates. Uses the lattice constants.

get_norm(x: int, y: int = 0, z: int = 0)[source]

Returns the norm of the vector given the coordinates.

get_nn_direction(site, direction)[source]

Returns the nearest neighbors in a given direction.

calculate_coordinates()[source]

Calculates the coordinates for each lattice site in up to 3D.

Each site index i corresponds to:

cell = i // n_basis sub = i % n_basis

where n_basis = len(self._basis) (e.g., 2 for honeycomb).

Works for any lattice with defined self._a1, _a2, _a3 and self._basis list.

calculate_reciprocal_vectors()[source]

Calculates the reciprocal lattice vectors based on the primitive vectors. Always returns 3D vectors (padding with zeros for lower dimensions).

Returns: - k1, k2, k3 : Reciprocal lattice vectors (always 3D)

calculate_r_vectors()[source]

Calculates the real-space vectors (r) for each site. Must match the ordering in calculate_coordinates().

calculate_k_vectors()[source]

Calculates the allowed reciprocal-space k-vectors (momentum grid) consistent with the lattice size and primitive reciprocal vectors.

When boundary fluxes are present (TWISTED BC), the fractional coordinates are shifted by \(\phi_\mu / (2\pi L_\mu)\) in each direction, so that the Bloch condition matches the twisted boundary.

The sampling follows the same fftfreq ordering used by the Bloch transform (Γ at index [0,0,0], followed by positive frequencies and finally the negative branch). This keeps the analytic grids aligned with the numerically constructed H(k) blocks.

calculate_dft_matrix()[source]

Bloch-type DFT matrix on the site basis.

Indices:

i = (R, beta) real-space cell R and sublattice beta n = (k, alpha) k-point k and sublattice alpha

Elements: $$

F_{(k,alpha),(R,beta)} =

1/sqrt(Nc) * delta_{alpha,beta} * exp(-i k . R).

$$ This is unitary: $$

F^dagger F = I_{Ns}, F F^dagger = I_{Ns},

$$ where Ns = Nc * Nb is the total number of sites, Nc is the number of unit cells, and Nb is the number of sublattices.

Important

When boundary fluxes are present (TWISTED BC), the k-grid used to build the DFT matrix is shifted by phi_mu / (2 pi L_mu) in each direction, exactly as in calculate_k_vectors().

Note that this DFT matrix does not include basis-dependent phases (i.e., exp(-i k . r_basis)).

Calculates the Discrete Fourier Transform (DFT) matrix for the lattice. This method can be optimized using FFT (Fast Fourier Transform) in the future. Reference: https://en.wikipedia.org/wiki/DFT_matrix

Parameters:
  • (bool) (- phase)

  • Returns

  • (ndarray) (- DFT matrix)

calculate_norm_sym()[source]

Calculate a symmetry-normalization measure for each site.

Default: Euclidean norm of the coordinate vector. Override in subclasses for lattice-specific behaviour.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
property adjacency: ndarray
embedding() ndarray[source]

Return coordinate embedding with shape (Ns, dim).

edge_list(*, include_weights: bool = False, threshold: float = 0.0) List[Tuple][source]

Return unique undirected edges from adjacency.

set_edge(i: int, j: int, weight: float = 1.0, *, symmetric: bool = True, add: bool = False, reinitialize: bool = True) None[source]

Set or add an edge coupling weight.

add_edge(i: int, j: int, weight: float = 1.0, *, reinitialize: bool = True) None[source]

Set an undirected edge weight.

add_to_edge(i: int, j: int, weight: float, *, reinitialize: bool = True) None[source]

Increment an undirected edge weight by weight.

remove_edge(i: int, j: int, *, reinitialize: bool = True) None[source]

Remove edge by setting weight to zero.

clear_edges(*, reinitialize: bool = True) None[source]

Remove all couplings.

add_couplings(couplings: Iterable[Tuple[int, int, float]], *, mode: str = 'set', reinitialize: bool = True) None[source]

Bulk add/update couplings from (i, j, weight) tuples.

mode: - 'set': overwrite values - 'add': increment values

class general_python.lattices.HighSymmetryPoints(points: Dict[str, ~general_python.lattices.tools.lattice_kspace.HighSymmetryPoint]=<factory>, _default_path: List[str] = <factory>)[source]

Bases: object

Collection of high-symmetry points for a lattice type.

Provides named access to standard high-symmetry points and defines default paths through the Brillouin zone.

Example

>>> pts = HighSymmetryPoints.square_2d()
>>> print(pts.Gamma)  # HighSymmetryPoint for Gamma
>>> print(pts.default_path())  # ['Gamma', 'X', 'M', 'Gamma']
>>> print(pts.get_path_points(['Gamma', 'M']))  # Custom path
points: Dict[str, HighSymmetryPoint]
add(point: HighSymmetryPoint) HighSymmetryPoints[source]

Add a high-symmetry point.

get(name: str) HighSymmetryPoint | None[source]

Get a point by name, returns None if not found.

resolve_label(name: str) str | None[source]

Resolve a label or alias to a canonical key in self.points.

Examples: "Γ" -> "Gamma", "K'" -> "Kp".

resolve(name: str) HighSymmetryPoint | None[source]

Resolve a label/alias and return the matching point object.

property default_path: List[str]

Return the default path through high-symmetry points.

get_path_points(path_labels: List[str]) List[Tuple[str, List[float]]][source]

Get path as list of (label, frac_coords) tuples.

Parameters:

path_labels (List[str]) – List of point labels defining the path (e.g., [‘Gamma’, ‘X’, ‘M’, ‘Gamma’])

Returns:

Path suitable for brillouin_zone_path() function

Return type:

List[Tuple[str, List[float]]]

get_default_path_points() List[Tuple[str, List[float]]][source]

Get the default path as list of (label, frac_coords) tuples.

classmethod chain_1d() HighSymmetryPoints[source]

High-symmetry points for 1D chain.

classmethod square_2d() HighSymmetryPoints[source]

High-symmetry points for 2D square lattice.

classmethod cubic_3d() HighSymmetryPoints[source]

High-symmetry points for 3D cubic lattice.

classmethod triangular_2d() HighSymmetryPoints[source]

High-symmetry points for 2D triangular lattice.

classmethod honeycomb_2d() HighSymmetryPoints[source]

High-symmetry points for honeycomb/graphene lattice.

classmethod hexagonal_2d() HighSymmetryPoints[source]

High-symmetry points for 2D hexagonal lattice (same as honeycomb).

__init__(points: Dict[str, ~general_python.lattices.tools.lattice_kspace.HighSymmetryPoint]=<factory>, _default_path: List[str] = <factory>) None
class general_python.lattices.HighSymmetryPoint(label: str, frac_coords: Tuple[float, float, float], latex_label: str = '', description: str = '')[source]

Bases: object

A high-symmetry point in the Brillouin zone.

label

Label for the point (e.g., ‘Gamma’, ‘K’, ‘M’, ‘X’)

Type:

str

frac_coords

Fractional coordinates in reciprocal lattice units (f1, f2, f3). The actual k-vector is: k = f1*b1 + f2*b2 + f3*b3

Type:

Tuple[float, float, float]

latex_label

LaTeX-formatted label for plotting (e.g., r’$\Gamma$’)

Type:

str, optional

description

Description of the point

Type:

str, optional

label: str
frac_coords: Tuple[float, float, float]
latex_label: str = ''
description: str = ''
__contains__(coord: Tuple[float, float, float] | str) bool[source]

Check if given fractional coordinates match this point.

to_cartesian(b1: ndarray, b2: ndarray, b3: ndarray) ndarray[source]

Convert fractional coordinates to Cartesian k-vector.

as_tuple() Tuple[str, List[float]][source]

Return as (label, [f1, f2, f3]) tuple for path generation.

__init__(label: str, frac_coords: Tuple[float, float, float], latex_label: str = '', description: str = '') None
class general_python.lattices.KPathResult(k_cart: ~numpy.ndarray, k_frac: ~numpy.ndarray, k_dist: ~numpy.ndarray, labels: ~typing.List[~typing.Tuple[int, str]], values: ~numpy.ndarray, indices: ~numpy.ndarray, matched_distances: ~numpy.ndarray = <factory>, path_axis: int = 0)[source]

Bases: object

Result of extracting data along a k-path in the Brillouin zone.

This dataclass holds all information needed for band structure plots and analysis along a high-symmetry path.

k_cart

Cartesian k-vectors along the path

Type:

np.ndarray, shape (Npath, 3)

k_frac

Fractional k-vectors along the path (in reciprocal lattice units)

Type:

np.ndarray, shape (Npath, 3)

k_dist

Cumulative distance along the path for x-axis plotting

Type:

np.ndarray, shape (Npath,)

labels

List of (index, label) pairs for high-symmetry points

Type:

List[Tuple[int, str]]

values

Data values along the path. The path axis is path_axis. Examples: (Npath,), (Npath, n_bands), (Nw, Npath), or (Nw, Npath, n_bands).

Type:

np.ndarray

indices

Indices into the original k-grid for each path point. Use to map path data back to the full k-grid.

Type:

np.ndarray, shape (Npath,), dtype=int

matched_distances

Distance from ideal path point to matched grid point (for quality check)

Type:

np.ndarray, shape (Npath,)

Example

>>> result = lattice.extract_kpath_data(energies, path='SQUARE_2D')
>>> plt.plot(result.k_dist, result.values)
>>> for idx, label in result.labels:
...     plt.axvline(result.k_dist[min(idx, len(result.k_dist)-1)], label=label)
k_cart: ndarray
k_frac: ndarray
k_dist: ndarray
labels: List[Tuple[int, str]]
values: ndarray
indices: ndarray
matched_distances: ndarray
path_axis: int = 0
property n_points: int

Number of points along the path.

property n_bands: int

Number of trailing channels per path point, flattening axes after path_axis.

property label_positions: ndarray

X-axis positions (k_dist values) of the high-symmetry point labels.

property label_texts: List[str]

Just the label strings for plotting.

unique_indices() ndarray[source]

Return unique k-point indices (no duplicates from path segments).

max_match_distance() float[source]

Maximum distance from path to matched grid point.

__init__(k_cart: ~numpy.ndarray, k_frac: ~numpy.ndarray, k_dist: ~numpy.ndarray, labels: ~typing.List[~typing.Tuple[int, str]], values: ~numpy.ndarray, indices: ~numpy.ndarray, matched_distances: ~numpy.ndarray = <factory>, path_axis: int = 0) None
class general_python.lattices.StandardBZPath(*values)[source]

Bases: Enum

Enumeration of standard high-symmetry paths in the Brillouin zone.

We define the k-space paths in a general representation of momentum vectors: [ k = f1 * b1 + f2 * b2 + f3 * b3, ] where (b1, b2, b3) are the reciprocal lattice vectors, and (f1, f2, f3) are the fractional coordinates, f_i = n_i / N_i, with n_i = 0, 1, …, N_i - 1 for each direction i.

Each value returns a list of (label, fractional_coord) pairs. The fractional coordinates are expressed in units of reciprocal lattice vectors.

Example: >>> path = StandardBZPath.SQUARE_2D.value >>> for label, coord in path: … print(f”{label}: {coord}”) G: [0.0, 0.0, 0.0] X: [0.5, 0.0, 0.0] M: [0.5, 0.5, 0.0] G: [0.0, 0.0, 0.0]

CHAIN_1D = [('0', [0.0, 0.0, 0.0]), ('\\pi', [0.5, 0.0, 0.0]), ('2\\pi', [1.0, 0.0, 0.0])]
SQUARE_2D = [('$\\Gamma$', [0.0, 0.0, 0.0]), ('$X$', [0.5, 0.0, 0.0]), ('$M$', [0.5, 0.5, 0.0]), ('$\\Gamma$', [0.0, 0.0, 0.0])]
TRIANGULAR_2D = [('$\\Gamma$', [0.0, 0.0, 0.0]), ('$M$', [0.5, 0.0, 0.0]), ('$K$', [0.3333333333333333, 0.3333333333333333, 0.0]), ('$\\Gamma$', [0.0, 0.0, 0.0])]
CUBIC_3D = [('$\\Gamma$', [0.0, 0.0, 0.0]), ('$X$', [0.5, 0.0, 0.0]), ('$M$', [0.5, 0.5, 0.0]), ('$R$', [0.5, 0.5, 0.5]), ('$\\Gamma$', [0.0, 0.0, 0.0])]
HONEYCOMB_2D = [('$\\Gamma$', [0.0, 0.0, 0.0]), ('$K$', [0.6666666666666666, 0.3333333333333333, 0.0]), ('$M$', [0.5, 0.0, 0.0]), ('$\\Gamma$', [0.0, 0.0, 0.0])]
general_python.lattices.register_lattice(name: str, lattice_cls: Type[Any], *aliases: str, overwrite: bool = False)[source]

Register a lattice class under name and optional aliases.

general_python.lattices.available_lattices() Tuple[str, ...][source]

Return tuple of registered lattice identifiers.

general_python.lattices.choose_lattice(typek: str | None = 'square', dim: int | None = None, lx: int | None = 1, ly: int | None = 1, lz: int | None = 1, bc: str | LatticeBC | None = None, flux: float | BoundaryFlux | dict | None = None, **kwargs)[source]

Returns an instance of a lattice of the desired type.

Parameters:
  • typek (str) – Type of lattice (“square”, “hexagonal”, or “honeycomb”)

  • dim (int) – Dimension (1, 2, or 3)

  • lx (int) – Number of sites in x-direction

  • ly (int) – Number of sites in y-direction

  • lz (int) – Number of sites in z-direction (ignored if dim < 3)

  • bc – Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC)

  • flux – Optional boundary flux specification forwarded to the lattice constructor. Accepts a scalar phase, BoundaryFlux, or a mapping from directions to phases (in radians).

Returns:

An instance of the desired lattice.

Return type:

Lattice

general_python.lattices.plot_bonds(lattice: Lattice, ax: pltAxes = None, **line_kwargs) pltAxes[source]

Plot physical bonds of the lattice using primitive vectors (a1,a2,a3).

Parameters:
  • ax (Axes) – existing matplotlib Axes; new one if None.

  • include_nnn (bool) – include next-nearest bonds if True.

  • **line_kwargs – passed to ax.plot and ax.scatter.

general_python.lattices.format_lattice_summary(lattice: Lattice, *, precision: int = 3) str[source]

Produce a multi-line summary describing key lattice metadata.

general_python.lattices.format_vector_table(vectors: Iterable[Sequence[float]], *, max_rows: int = 10, precision: int = 3, column_labels: Sequence[str] | None = None, index_label: str = '#', indentation: str = '') str[source]

Return a tabular string representation of an array of vectors.

Parameters:
  • vectors – Any iterable producing coordinate sequences.

  • max_rows – Maximum number of rows to include in the formatted output.

  • precision – Number of decimal places to use when printing floating point values.

  • column_labels – Optional axis labels. Defaults to Cartesian axes based on vector dimension.

  • index_label – Label for the index column.

  • indentation – Optional indentation prefix applied to each line of the table.

general_python.lattices.format_real_space_vectors(lattice: Lattice, *, max_rows: int = 10, precision: int = 3, indentation: str = '') str[source]

Format a table of lattice real-space vectors.

general_python.lattices.format_reciprocal_space_vectors(lattice: Lattice, *, max_rows: int = 10, precision: int = 3, indentation: str = '') str[source]

Format a table of reciprocal (k-space) vectors.

general_python.lattices.format_brillouin_zone_overview(lattice: Lattice, *, precision: int = 3) str[source]

Provide a textual overview of the sampled Brillouin zone.

The function reports bounding box limits and attempts to compute the convex hull measure (length/area/volume) when SciPy is available.

class general_python.lattices.LatticePlotter(lattice: Lattice)[source]

Bases: object

Convenience wrapper bundling plotting helpers for a single lattice.

Usage:

lattice.plot.real_space() lattice.plot.structure(show_indices=True) lattice.plot.regions(regions_dict)

lattice: Lattice
real_space(**kwargs) Tuple[Figure, Axes][source]

Plot real-space sites.

reciprocal_space(**kwargs) Tuple[Figure, Axes][source]

Scatter-plot of reciprocal lattice vectors (k-points).

Parameters mirror plot_real_space()

latticeLattice

The lattice object to plot.

axAxes, optional

Matplotlib axes to plot on. If None, a new figure is created.

show_indicesbool, default=False

If True, annotate each k-point with its index.

show_axesbool, default=True

If False, hides the coordinate axes.

colorstr, default=”C1”

Color of the k-point markers.

markerstr, default=”o”

Marker style.

figsizetuple, optional

Figure size in inches (width, height).

titlestr, optional

Title of the plot.

elev, azimfloat, optional

Elevation and azimuth angles for 3D plots.

extend_kpointsbool, default=False

If True, draw translated reciprocal-space copies around the original mesh.

extend_copiesint or iterable of int, default=2

Number of copies per reciprocal direction used when extend_kpoints=True. Scalars are applied to all active reciprocal directions.

extend_tolfloat, default=1e-10

Tolerance used to identify which extended points are already present in the original reciprocal mesh.

**scatter_kwargs

Include: - point_edgecolor: Color of the marker edges (default “white”). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default “C2”). - edgecolor_extended: Edge color for translated copies (default “gray”). - marker_extended: Marker for translated copies (default marker). - Any other valid arguments for ax.scatter.

brillouin_zone(**kwargs) Tuple[Figure, Axes][source]

Plot the Brillouin Zone.

structure(**kwargs) Tuple[Figure, Axes][source]

Plot detailed lattice structure with connectivity.

Parameters:
  • ax (Axes, optional) – Matplotlib axes to plot on. If None, a new figure is created.

  • show_indices (bool) – If True, annotates nodes with their site indices.

  • highlight_boundary (bool) – If True, draws boundary nodes with a distinct color/edge.

  • show_axes (bool) – If False, hides the coordinate axes for a cleaner diagram.

  • edge_color (str) – Color of the edges.

  • node_color (str) – Color of the nodes.

  • boundary_node_color (str) – Color of the boundary node edges.

  • periodic_color (str) – Color for periodic boundary annotations.

  • open_color (str) – Color for open boundary annotations.

  • node_size (int) – Size of the node markers.

  • edge_alpha (float) – Transparency of the edges.

  • label_padding (float) – Fractional padding for node index labels.

  • boundary_offset (float) – Fractional offset for boundary annotations.

  • figsize (tuple, optional) – Figure size in inches (width, height).

  • title (str, optional) – Title of the plot.

  • title_kwargs (dict, optional) – Additional keyword arguments for the title.

  • tight_layout (bool) – If True, applies tight layout to the figure.

  • elev (float, optional) – Elevation and azimuth angles for 3D plots.

  • azim (float, optional) – Elevation and azimuth angles for 3D plots.

  • partition_colors (tuple of str, optional) – Colors to use for bipartite/sublattice coloring. If provided, nodes are colored based on sublattice parity.

  • show_periodic_connections (bool) – If True, indicates wrap-around connections textually or graphically.

  • show_primitive_cell (bool) – If True, overlays the primitive unit cell vectors/box.

  • **scatter_kwargs – Additional arguments passed to ax.scatter.

regions(regions: Dict[str, List[int]], **kwargs) Tuple[Figure, Axes][source]

Plot specific regions on the lattice.

Parameters:
  • regions (Dict[str, List[int]]) – Dictionary mapping region names to lists of site indices.

  • show_system (bool) – If True, plot all lattice sites faintly in the background.

  • system_color (str) – Color for background system sites.

  • cmap (str) – Colormap name for distinct regions.

  • blob_radius (float, optional) – If given, draw a translucent circle around each site.

  • show_bonds (bool) – If True, draw intra-region NN bonds.

  • ... (... other args mirror plot_real_space)

bz_high_symmetry(**kwargs) Tuple[Figure, Axes][source]

Plot the Brillouin zone, high-symmetry path, and sampled reciprocal mesh.

Parameters:
  • path (list[str], str, or iterable[(label, frac)], optional) – High-symmetry path specification. If omitted, the lattice default path is used.

  • show_kpoints (bool, default=True) – Draw sampled reciprocal-space mesh points.

  • show_bz (bool, default=True) – Draw the first Brillouin zone.

  • show_path (bool, default=True) – Draw the ideal high-symmetry path.

  • show_matched_kpoints (bool, default=True) – Highlight sampled k-points whose distance to the path is within the matching tolerance.

  • points_per_seg (int, default=40) – Number of interpolation points per path segment for the ideal path.

  • path_match_tol (float, optional) – Distance tolerance used when highlighting mesh points near the drawn path.

  • extend (bool, default=False) – Draw translated copies of the sampled k-mesh.

  • extend_copies (int or iterable[int], optional) – Number of reciprocal-cell copies per direction. In 2D, extend_copies=1 includes the first shell around the first Brillouin zone and extend_copies=2 includes the second shell as well.

  • show_background_bz (bool, default=False) – Draw translated Brillouin-zone copies behind the mesh.

  • hs_plot ({"none", "markers", "labels", "both"}, default="markers") – Whether to draw exact high-symmetry markers, labels, or both.

  • legend_kwargs (dict, optional) – Extra keyword arguments passed to axis.legend.

  • **kwargs – Additional style overrides forwarded to plot_high_symmetry_points.

subsystem(sites: List[int], *, show_boundary: bool = True, **kwargs) Tuple[Figure, Axes][source]

Plot a single subsystem with its boundary highlighted.

Parameters:
  • sites (list of int) – Site indices in the subsystem.

  • show_boundary (bool, default=True) – If True, highlight the boundary bonds crossing A/B.

  • **kwargs – Passed to plot_regions.

Returns:

fig, ax

Return type:

Figure, Axes

Examples

>>> lattice.plot.subsystem([0, 1, 4, 5])
>>> lattice.plot.subsystem(range(8), show_bonds=True)
sweep(direction: str | None = None, *, rectangular: bool = False, max_panels: int = 6, figsize: Tuple[float, float] | None = None, **kwargs) Tuple[Figure, ndarray][source]

Plot subsystem sweep showing cuts with different boundary sizes.

Creates a grid of subplots showing subsystems grouped by ∂A.

Parameters:
  • direction (str, optional) – Direction for sweep (‘x’, ‘y’, ‘z’). Creates full-width cuts.

  • rectangular (bool, default=False) – If True and direction is None, use rectangular subsystems (various shapes). If False, use lexicographic sweep (sequential site addition).

  • max_panels (int, default=6) – Maximum number of panels to show.

  • figsize (tuple, optional) – Figure size. Auto-computed if None.

  • **kwargs – Passed to plot_regions for each panel.

Returns:

fig, axes

Return type:

Figure, ndarray of Axes

Examples

>>> lattice.plot.sweep(rectangular=True)  # Various rectangular shapes
>>> lattice.plot.sweep(direction='x')     # Full-width column cuts
>>> lattice.plot.sweep()                  # Lexicographic sweep
__init__(lattice: Lattice) None
general_python.lattices.plot_real_space(lattice: Lattice, *, ax: Axes | None = None, show_indices: bool = False, show_axes: bool = True, color: str = 'C0', marker: str = 'o', figsize: Tuple[float, float] | None = None, fix_aspect: bool = True, title: str | None = None, title_kwargs: Dict[str, object] | None = None, tight_layout: bool = True, elev: float | None = None, azim: float | None = None, **scatter_kwargs) Tuple[Figure, Axes][source]

Scatter-plot of real-space lattice vectors.

Parameters:
  • lattice (Lattice) – The lattice object to plot.

  • ax (Axes, optional) – Matplotlib axes to plot on. If None, a new figure is created.

  • show_indices (bool, default=False) – If True, annotate each site with its index.

  • show_axes (bool, default=True) – If False, hides the coordinate axes.

  • color (str, default="C0") – Color of the site markers.

  • marker (str, default="o") – Marker style.

  • figsize (tuple, optional) – Figure size in inches (width, height).

  • fix_aspect (bool, default=True) – If True, preserve equal axis scaling in 2D plots. Set to False to let the requested figsize control the on-screen aspect.

  • title (str, optional) – Title of the plot.

  • elev (float, optional) – Elevation and azimuth angles for 3D plots.

  • azim (float, optional) – Elevation and azimuth angles for 3D plots.

  • **scatter_kwargs – Additional arguments passed to ax.scatter.

Returns:

fig, ax – The figure and axes objects.

Return type:

Tuple[Figure, Axes]

general_python.lattices.plot_reciprocal_space(lattice: Lattice, *, ax: Axes | None = None, show_indices: bool = False, show_axes: bool = True, color: str = 'C1', marker: str = 'o', figsize: Tuple[float, float] | None = None, fix_aspect: bool = True, title: str | None = None, title_kwargs: Dict[str, object] | None = None, tight_layout: bool = True, elev: float | None = None, azim: float | None = None, extend_kpoints: bool = False, extend_copies: int | Iterable[int] = 2, extend_tol: float = 1e-10, **scatter_kwargs) Tuple[Figure, Axes][source]

Scatter-plot of reciprocal lattice vectors (k-points).

Parameters mirror plot_real_space()

latticeLattice

The lattice object to plot.

axAxes, optional

Matplotlib axes to plot on. If None, a new figure is created.

show_indicesbool, default=False

If True, annotate each k-point with its index.

show_axesbool, default=True

If False, hides the coordinate axes.

colorstr, default=”C1”

Color of the k-point markers.

markerstr, default=”o”

Marker style.

figsizetuple, optional

Figure size in inches (width, height).

fix_aspectbool, default=True

If True, preserve equal axis scaling in 2D plots. Set to False to let the requested figsize control the on-screen aspect.

titlestr, optional

Title of the plot.

elev, azimfloat, optional

Elevation and azimuth angles for 3D plots.

extend_kpointsbool, default=False

If True, draw translated reciprocal-space copies around the original mesh.

extend_copiesint or iterable of int, default=2

Number of copies per reciprocal direction used when extend_kpoints=True. Scalars are applied to all active reciprocal directions.

extend_tolfloat, default=1e-10

Tolerance used to identify which extended points are already present in the original reciprocal mesh.

**scatter_kwargs

Include: - point_edgecolor: Color of the marker edges (default “white”). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default “C2”). - edgecolor_extended: Edge color for translated copies (default “gray”). - marker_extended: Marker for translated copies (default marker). - Any other valid arguments for ax.scatter.

general_python.lattices.plot_brillouin_zone(lattice: Lattice, *, ax: Axes | None = None, facecolor: str = 'tab:blue', edgecolor: str = 'black', alpha: float = 0.25, figsize: Tuple[float, float] | None = None, fix_aspect: bool = True, title: str | None = None, title_kwargs: Dict[str, object] | None = None, tight_layout: bool = True, elev: float | None = None, azim: float | None = None) Tuple[Figure, Axes][source]

Plot the Brillouin Zone approximation based on sampled k-vectors.

Parameters:
  • lattice (Lattice) – The lattice object containing k-vectors.

  • ax (Axes, optional) – Matplotlib axes to plot on. If None, a new figure is created.

  • facecolor (str, default="tab:blue") – Color to fill the Brillouin Zone area.

  • edgecolor (str, default="black") – Color for the Brillouin Zone boundary.

  • alpha (float, default=0.25) – Transparency level for the Brillouin Zone fill.

  • figsize (tuple, optional) – Figure size in inches (width, height).

  • fix_aspect (bool, default=True) – If True, preserve equal axis scaling in 2D plots. Set to False to let the requested figsize control the on-screen aspect.

  • title (str, optional) – Title of the plot.

  • elev (float, optional) – Elevation and azimuth angles for 3D plots.

  • azim (float, optional) – Elevation and azimuth angles for 3D plots.

general_python.lattices.plot_lattice_structure(lattice, **kwargs)[source]

Wrapper for the visualization module’s lattice structure plotter.

class general_python.lattices.RegionType(*values)[source]

Bases: Enum

Supported region types for LatticeRegionHandler.get_region().

HALF = 'half'
HALF_X = 'half_x'
HALF_Y = 'half_y'
HALF_Z = 'half_z'
HALF_XY = 'half_xy'
HALF_YX = 'half_yx'
QUARTER = 'quarter'
SWEEP = 'sweep'
DISK = 'disk'
SUBLATTICE = 'sublattice'
GRAPH = 'graph'
PLAQUETTE = 'plaquette'
KITAEV_PRESKILL = 'kitaev_preskill'
LEVIN_WEN = 'levin_wen'
CUSTOM = 'custom'
FRACTION = 'fraction'
class general_python.lattices.LatticeRegionHandler(lattice: Lattice)[source]

Bases: object

Handles region definitions and extractions for a Lattice.

Usage:

handler = LatticeRegionHandler(lattice)
# or equivalently via the lattice shortcut:
sites = lattice.regions.get_region('disk', origin=10, radius=3.0)
kp    = lattice.regions.get_region('kitaev_preskill', radius=5.0)
lw    = lattice.regions.get_region('levin_wen', inner_radius=2, outer_radius=5)
__init__(lattice: Lattice)[source]
lattice: Lattice
get_shortest_displacement(i: int, j: int) ndarray[source]

Compute the shortest displacement vector r_j - r_i respecting PBC.

adjacency_map(*, include_nnn: bool = False, weight_threshold: float = 0.0, use_abs_weights: bool = True) Dict[int, Set[int]][source]

Return normalized adjacency as dict[int, set[int]].

connected_subsets(*, max_size: int, nodes: List[int] | None = None, min_size: int = 1, max_regions: int | None = None, adjacency: Any | None = None) List[List[int]][source]

Enumerate connected subsets using the current lattice adjacency.

generate_regions(kind: str | RegionType = RegionType.KITAEV_PRESKILL, *, adjacency: Any | None = None, nodes: List[int] | None = None, min_size: int = 1, max_size: int = 4, size_a: Tuple[int, int] | None = None, size_b: Tuple[int, int] | None = None, size_c: Tuple[int, int] | None = None, require_connected_parts: bool = True, require_connected_union: bool = True, require_connected_complement: bool = False, require_pairwise_touch: bool | None = None, require_single_triple_junction: bool | None = None, forbid_full_system: bool = True, include_nnn: bool = False, max_regions: int | None = 128, as_region: bool = True, tripartite: bool | None = None, extra: Callable[[Region], bool] | None = None) List[Region] | List[Dict[str, List[int]]][source]

Generate many region candidates of a selected type.

For KP-like constructions this supports connected/disjoint A,B,C combinations with optional single triple-junction filtering.

get_region(kind: str | RegionType = RegionType.HALF, *, origin: int | List[float] | None = None, radius: float | None = None, direction: str | None = None, sublattice: int | None = None, sites: List[int] | None = None, depth: int | None = None, plaquettes: List[int] | None = None, configuration: int | None = None, predefined: bool | int | str | None = None, as_region: bool = True, min_size: int | None = None, **kwargs) List[int] | Dict[str, List[int]] | Dict[str, Any] | List[Dict[str, Any]] | Region | None[source]

Return site indices defining a spatial region.

Parameters:
  • kind (str or RegionType) – 'half', 'half_x', 'half_y', 'half_z', 'disk', 'quarter', 'sweep', 'sublattice', 'graph', 'plaquette', 'kitaev_preskill', 'levin_wen', 'custom'.

  • origin (int or list[float], optional) – Center of the region (site index or coordinate).

  • radius (float, optional) – Radius for 'disk' and 'kitaev_preskill' regions.

  • direction (str, optional) – Direction for 'half' cuts ('x', 'y', 'z').

  • sublattice (int, optional) – Sublattice index for 'sublattice' regions.

  • sites (list[int], optional) – Explicit list of sites for 'custom' regions.

  • depth (int, optional) – Hop-distance for 'graph' balls.

  • plaquettes (list[int], optional) – Plaquette indices for 'plaquette' regions.

  • configuration (int, optional) – Legacy predefined configuration index (1-based) for the given lattice and kind.

  • predefined (bool | int | str, optional) – Convenience selector for predefined regions. - True : list available predefined entries for given lattice/kind. - int : select predefined entry by 0-based index. - str : select predefined entry by label.

  • as_region (bool, optional) – If True (default), return Region-class objects for supported kinds. If False, keep legacy list/dict return shapes.

  • min_size (int, optional) – If provided, return None if the region (or any of its A, B, C parts) is smaller than this value.

  • (forwarded) (Keyword-only)

  • ------------------------

  • inner_radius (float) – For 'levin_wen'.

  • outer_radius (float) – For 'levin_wen'.

  • n_sectors (int) – Number of angular sectors for 'kitaev_preskill' (default 3).

  • rotation (float) – Rotation of sector boundaries (radians) for KP (default 0).

  • use_pbc (bool) – Whether to use PBC-wrapped distances for KP/LW. Default False — regions must not wrap around the torus boundary.

  • region (str) – Which single region key to return for KP/LW (e.g. 'A', 'AB'). If given, a flat list is returned instead of a dict.

Returns:

Sorted site list for simple regions, a dict of labelled site lists for topological partitions, metadata entries when predefined=True, or Region objects when as_region=True.

Return type:

Region or list[int] or dict[str, list[int]] or list[dict]

Examples

>>> lat.regions.get_region('half_x')
HalfRegion(A=[...], B=[...])
>>> lat.regions.get_region('disk', origin=10, radius=2.5)
DiskRegion(A=[...], B=[...], C=[])
>>> lat.regions.get_region('kitaev_preskill', configuration=1)
KitaevPreskillRegion(A=[...], B=[...], C=[...])
>>> lat.regions.get_region('kitaev_preskill', predefined=0)
KitaevPreskillRegion(A=[...], B=[...], C=[...])
>>> lat.regions.get_region('half_x', as_region=False)
[0, 1, 2, ...]
list_predefined(kind: str | None = None, *, lattice_type: str | Any | None = None, lx: int | None = None, ly: int | None = None, lz: int | None = None, include_region: bool = False, labels_only: bool = False) List[Any][source]

List predefined regions with optional filtering by type/size/kind.

Defaults to the current lattice type and size if no filters are provided.

list_predefined_kinds(*, lattice_type: str | Any | None = None, lx: int | None = None, ly: int | None = None, lz: int | None = None) List[str][source]

Return sorted predefined-kind names for the selected type/size filters.

list_predefined_sizes(*, kind: str | None = None, lattice_type: str | Any | None = None) List[Tuple[int, int, int]][source]

Return available (Lx, Ly, Lz) sizes for selected lattice type and optional kind.

get_predefined(kind: str, *, configuration: int | str | None = None, index: int | None = None, label: str | None = None, lattice_type: str | Any | None = None, lx: int | None = None, ly: int | None = None, lz: int | None = None, region: str | None = None, return_meta: bool = False)[source]

Fetch one predefined region by configuration/index/label with optional type/size filters.

show_predefined(kind: str | None = None, *, lattice_type: str | Any | None = None, lx: int | None = None, ly: int | None = None, lz: int | None = None, limit: int | None = 40) List[Dict[str, Any]][source]

Pretty-print predefined entries and return them.

Filters default to the current lattice type and size if not provided.

Parameters:
  • kind (str, optional) – Filter by region kind (e.g. ‘half_x’, ‘kitaev_preskill’, etc.).

  • lattice_type (str or LatticeType, optional) – Filter by lattice type (e.g. ‘square’, ‘honeycomb’).

  • lx (int, optional) – Filter by lattice size. Defaults to current lattice dimensions.

  • ly (int, optional) – Filter by lattice size. Defaults to current lattice dimensions.

  • lz (int, optional) – Filter by lattice size. Defaults to current lattice dimensions.

  • limit (int, optional) – Maximum number of entries to show (default 40). Use None for no limit.

get_entropy_cuts(cut_type: str = 'all', *, include_sublattice: bool = True, sweep_by_unit_cell: bool | None = None) Dict[str, List[int]][source]

Return canonical bipartition cuts for entanglement-entropy studies.

Supported cut types: - half_x, half_y, quarter, sublattice_A, sweep, all.

Notes

  • quarter
    • is defined as the intersection of half_x and half_y.

  • sweep
    • returns nested prefixes for scaling analyses.

  • For non-Bravais lattices (e.g. honeycomb), sweep defaults to unit-cell increments; for Bravais lattices, it defaults to site increments.

  • Cleans up duplicates when multiple cut types are requested together (e.g. all).

  • Cleans up empty regions when sublattice cuts are not available for a given lattice.

region_fraction(fraction: float | int) List[int][source]

Return a fraction of the system as a contiguous block of sites in index order.

region_half(direction: str = 'x') List[int][source]

Half-system cut along a cardinal or tilted direction.

Useful for area law scaling. Handles PBC by cutting based on coordinates relative to median.

Parameters:

direction (str) – Direction to cut (‘x’, ‘y’, ‘z’, ‘xy’, or ‘yx’). ‘xy’ is a diagonal cut along x+y, ‘yx’ is along x-y.

Returns:

The half-region as a CustomRegion object.

Return type:

Region

Example

>>> half_x_sites = lattice.regions.get_region(kind='half_x')
... [10, 11, 12, 13, 14, 15, ...]  # sites in the left half of the lattice
region_quarter() List[int][source]

Return a quarter-system region as half_x half_y.

region_sweep(*, by_unit_cell: bool | None = None) Dict[str, List[int]][source]

Return nested sweep cuts used for entropy scaling analyses.

Parameters:

by_unit_cell (bool or None) –

  • True: grow by unit cells.

  • False: grow by individual sites.

  • None: auto-select (unit cells for multi-sublattice lattices).

region_disk(center: int | List[float], radius: float, pbc: bool = False) List[int][source]

Spherical / circular region (PBC-aware).

Parameters:
  • center (int or array-like) – Site index or coordinate vector.

  • radius (float) – Inclusion radius.

region_sublattice(sub: int) List[int][source]

Return all sites belonging to a specific sublattice.

region_graph_ball(center: int, depth: int) List[int][source]

Graph-distance ball (breadth-first search).

Returns all sites within depth bonds from center (inclusive).

region_plaquettes(plaquette_ids: List[int]) List[int][source]

Region defined by a union of plaquettes.

Requires the lattice to implement calculate_plaquettes() and store _plaquettes.

region_kitaev_preskill(origin: int | List[float] | None = None, radius: float | None = None, n_sectors: int = 3, rotation: float = 0.0, use_pbc: bool = False) Dict[str, List[int]][source]

Divide the lattice into angular sectors meeting at origin.

The disk of the given radius is split into n_sectors equal pie-slices. By default (n_sectors=3, rotation=0):

A : angles in [-π, -π/3)
B : angles in [-π/3, +π/3)
C : angles in [+π/3, +π]

The rotation parameter (radians) rotates all sector boundaries counter-clockwise.

Warning

use_pbc defaults to False. For the TEE construction the regions must not wrap around the boundary — otherwise a single angular sector may pick up sites from the opposite side of the torus and the linear combination becomes ill-defined.

Parameters:
  • origin (int or array-like, optional) – Center of the disk. Defaults to the centroid of all sites.

  • radius (float, optional) – Disk radius. Defaults to min(Lx, Ly)/2 - 0.5 (or 1.0).

  • n_sectors (int) – Number of angular sectors (default 3).

  • rotation (float) – Global rotation of sector boundaries in radians (default 0).

  • use_pbc (bool) – If True, use minimum-image (PBC-wrapped) displacements when computing distances and angles. Default False — raw Euclidean distances are used so the regions cannot wrap around the boundary.

Returns:

Keys: single-sector labels 'A', 'B', … and all pairwise / triple unions ('AB', 'BC', 'ABC', …).

Return type:

dict[str, list[int]]

Notes

For the TEE one computes:

S_topo = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC

The sectors must tile a disk without overlapping interiors so that the linear combination isolates the topological contribution . The rotation parameter lets you sweep the partition to check independence.

region_levin_wen(origin: int | List[float] | None = None, inner_radius: float | None = None, outer_radius: float | None = None, use_pbc: bool = False) Dict[str, List[int]][source]

Define three concentric regions around origin.

  • A (inner disk) : distance < inner_radius

  • B (annulus) : inner_radius ≤ distance < outer_radius

  • C (exterior) : distance ≥ outer_radius

Warning

use_pbc defaults to False. Wrapping would make sites from the far side of the torus appear close to the origin, contaminating the annular regions and invalidating the TEE linear combination.

Parameters:
  • origin (int or array-like, optional) – Centre of the annuli. Defaults to the centroid of all sites.

  • inner_radius (float, optional) – Boundary between A and B. Default 1.0.

  • outer_radius (float, optional) – Boundary between B and C. Default min(Lx, Ly)/2 - 0.5.

  • use_pbc (bool) – If True, use minimum-image (PBC-wrapped) displacements. Default False — raw Euclidean distances prevent wrap-around.

Returns:

Keys: 'A', 'B', 'C' and unions 'AB', 'BC', 'AC', 'ABC'.

Return type:

dict[str, list[int]]

Notes

For the TEE:

S_topo = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC

In the original Levin-Wen paper the region outside the outer annulus is called D; here we label it C for consistency with the KP convention.

subsystem_boundary(subsystem: int | List[int] | ndarray, *, include_nnn: bool = False, return_bonds: bool = False) int | Tuple[int, List[Tuple[int, int]]][source]

Compute the boundary size dA of a subsystem.

The boundary dA is the number of bonds (edges) that cross between subsystem A and its complement B. This is the quantity that appears in the area law of entanglement entropy: S ~ dA for gapped systems.

Parameters:
  • subsystem (int, list, or array) – Specification of subsystem A: - If int: First subsystem sites (contiguous from 0) - If list/array: Explicit site indices

  • include_nnn (bool, default=False) – If True, include next-nearest-neighbor bonds in the boundary count.

  • return_bonds (bool, default=False) – If True, also return the list of boundary bonds.

Returns:

  • dA (int) – Number of bonds crossing the boundary (|dA|).

  • boundary_bonds (list of (int, int), optional) – If return_bonds=True, the list of (i, j) bonds where i in A, j in B.

Examples

>>> # Compute boundary for half-system cut
>>> lattice     = SquareLattice(Lx=4, Ly=4, bc='obc')
>>> half_sites  = lattice.regions.region_half('x').A
>>> dA          = lattice.regions.subsystem_boundary(half_sites)
>>> print(f"Boundary size: {dA}") # Should be Ly for x-cut
>>> # Get boundary bonds explicitly
>>> dA, bonds   = lattice.regions.subsystem_boundary([0, 1, 2], return_bonds=True)
>>> print(f"Boundary bonds: {bonds}")

Notes

For area-law scaling in d dimensions, S ~ L^{d-1} where L is linear size. The boundary dA counts bonds, which for regular lattices is proportional to the surface area of the subsystem.

For periodic boundary conditions, bonds wrapping around the system are counted correctly using the lattice connectivity.

subsystem_boundary_sites(subsystem: int | List[int] | ndarray, *, include_complement: bool = False, include_nnn: bool = False) List[int] | Tuple[List[int], List[int]][source]

Find sites in subsystem A that lie on the boundary.

A site in A is a “boundary site” if it has at least one neighbor in B (the complement of A).

Parameters:
  • subsystem (int, list, or array) – Specification of subsystem A.

  • include_complement (bool, default=False) – If True, also return boundary sites in B (complement).

  • include_nnn (bool, default=False) – If True, include next-nearest-neighbor connectivity.

Returns:

  • boundary_A (list of int) – Sites in A that have neighbors in B.

  • boundary_B (list of int, optional) – If include_complement=True, sites in B that have neighbors in A.

Examples

>>> half = lattice.regions.region_half('x').A
>>> boundary_sites = lattice.regions.subsystem_boundary_sites(half)
>>> print(f"Boundary sites in A: {boundary_sites}")
sweep_subsystems(direction: str | None = None, *, rectangular: bool = False, by_unit_cell: bool = True) Dict[int, List[List[int]]][source]

Generate subsystems grouped by boundary size dA.

Creates subsystems and organizes them by their boundary size (number of bonds cut). Useful for area-law scaling studies.

Parameters:
  • direction (str, optional) – Direction for sweep: ‘x’, ‘y’, ‘z’. Creates full-width cuts. If None, uses rectangular or lexicographic mode.

  • rectangular (bool, default=False) – If True (and direction is None), generate rectangular subsystems of all sizes (1x1, 1x2, 2x1, 2x2, …). Gives variety in shapes. If False, use lexicographic sweep (sequential site addition).

  • by_unit_cell (bool, default=True) – If True, grow by unit cells (for multi-sublattice lattices).

Returns:

by_dA – Subsystems grouped by boundary size: {dA: [[sites], [sites], …]}

Return type:

Dict[int, List[List[int]]]

Examples

>>> # Rectangular subsystems (various shapes)
>>> by_dA = lattice.regions.sweep_subsystems(rectangular=True)
>>> for dA, subs in sorted(by_dA.items()):
...     print(f"dA={dA}: {len(subs)} shapes")
>>> # Full-width directional cuts (constant dA)
>>> by_dA = lattice.regions.sweep_subsystems(direction='x')
>>> # Lexicographic (growing blob)
>>> by_dA = lattice.regions.sweep_subsystems()
rectangular_subsystems(*, max_subsystems: int | None = None) Dict[int, List[List[int]]][source]

Generate rectangular subsystems of various sizes, grouped by boundary dA.

Creates subsystems by taking all combinations of x and y extents, giving a variety of shapes with different boundary sizes.

Parameters:

max_subsystems (int, optional) – Maximum total number of subsystems to generate.

Returns:

by_dA – Subsystems grouped by boundary size: {dA: [[sites], …]}

Return type:

Dict[int, List[List[int]]]

Examples

>>> by_dA = lattice.regions.rectangular_subsystems()
>>> for dA, subs in sorted(by_dA.items()):
...     print(f"dA={dA}: {len(subs)} rectangles")
general_python.lattices.run_lattice_tests(dim=2, lx=5, ly=5, lz=1, bc=None, typek='square')[source]

Run automated tests for a lattice in 1D, 2D, or 3D.

Parameters:
  • dim (int) – Lattice dimension (1, 2, or 3)

  • lx (int) – Number of sites in the x-direction

  • ly (int) – Number of sites in the y-direction (ignored if dim=1)

  • lz (int) – Number of sites in the z-direction (ignored if dim < 3)

  • bc – Boundary condition (e.g., LatticeBC.PBC or LatticeBC.OBC)

  • typek (str) – Type of lattice (“square”, “hexagonal”, or “honeycomb”)

Contains the general lattice class hierarchy and helpers.

This module defines the base Lattice API used across general_python, together with utility routines for boundary handling and symmetry metadata.

Currently, up to 3-spatial dimensions are supported…

Date : 2025-02-01 Version : 2.0 ——————————————————————————

class general_python.lattices.lattice.Lattice(dim: int = None, lx: int = 1, ly: int = 1, lz: int = 1, bc: str = None, adj_mat: ndarray = None, flux: ndarray = None, *args, **kwargs)[source]

Bases: ABC

Abstract Base Class for defining lattice structures.

This class serves as the foundation for all lattice implementations in the lattices module. It handles geometry, connectivity, boundary conditions, and k-space properties.

Indexing Convention

Lattice sites are indexed linearly from 0 to Ns - 1. The mapping from spatial coordinates to linear index depends on the concrete implementation, but typically follows a row-major (lexicographic) order:

  • 1D: Left to right.

  • 2D: Bottom-left to top-right (x varies fastest).

  • 3D: Front-bottom-left to back-top-right.

Features

  • Geometry: Calculation of real-space coordinates, unit vectors, and basis vectors.

  • Connectivity: Automatic identification of Nearest Neighbors (NN) and Next-Nearest Neighbors (NNN).

  • Boundaries: Support for various boundary conditions: * PBC: Periodic Boundary Conditions (torus topology).

    • X-direction periodic, Y-direction periodic, Z-direction periodic

    • OBC: Open Boundary Conditions (hard edges). * X-direction open, Y-direction open, Z-direction open

    • MBC: Mixed Boundary Conditions (e.g., cylinder topology). * X-direction periodic, Y-direction open, Z-direction open

    • SBC: Switched Boundary Conditions (e.g. twisted cylinder). * X-direction open, Y-direction periodic, Z-direction open

    • TWISTED: Twisted Boundary Conditions with specified fluxes.

  • Reciprocal Space: Automatic calculation of reciprocal lattice vectors and Brillouin Zone paths.

  • Visualization: Integration with plotting utilities via .plot.

Ns

Total number of sites in the lattice.

Type:

int

dim

Spatial dimension of the lattice (1, 2, or 3).

Type:

int

Lx, Ly, Lz

Linear dimensions of the lattice.

Type:

int

bc

Active boundary condition.

Type:

LatticeBC

coordinates

Array of shape (Ns, 3) containing real-space coordinates of all sites.

Type:

np.ndarray

nn

Adjacency list for nearest neighbors. nn[i] is a list of neighbors for site i.

Type:

List[List[int]]

property bad_lattice_site

Bad lattice site

a = 1
b = 1
c = 1
unit_length = 1
__init__(dim: int = None, lx: int = 1, ly: int = 1, lz: int = 1, bc: str = None, adj_mat: ndarray = None, flux: ndarray = None, *args, **kwargs)[source]

General Lattice class. This class contains the general lattice model.

Parameters:
  • dim (int, optional) – Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.

  • lx (int, optional) – Length of the lattice in the x-direction.

  • ly (int, optional) – Length of the lattice in the y-direction.

  • lz (int, optional) – Length of the lattice in the z-direction.

  • bc (str, optional) – Boundary conditions (e.g., ‘PBC’, ‘OBC’).

  • adj_mat (np.ndarray, optional) – Adjacency matrix for the lattice.

  • flux (np.ndarray, optional) – Flux piercing the boundaries. This can be a dictionary specifying the flux in each direction, or a single value applied to all directions. Importantly, this automatically implies TWISTED boundary conditions, so the bc parameter can be left as None or set to ‘TWISTED’ for clarity.

__str__()[source]

String representation of the lattice

__repr__()[source]

Representation of the lattice

__len__()[source]

Length of the lattice (number of sites)

__getitem__(index: int)[source]

Get the site at the given index

__iter__()[source]

Iterate over the lattice sites

__contains__(item: int)[source]

Check if the lattice contains the given site

init(verbose: bool = False, *, force_dft: bool = False, **kwargs)[source]

Initializes the lattice object by calculating coordinates, reciprocal vectors, and neighbor lists.

This method performs the following steps: 1. Calculates the real-space coordinates, r-vectors, and k-vectors of the lattice. 2. If the number of sites (self.Ns) is less than 100, computes the discrete Fourier transform (DFT) matrix. 3. If an adjacency matrix (self._adj_mat) is provided:

  • Determines the number of sites (Ns) from the adjacency matrix.

  • For each site, identifies nearest neighbors (nn) as those connected by the highest weight in the adjacency matrix, and next-nearest neighbors (nnn) as those connected by the next highest distinct weight.

  • Stores forward neighbors (indices greater than the current site) for both nn and nnn.

  1. If no adjacency matrix is provided, calculates nearest and next-nearest neighbors using default methods.

5. Calculates normalization or symmetry properties of the lattice. This method sets up all necessary neighbor lists and lattice properties required for further computations.

get_region(kind: str | RegionType = RegionType.HALF, *, origin: int | List[float] | None = None, radius: float | None = None, direction: str | None = None, sublattice: int | None = None, sites: List[int] | None = None, depth: int | None = None, plaquettes: List[int] | None = None, **kwargs) List[int][source]

Return a list of site indices defining a spatial region.

Parameters:
  • kind (str or RegionType) – Type of region: ‘half’, ‘disk’, ‘sublattice’, ‘graph’, ‘plaquette’, ‘custom’. We also support specific half cuts like ‘half_x’, ‘half_y’, ‘half_z’ for convenience.

  • origin (int or list[float], optional) – Center of the region. Can be a site index or coordinate vector.

  • radius (float, optional) – Radius for ‘disk’ regions.

  • direction (str, optional) – Direction for ‘half’ cuts (‘x’, ‘y’, ‘z’).

  • sublattice (int, optional) – Sublattice index for ‘sublattice’ regions.

  • sites (list[int], optional) – Explicit list of sites for ‘custom’ regions.

  • depth (int, optional) – Depth/distance for ‘graph’ regions.

  • plaquettes (list[int], optional) – List of plaquette indices for ‘plaquette’ regions.

Returns:

Sorted list of site indices belonging to the region.

Return type:

list[int]

get_entropy_cuts(cut_type: str = 'all', *, include_sublattice: bool = True, sweep_by_unit_cell: bool | None = None) Dict[str, List[int]][source]

Return canonical bipartition cuts for entanglement-entropy workflows.

This is a convenience wrapper around self.regions.get_entropy_cuts().

generate_regions(kind: str | RegionType = RegionType.KITAEV_PRESKILL, **kwargs)[source]

Generate many region candidates for a selected region type.

This is a thin wrapper around self.regions.generate_regions().

property lx
property Lx
property ly
property Ly
property lz
property Lz
property area
property volume
property lxly
property lxlz
property lylz
property lxlylz
property dim
property sites
property size
property nsites
property ns
property Ns
property sites_per_cell: int

Sites per unit cell (1 for Bravais, 2 for honeycomb, etc.).

symmetry_perms(point_group: str = 'full') ndarray[source]

Generate space-group permutation table for this lattice.

Delegates to generate_space_group_perms().

When TWISTED boundary conditions are active, the point-group part is disabled (only translations are returned) because a generic flux breaks point-group symmetry unless the flux respects it.

Parameters:

point_group (str) – 'full' for maximal point group, 'translations' for translations only.

Return type:

ndarray, shape (|G|, Ns)

lattice_symmetries() Dict[str, object][source]

Return a dictionary describing the spatial symmetries of this lattice.

The information is consistent for both single-particle and many-body representations. When TWISTED boundary conditions are present the point-group part is absent (flux generically breaks it).

Returns:

Keys: - 'lattice_type' : LatticeType enum - 'sites_per_cell' : int - 'n_cells' : number of unit cells - 'dim' : spatial dimension - 'bc' : boundary condition enum - 'is_periodic' : (bool, bool, bool) per direction - 'is_twisted' : bool - 'translation_group' : ZL_x x ZL_y (as tuple (Lx, Ly)) - 'point_group' : str or None ('D4' for square Lx==Ly, etc.) - 'space_group_order' : total number of space-group elements - 'flux' : BoundaryFlux or None

Return type:

dict

symmetry_info() str[source]

Return a human-readable summary of the lattice symmetries.

Consistent for both single-particle (band-structure / Bloch) and many-body (Hilbert-space symmetry sectors) viewpoints.

Return type:

str

property a1
property a2
property a3
property k1
property b1
property k2
property b2
property k3
property b3
property n1
property n2
property n3
property basis
property multipartity
property vectors
property avec
property bvec
property dft

Return the discrete Fourier transform (DFT) matrix for the lattice.

property nn

Return the nearest-neighbor connectivity matrix for the lattice.

property bonds

Return the bond connectivity matrix for the lattice.

property nn_forward

Return the forward nearest-neighbor connectivity matrix for the lattice.

property nnn

Return the next-nearest-neighbor connectivity matrix for the lattice.

property nnn_forward

Return the forward next-nearest-neighbor connectivity matrix for the lattice.

property coordinates

Return the real-space coordinates of the lattice sites.

property subs

Return the sublattice indices of the lattice sites. For a Bravais lattice, this would simply be an array of zeros. For a non-Bravais lattice, this would indicate which sublattice each site belongs to.

property cells

Return the unit cell coordinates of the lattice sites. For a Bravais lattice, this would simply be the integer coordinates of the unit cells. For a non-Bravais lattice, this would include the basis vectors as well.

property fracs

for a square lattice, these would be (x/Lx, y/Ly, z/Lz) for each site.

Type:

Return fractional coordinates of the lattice sites. Example

property kvectors

Return the allowed k-vectors in reciprocal space for the lattice.

property rvectors

Return the allowed r-vectors in real space for the lattice.

property bc
property bc_x
property bc_y
property bc_z
property cardinality
property name
property type
sublattice(site: int) int[source]

Return the sublattice index for a given site. By default, returns 0 for all sites (single sublattice). Override in subclasses for multi-sublattice lattices.

k_vector(qx, qy=0.0, qz=0.0) ndarray[source]

Return the k-vector in Cartesian coordinates for given (qx, qy, qz) in reciprocal lattice units.

k_grid(n_k: int | Tuple[int, int, int], shift: bool | Tuple[bool, bool, bool] | None = None) ndarray[source]

Generate a full k-point grid for the given lattice.

Parameters:
  • lattice (Lattice) – Lattice object with reciprocal lattice vectors _k1, _k2, _k3.

  • n_k (Iterable[int]) –

    Number of points (Lx, Ly, Lz) along each reciprocal direction.

    We define the k-points as: k = f1 * b1 + f2 * b2 + f3 * b3, where f_i = n_i / N_i, with n_i = 0, 1, …, N_i - 1.

Returns:

k_points – Cartesian coordinates of k-points in reciprocal space.

Return type:

np.ndarray, shape (Nk, dim)

extract_momentum(eigvecs: ndarray, *, eigvals: ndarray = None, tol: float = 1e-10) ndarray[source]

Extract crystal momentum vectors k from real-space eigenvectors.

wigner_seitz_extend(k_points: ndarray, data: ndarray | None = None, *, copies: int | Iterable[int] | None = None, **kwargs) Tuple[ndarray, ndarray | None][source]

Extend k-space points and optional data across translated Brillouin zones.

The helper works for arbitrary k-space dimensions and any number of reciprocal translation vectors. Legacy b1/b2/b3 with nx/ny/nz remain supported for existing callers.

Allows to generate extended k-point grids for plotting band structures along high-symmetry paths…

Parameters:
  • k_points (ndarray, shape (N, dim)) – Array of k-points in reciprocal space to be extended.

  • data (ndarray, shape (N, ...) or None) – Optional data associated with each k-point (e.g. eigenvalues) to be extended alongside the k-points. Must have the same leading dimension as k_points.

  • copies (int or iterable of ints, optional) – Number of translated copies to generate in each reciprocal direction. If an integer is provided, the same number of copies will be generated in all directions. If an iterable is provided, it should have a length equal to the number of reciprocal lattice vectors (e.g. 3 for 3D), specifying the number of copies in each direction separately.

  • **kwargs – Additional keyword arguments to pass to the underlying ws_extend function. See its documentation for details.

Returns:

  • extended_k_points (ndarray, shape (M, dim)) – Array of extended k-points in reciprocal space, including the original points and their translated copies

  • extended_data (ndarray, shape (M, …) or None) – Extended data associated with each k-point, if the input data was provided. Otherwise, None

wigner_seitz_mask(Kx, Ky=None, Kz=None, *, shells: int = 1, tol: float = 1e-12, **kwargs) ndarray[source]

Return a boolean mask for the Wigner-Seitz cell in reciprocal space. This can be used to identify which k-points lie within the first Brillouin zone.

Parameters:
  • Kx (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • Ky (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • Kz (array-like) – Arrays of k-point coordinates in reciprocal space. This is a grid of k-points for which we want to determine if they lie within the Wigner-Seitz cell.

  • shells (int) – Number of shells of Wigner-Seitz cell to include in the mask.

  • tol (float) – Tolerance for determining if a point is within the Wigner-Seitz cell, accounting for numerical precision issues.

  • **kwargs – Additional keyword arguments to pass to the underlying ws_bz_mask function. See its documentation for details.

wigner_seitz_shifts(*, copies: int | Iterable[int] | None = None, include_origin: bool = False, tol: float = 1e-12, **kwargs) ndarray[source]

Return reciprocal-lattice translation vectors for Brillouin-zone copies.

This is the shared geometry helper for selecting or drawing translated Brillouin zones. It returns zone-center shifts only, not an extended k-mesh.

Parameters:
  • copies (int or iterable of int, optional) – Number of translated copies to generate in each reciprocal direction.

  • include_origin (bool, default=False) – Whether to include the central Brillouin zone at Gamma.

  • tol (float, default=1e-12) – Tolerance used when removing numerically duplicated shifts.

  • **kwargs – Additional keyword arguments forwarded to tools.lattice_kspace.ws_bz_shifts.

Returns:

Array of reciprocal-space translation vectors for zone copies.

Return type:

np.ndarray

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for this lattice type.

Override in subclasses to provide lattice-specific high-symmetry points. Returns None if not defined for this lattice type.

Returns:

High-symmetry points with default path, or None if not defined.

Return type:

HighSymmetryPoints or None

Example

>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> pts = lattice.high_symmetry_points()
>>> print(pts.Gamma.frac_coords)  # (0.0, 0.0, 0.0)
>>> print(pts.default_path())     # ['Gamma', 'X', 'M', 'Gamma']
default_bz_path() List[Tuple[str, List[float]]] | None[source]

Return the default Brillouin zone path for this lattice.

Returns:

Default path as list of (label, [f1, f2, f3]) tuples, or None if not defined.

Return type:

List[Tuple[str, List[float]]] or None

default_resolve_path(path: Iterable[tuple[str, Iterable[float]]] | StandardBZPath | str | List[str] | HighSymmetryPoints) List[Tuple[str, List[float]]][source]

Resolve path input to a list of (label, fractional_coord) pairs.

Parameters:
  • path (list[(label, coords)], StandardBZPath, str, List[str], or HighSymmetryPoints) – Path definition (fractional coordinates), standard enum, enum name string, list of point labels, or HighSymmetryPoints object.

  • lattice (Lattice, optional) – Lattice object used to resolve labels if path is a list of strings.

Returns:

resolved_path – Resolved path as a list of (label, fractional_coord) pairs.

Return type:

list[(label, list[float])]

Example

>>> path = _resolve_path_input("SQUARE_2D")
>>> for label, coord in path:
...     print(f"{label}: {coord}")
contains_special_point(point: str | HighSymmetryPoint | Tuple[float, ...] | ndarray, *, tol: float = 1e-12) bool[source]

Return True if the lattice momentum grid contains a special point. This method helps to check whether a finite lattice contains a particular high-symmetry point in the Brillouin zone, which is important for band structure calculations and topological analyses.

Parameters:
  • point – Special point identifier. Accepted forms: - label string (e.g. "Gamma", "K", "K'"), - HighSymmetryPoint, - explicit fractional coordinate tuple/array.

  • tol (float) – Absolute tolerance used in the coordinate match.

Notes

The check is done in fractional reciprocal coordinates and naturally includes flux-induced shifts from twisted boundary conditions because it uses self.kvectors_frac.

bz_path(path: List[str] | str | StandardBZPath | None = None, *, points_per_seg: int = 40) Tuple[ndarray, ndarray, List[Tuple[int, str]], ndarray][source]

Generate k-points along a Brillouin zone path.

Parameters:
  • path (list of str, str, StandardBZPath, or None) – Path specification. Can be: - List of high-symmetry point names: [‘Gamma’, ‘X’, ‘M’, ‘Gamma’] - StandardBZPath enum or string: ‘SQUARE_2D’ - None: use default path for this lattice

  • points_per_seg (int) – Number of interpolated points per path segment.

Returns:

  • k_path (np.ndarray, shape (Npath, 3)) – Cartesian k-points along the path.

  • k_dist (np.ndarray, shape (Npath,)) – Cumulative distance for plotting x-axis.

  • labels (List[Tuple[int, str]]) – Indices and labels for high-symmetry points.

  • k_path_frac (np.ndarray, shape (Npath, 3)) – Fractional k-coordinates along the path.

Example

>>> lattice = SquareLattice(dim=2, lx=4, ly=4)
>>> k_path, k_dist, labels, k_frac = lattice.bz_path()
>>> # Or with custom path:
>>> k_path, k_dist, labels, k_frac = lattice.bz_path(['Gamma', 'M', 'Gamma'])
bz_path_points(path: List[str] | str | StandardBZPath | None = None, *, points_per_seg: int = 40, k_vectors: np.ndarray | None = None, k_vectors_frac: np.ndarray | None = None, tol: float = 1e-12, periodic: bool = True) KPathSelection[source]

Build an ideal Brillouin-zone path and optionally match it to an existing k-grid.

If no k-grid is provided, the returned object still contains the continuous path geometry, which is useful for plotting or for constructing a path that is not constrained to the sampled reciprocal mesh. When a sampled grid is provided, reciprocal-lattice copies are generated automatically as needed so paths in extended Brillouin-zone regions can still match the existing data.

Parameters:
  • path (list of str, str, StandardBZPath, or None) – Path specification. Can be: - List of high-symmetry point names: [‘Gamma’, ‘X’, ‘M’, ‘Gamma’] - StandardBZPath enum or string: ‘SQUARE_2D’ - None: use default path for this lattice

  • points_per_seg (int) – Number of interpolated points per path segment.

  • k_vectors (np.ndarray, shape (Nk, 3), optional) – Cartesian k-vectors of the existing grid to match against.

  • k_vectors_frac (np.ndarray, shape (Nk, 3), optional) – Fractional k-vectors of the existing grid to match against. Required if k_vectors is provided.

  • tol (float) – Tolerance for matching path points to the existing k-grid. With periodic=True it is interpreted in fractional reciprocal coordinates. With periodic=False it is interpreted in plotted Cartesian reciprocal coordinates.

  • periodic (bool, default=True) – If True, allow reciprocal-translation-equivalent points to match. Set to False for visual matching in the displayed Brillouin-zone copy.

bz_path_data(k_vectors: ndarray, k_vectors_frac: ndarray, values: ndarray, path: List[str] | Literal['CHAIN_1D', 'SQUARE_2D', 'TRIANGULAR_2D', 'CUBIC_3D', 'HONEYCOMB_2D'] | str | StandardBZPath | None = None, *, points_per_seg: int = 40, return_result: bool = True) KPathResult | Tuple[ndarray, ndarray, List[Tuple[int, str]], ndarray][source]

Extract k-path data from a k-grid using fractional coordinate matching.

This function finds the closest k-points on the actual grid to an ideal path through high-symmetry points. It handles periodic boundary conditions in k-space and automatically reuses reciprocal-lattice copies of the sampled grid when the requested path lies in an extended Brillouin-zone region. It also allows to return a structured KPathResult dataclass or a tuple…

Parameters:
  • lattice (Lattice) – Lattice object with reciprocal lattice vectors

  • k_vectors (np.ndarray, shape (..., 3)) – Cartesian k-points (will be flattened)

  • k_vectors_frac (np.ndarray, shape (..., 3)) – Fractional coordinates of k-points (will be flattened)

  • values (np.ndarray) – Data values sampled on the k-grid. The k-grid axes may appear as (Lx, Ly, Lz, ...) or after leading batch axes such as time or frequency, e.g. (Nw, Lx, Ly, Lz) or (Nw, Lx, Ly, Lz, ...). A single flattened k-grid axis of length Nk is also supported.

  • path (various, optional) – Path specification. Can be: - StandardBZPath enum value (e.g., StandardBZPath.SQUARE_2D) - String name (e.g., ‘SQUARE_2D’) - List of (label, [f1,f2,f3]) tuples - HighSymmetryPoints object (uses default path) - None: uses lattice’s default path if available

  • points_per_seg (int) – Number of interpolated points per path segment

  • return_result (bool) – If True (default), return KPathResult dataclass. If False, return tuple for backwards compatibility.

Returns:

If return_result=True: KPathResult dataclass with all path data. The returned values preserve any leading batch axes and replace the k-grid axes with a path axis. If return_result=False: (k_cart, k_frac, k_dist, labels, values) tuple

Return type:

KPathResult or tuple

Examples

>>> # Using default path from HighSymmetryPoints
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, HighSymmetryPoints.square_2d())
>>> plt.plot(result.k_dist, result.values)
>>> # Using standard path enum
>>> result = bz_path_data(lattice, k_grid, k_frac, energies, 'SQUARE_2D')
>>> # Custom path
>>> custom_path = [('G', [0,0,0]), ('X', [0.5,0,0]), ('G', [0,0,0])]
>>> result      = bz_path_data(lattice, k_grid, k_frac, energies, custom_path)
property flux: BoundaryFlux
set_flux(value: float | Mapping[str | LatticeDirection, float] | None, *, reinit: bool = True) None[source]

Set boundary flux and optionally recalculate k-vectors, DFT, and neighbors.

Parameters:
  • value (float, Mapping, or None) – New flux specification (see _normalize_flux_dict()).

  • reinit (bool) – If True (default), recalculate reciprocal vectors, k-vectors, DFT matrix, and neighbor lists to be consistent with the new flux.

property has_flux: bool

True when a non-trivial boundary flux is attached.

property is_twisted: bool

True when the boundary conditions are TWISTED.

property is_topological: bool

True when the lattice carries a non-trivial boundary flux.

A non-trivial flux (mod \(2\pi\)) introduces a measurable Aharonov-Bohm phase and may change the topological sector of the ground state.

flux_summary() str[source]

Return a human-readable summary of the boundary-flux configuration.

boundary_phase(direction: LatticeDirection, winding: int = 1) complex[source]

Return the complex phase accumulated after crossing the boundary along direction.

Parameters:

directionLatticeDirection

The lattice direction (X, Y, or Z).

windingint

The winding number (number of times crossing the boundary).

Returns:

complex

The complex phase factor e^{i * flux * winding}.

boundary_phases() ndarray[source]

Return a lookup table of complex boundary phases.

Returns:

tabletable[d, w] is exp(i * w * phi_d) for direction d and winding number w.

Return type:

np.ndarray, shape (3, Ns+1)

boundary_phase_from_winding(wx: int, wy: int, wz: int) complex[source]

Return total complex boundary phase accumulated from winding numbers. If no winding (all zero), returns real 1.0.

bond_winding(i: int, j: int) tuple[int, int, int][source]

Compute how many times a bond (i -> j) crosses the periodic boundary in each lattice direction.

Returns (wx, wy, wz), where each entry is 0 if no crossing, +1 if wrapped positively, -1 if wrapped negatively.

Parameters:

iint

Index of the starting lattice site.

jint

Index of the ending lattice site.

Returns:

tuple[int, int, int]

A tuple indicating the winding numbers (wx, wy, wz) for the bond from site i to site j.

is_spanning(sites: Iterable[int]) bool[source]

Check if a set of sites spans the lattice (non-contractible on a torus).

This method uses a BFS-based winding number tracking on the induced subgraph of the provided site indices. If any loop with a non-zero winding number along a periodic direction is found, the set is considered spanning.

bond_phase(i: int, j: int) complex[source]

Return the complex hopping phase factor for the bond \(i \to j\).

For bonds that do not cross a periodic boundary, this is 1. For boundary-crossing bonds under TWISTED BC, the phase is \(\exp(i\,\phi_\mu)\) for each direction \(\mu\) in which the bond wraps.

This is the factor that should multiply the bare hopping amplitude in real-space Hamiltonian construction.

Parameters:
  • i (int) – Source and target site indices.

  • j (int) – Source and target site indices.

Returns:

Phase factor (unit modulus).

Return type:

complex

hopping_matrix_with_flux(*, include_nnn: bool = False) ndarray[source]

Build an \(N_s \times N_s\) matrix of complex hopping amplitudes that includes the Peierls phases from boundary fluxes.

Diagonal is zero. Off-diagonal H[i,j] = t_{ij} * phase(i->j) where t_{ij} = 1 for all connected pairs and phase is the product of boundary phases along directions that the bond wraps.

Parameters:

include_nnn (bool) – If True, include next-nearest-neighbor hoppings as well.

Returns:

H – Complex hopping matrix.

Return type:

np.ndarray, shape (Ns, Ns)

get_nnn_middle_sites(i: int, j: int, orientation: str | None = None) list[int][source]

Return the list of ‘middle’ sites l that are nearest neighbors of both i and j - i.e., sites forming two-step NNN paths i-l-j.

Works for any lattice that implements get_nn(site, idx) and get_nn_num(site).

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

  • orientation ({'anticlockwise', 'clockwise', None}, optional) – If provided, will sort/choose based on geometric angle. Default: None (return all middle sites).

Returns:

List of middle-site indices (can be 0, 1, or 2 elements).

Return type:

list[int]

get_chirality_sign(i: int, j: int, normal: ndarray | None = None, orientation: str | None = None) int[source]

Compute the local orientation (chirality) sign nu_{ij} = pm 1 for a NNN pair (i,j), defined by the cross product of the two bond vectors i-l and l-j.

Works for any 2D or quasi-2D lattice with known site coordinates.

Parameters:
  • i (int) – Site indices (next-nearest neighbors).

  • j (int) – Site indices (next-nearest neighbors).

  • normal (np.ndarray, optional) – Orientation of the lattice plane (default: +z for 2D).

Returns:

+1 for anticlockwise, -1 for clockwise, 0 if not a valid NNN pair.

Return type:

int

bond_type(i: int, j: int) str[source]

Determine the bond type between sites i and j.

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

Returns:

‘nn’ for nearest neighbor, ‘nnn’ for next-nearest neighbor, ‘none’ otherwise.

Return type:

str

periodic_flags() Tuple[bool, bool, bool][source]

Return booleans indicating whether (x, y, z) directions are periodic.

TWISTED boundary conditions are topologically equivalent to PBC (the lattice is still a torus), so all three directions are periodic.

is_periodic(direction: LatticeDirection | None = None, allow_twisted: bool = True) bool[source]

Check if a given direction has periodic boundary conditions.

property typek
property spatial_norm
site_index(x: int, y: int, z: int)[source]

Convert (x, y, z) coordinates to a unique site index (row-major).

Default implementation uses standard lexicographic ordering. Override in subclasses if a different indexing convention is needed.

site_diff(i: int | tuple, j: int | tuple, *, minimum_image: bool = False, real_space: bool = False) Tuple[float, float, float][source]

Return the displacement i -> j with optional PBC minimum-image wrapping.

Parameters:
  • i (int or tuple) – Site indices or explicit coordinates.

  • j (int or tuple) – Site indices or explicit coordinates.

  • minimum_image (bool, default=False) – If True, wrap each periodic direction to the shortest displacement.

  • real_space (bool, default=False) – If True and i, j are site indices, return displacement in real-space vectors (uses displacement()). Otherwise use lattice coordinates.

site_distance(i: int | tuple, j: int | tuple, *, minimum_image: bool = False, real_space: bool = False) float[source]

Return Euclidean distance between two sites/coordinates.

Parameters:
  • minimum_image (bool, default=False) – If True, periodic directions use minimum-image convention.

  • real_space (bool, default=False) – If True and inputs are indices, measure in real-space lattice vectors.

calculate_reciprocal_vectors()[source]

Calculates the reciprocal lattice vectors based on the primitive vectors. Always returns 3D vectors (padding with zeros for lower dimensions).

Returns: - k1, k2, k3 : Reciprocal lattice vectors (always 3D)

calculate_dft_matrix(phase=False, use_fft: bool = False) ndarray[source]

Bloch-type DFT matrix on the site basis.

Indices:

i = (R, beta) real-space cell R and sublattice beta n = (k, alpha) k-point k and sublattice alpha

Elements: $$

F_{(k,alpha),(R,beta)} =

1/sqrt(Nc) * delta_{alpha,beta} * exp(-i k . R).

$$ This is unitary: $$

F^dagger F = I_{Ns}, F F^dagger = I_{Ns},

$$ where Ns = Nc * Nb is the total number of sites, Nc is the number of unit cells, and Nb is the number of sublattices.

Important

When boundary fluxes are present (TWISTED BC), the k-grid used to build the DFT matrix is shifted by phi_mu / (2 pi L_mu) in each direction, exactly as in calculate_k_vectors().

Note that this DFT matrix does not include basis-dependent phases (i.e., exp(-i k . r_basis)).

Calculates the Discrete Fourier Transform (DFT) matrix for the lattice. This method can be optimized using FFT (Fast Fourier Transform) in the future. Reference: https://en.wikipedia.org/wiki/DFT_matrix

Parameters:
  • (bool) (- phase)

  • Returns

  • (ndarray) (- DFT matrix)

get_nei(site: int, **kwargs)[source]

Returns the nearest neighbors of a given site.

Parameters:

direction (-)

get_nei_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

abstractmethod get_real_vec(x: int, y: int, z: int)[source]

Returns the real vector given the coordinates. Uses the lattice constants.

abstractmethod get_norm(x: int, y: int, z: int)[source]

Returns the norm of the vector given the coordinates.

abstractmethod get_nn_direction(site: int, direction: LatticeDirection)[source]

Returns the nearest neighbors in a given direction.

get_nnn_direction(site: int, direction: LatticeDirection)[source]

Returns the next nearest neighbors in a given direction.

wrong_nei(nei)[source]

Check if a given neighbor index is invalid.

A neighbor is considered invalid if it is:
  • None

  • Equal to self.bad_lattice_site

  • NaN (not a number)

  • Less than 0

Parameters:

nei (Any) – The neighbor index to check.

Returns:

True if the neighbor index is invalid, False otherwise.

Return type:

bool

get_nn_num(site: int)[source]

Returns the number of nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of nearest)

get_nn(site, num: int = -1)[source]

Returns the nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

get_nnn_num(site: int)[source]

Returns the number of next nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of next nearest)

get_nnn(site, num: int = -1)[source]

Returns the next nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of next nearest neighbors

get_nn_forward_num_max()[source]

Returns the maximum number of forward nearest neighbors in the lattice.

Returns: - maximum number of nearest neighbors

get_nn_forward_num(site: int)[source]

Returns the number of forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of nearest)

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of nearest neighbors

get_nnn_forward_num(site: int)[source]

Returns the number of forward next nearest neighbors of a given site.

Parameters:
  • site (-)

  • Returns

  • neighbors (- number of next nearest)

get_nnn_forward(site: int, num: int = -1)[source]

Returns the forward next nearest neighbors of a given site.

Parameters:
  • site (-)

  • num (-)

Returns:

  • list of next nearest neighbors

neighbors(site: int, order=1)[source]

Return neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), ‘all’ for both.

neighbors_forward(site: int, order=1)[source]

Return forward neighbors of a site: 1 for nn (all with highest weight), 2 for nnn (all with second-highest), ‘all’ for both.

any_neighbor(site: int, order=1)[source]

Return any neighbor (first) of given order or None.

any_neighbor_forward(site: int, order=1)[source]

Return any forward neighbor (first) of given order or None.

property n_nodes: int

Number of nodes (sites) in the lattice — alias for Ns.

property n_edges: int

Number of unique undirected nearest-neighbour edges.

property positions: ndarray

Real-space position vectors (same as rvectors).

property site_offsets: ndarray

Position offsets of sites inside the unit cell (same as basis).

property basis_coords: ndarray

Integer basis coordinates [nx, ny, nz, sub] for every site.

Shape (Ns, 4) — the first three columns are the cell-index triplet and the last column is the sublattice label.

property ndim: int

Spatial dimensionality of the lattice.

property extent: Tuple[int, ...]

Number of unit cells in each direction (Lx, Ly, Lz).

property pbc: Tuple[bool, bool, bool]

Per-axis periodicity flags (alias for periodic_flags()).

edges(*, filter_color: int | None = None, return_color: bool = False) List[source]

Return list of nearest-neighbour edges.

Parameters:
  • filter_color (int, optional) – If given, return only edges whose bond_type equals this colour.

  • return_color (bool) – If True each element is (i, j, color); otherwise (i, j).

Returns:

Unique undirected edges (i, j) with i < j.

Return type:

list[tuple]

property edge_colors: List[int]

Sequence of bond-type colours for every edge in edges(), matching the order returned by edges().

displacement(i: int, j: int, *, minimum_image: bool = True) ndarray[source]

Real-space displacement vector from site i to site j.

Parameters:
  • i (int) – Site indices.

  • j (int) – Site indices.

  • minimum_image (bool) – If True (default) and the lattice is periodic, return the shortest displacement under periodic boundary conditions.

Return type:

np.ndarray shape (3,)

distance(i: int, j: int, *, minimum_image: bool = True) float[source]

Euclidean distance between sites i and j (PBC-aware by default).

get_coordinates(*args)[source]
get_r_vectors(*args)[source]
get_k_vectors(*args)[source]
get_site_diff(i: int, j: int)[source]
get_k_vec_idx(sym=False)[source]
get_dft(*args)[source]

Returns the DFT matrix

get_spatial_norm(*args)[source]

Returns the spatial norm at lattice site i or all of them

get_difference_idx_matrix(cut=True) list[source]

Returns the matrix with indcies corresponding to a slice from the QMC. A usefull function for reading the position Green’s function saved from: @url https://github.com/makskliczkowski/DQMC The Green’s functions are saved in the following manner. If cut is True, data has (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i. For 1D simulation: 1 column and (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, …, 0, …, Lx) For 2D simulation: (2 * Lx - 1) rows for possition differences (-Lx, -Lx + 1, …, 0, …, Lx) and (2 * Ly - 1) columns for possition differences (-Ly, -Ly + 1, …, 0, …, Ly) For 3D simulation: Same as in 2D but after (2 * Lx - 1) x (2 * Ly - 1) matrix has finished, a new slice for Lz appears for next columns Lz * (2*Ly - 1) - cut : if true (2L_i - 1) possible position differences, otherwise we skip the negative ones and use L_i.

calculate_bonds()[source]

Calculates the bonds for the lattice using forward nn.

calculate_coordinates()[source]

Calculates the coordinates for each lattice site in up to 3D.

Each site index i corresponds to:

cell = i // n_basis sub = i % n_basis

where n_basis = len(self._basis) (e.g., 2 for honeycomb).

Works for any lattice with defined self._a1, _a2, _a3 and self._basis list.

calculate_r_vectors()[source]

Calculates the real-space vectors (r) for each site. Must match the ordering in calculate_coordinates().

calculate_k_vectors()[source]

Calculates the allowed reciprocal-space k-vectors (momentum grid) consistent with the lattice size and primitive reciprocal vectors.

When boundary fluxes are present (TWISTED BC), the fractional coordinates are shifted by \(\phi_\mu / (2\pi L_\mu)\) in each direction, so that the Bloch condition matches the twisted boundary.

The sampling follows the same fftfreq ordering used by the Bloch transform (Γ at index [0,0,0], followed by positive frequencies and finally the negative branch). This keeps the analytic grids aligned with the numerically constructed H(k) blocks.

filter_k_vectors(qx: int | None = None, qy: int | None = None, qz: int | None = None) ndarray[source]

Filters the k-vectors to find those matching the specified fractional components.

Parameters:
  • (int) (qx)

  • (int (qz)

  • optional) (Fractional component in the z-direction. Defaults to None.)

  • (int

  • optional)

Returns:

Array of indices of k-vectors matching the specified components.

Return type:

np.ndarray

translation_operators()[source]

Return translation matrices T1, T2, T3 on the one-hot basis.

calculate_norm_sym()[source]

Calculate a symmetry-normalization measure for each site.

Default: Euclidean norm of the coordinate vector. Override in subclasses for lattice-specific behaviour.

abstractmethod calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
calculate_nn()[source]

Calculates the nearest neighbors.

For TWISTED boundary conditions the neighbor connectivity is identical to PBC — the flux phases are applied separately when building the Hamiltonian or the DFT matrix.

calculate_plaquettes(use_obc: bool = True)[source]
calculate_wilson_loops()[source]

Calculates the Wilson loops (non-contractible loops) for the lattice based on its boundary conditions. Returns a list of lists, where each inner list contains the site indices of a Wilson loop.

Assumes standard lexicographic site indexing (x + y*Lx + z*Lx*Ly).

abstractmethod calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]
calculate_nnn()[source]

Calculates the next nearest neighbors.

Like calculate_nn(), each calculate_nnn_in implementation is expected to set self._nnn (and optionally self._nnn_forward) directly. The return value—if any—is stored as a fallback.

adjacency_matrix(sparse: bool = False, save: bool = True, *, mode: str = 'binary', include_self: bool = False, include_nnn: bool = False, typed_self_separate: bool = True, n_types: int = 3) ndarray[source]

Construct adjacency matrix A_ij = 1 if i and j are neighbors.

Parameters:
  • save (bool) – save the adjacency matrix in the lattice object for future use.

  • mode (str) –

    ‘binary’ :

    A_ij = 1 if i and j are neighbors, 0 otherwise.

    ’typed’ :

    A_ij = weight of the bond between i and j (1 for nn, 2 for nnn, etc.), 0 otherwise.

  • include_self (bool) – include self-connections (diagonal elements) if True.

  • include_nnn (bool) – include next-nearest neighbors if True.

  • typed_self_separate (bool) – if True, self-connections are given a unique weight (n_types) to distinguish them from other types of connections.

  • n_types (int) – number of different neighbor types (nn, nnn, etc.) to consider.

  • sparse (bool) – return a scipy.sparse CSR matrix if True.

Returns:

adjacency matrix of size (Ns, Ns).

Return type:

A (ndarray or sparse CSR)

print_neighbors(logger: Logger)[source]

Logs the neighbors of each site in the lattice using the provided logger.

For each site in the lattice, this method retrieves its nearest neighbors and logs their indices. Additionally, for each neighbor, it logs detailed information using a higher verbosity level.

Parameters:

logger – An object with an info method for logging messages. The info method should accept parameters lvl (int) for verbosity level and color (str) for message color.

print_forward(logger: Logger)[source]

Logs the forward nearest neighbors for each site in the lattice.

For each site in the lattice, this method retrieves the number of forward nearest neighbors and logs their indices using the provided logger. The method outputs two levels of information: - Level 1 (green): Lists the neighbors of each site. - Level 2 (blue): Details each neighbor’s index for the site.

Parameters:

logger (A logging object with an info method that accepts a message,) – a logging level (lvl), and a color (color).

get_geometric_encoding(*, tol=1e-06)[source]

Map each site i to (cell_idx, sub_idx) purely from geometry.

Returns:

  • cell_idx ((Ns,) int array in [0, Nc-1])

  • sub_idx ((Ns,) int array in [0, Nb-1])

realspace_from_kspace(H_k: ndarray, *, block_diag: bool = True, kgrid: ndarray | None = None)[source]

Inverse Bloch transform: reconstruct real-space matrix from k-space blocks.

This is the exact inverse of kspace_from_realspace(). It reconstructs the real-space Hamiltonian from momentum-space blocks using the inverse Fourier transform:

\[H_{\text{real}} = \sum_k W(k)^\dagger H(k) W(k)\]

where \(W(k)\) is the Bloch unitary matrix.

Parameters:
  • H_k (np.ndarray) –

    K-space Hamiltonian blocks in one of two formats:

    • Grid format: shape (Lx, Ly, Lz, Nb, Nb) for full BZ grid

      (as returned by kspace_from_realspace with block_diag=True)

    • List format: shape (Nk, Nb, Nb) for custom k-points

    Must be in fftfreq order (no fftshift applied) to match the forward transform.

  • block_diag (bool, default=True) –

    Mode selector matching the forward transform:

    • If True: Expects H_k in block-diagonal format (grid or list of blocks)

      and returns reconstructed real-space matrix.

    • If False: Expects H_k as full transformed matrix (Ns, Ns) and

      applies inverse DFT directly.

  • kgrid (Optional[np.ndarray], default=None) –

    K-point grid for reference (only used when block_diag=True).

    • If None: Assumes H_k is on the full BZ grid in fftfreq order

    • If provided: Must match the k-points used for the forward transform

      Shape (Lx, Ly, Lz, 3) or (Nk, 3) in fftfreq order.

Returns:

H_real – Reconstructed real-space matrix with shape (Ns, Ns) where Ns = Nc * Nb is the total number of sites.

Return type:

np.ndarray

Notes

  • Round-trip accuracy:
    • Eigenvalues are preserved to machine precision (~1e-15)

  • Both H_k and kgrid must be in fftfreq order (no fftshift)

  • The reconstruction is exact for translationally invariant systems:
    • H_real_reconstructed H_real_original to numerical precision

  • For systems with periodic boundary conditions, the forward and inverse

transforms form a perfect isometry on the Hilbert space.

Examples

Example 1: Round-trip transform (full grid)

>>> # Forward transform
>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>>
>>> # Inverse transform
>>> H_real_recon = lattice.realspace_from_kspace(H_k, kgrid=k_grid)
>>>
>>> # Verify reconstruction
>>> np.allclose(H_real, H_real_recon)  # True

Example 2: Inverse transform without explicit kgrid

>>> # If kgrid is omitted, it's reconstructed using fftfreq convention
>>> H_real_recon = lattice.realspace_from_kspace(H_k)
>>> np.allclose(H_real, H_real_recon)  # True

Example 3: Full matrix mode (inverse DFT)

>>> H_k_full        = lattice.kspace_from_realspace(H_real,     block_diag=False)
>>> H_real_recon    = lattice.realspace_from_kspace(H_k_full,   block_diag=False)
>>> np.allclose(H_real, H_real_recon)  # True

See also

kspace_from_realspace

Forward Bloch transform (real-space to k-space)

structure_factor

Compute momentum-resolved structure factors

References

kspace_from_realspace(mat: ndarray, block_diag: bool = False, kpoints: ndarray | None = None, unitary_norm: bool = True, return_transform: bool = False)[source]

Transform a real-space matrix (Hamiltonian, operator, correlator) to momentum space.

This method provides a convenient interface to the Bloch transform for periodic systems. The transform uses the formula:

\[H_{ab}(k) = \sum_{i,j} W^*_{i,a}(k) H_{i,j} W_{j,b}(k)\]

where \(W_{i,a}(k) = \frac{1}{\sqrt{N_c}} e^{-ik \cdot r_i} \delta_{\text{sub}(i),a}\)

Parameters:
  • mat (np.ndarray) – Real-space matrix with shape (Ns, Ns) where Ns = Nc * Nb is the total number of sites (unit cells x basis sites per cell).

  • block_diag (bool, default=False) –

    Mode selector for different output formats:

    • If False: Returns full transformed matrix H_k_full with shape (Ns, Ns)

      This is the complete DFT of the real-space matrix, useful for structure factors.

    • If True: Returns block-diagonal form with k-space blocks H_k, momentum grid,

      and fractional coordinates. This is the standard mode for band structure calculations.

    Output: (H_k, k_grid, k_grid_frac) where:
    • H_k: shape (Lx, Ly, Lz, Nb, Nb) - Hamiltonian blocks at each k-point

    • k_grid: shape (Lx, Ly, Lz, 3) - Cartesian k-point coordinates

    • k_grid_frac: shape (Lx, Ly, Lz, 3) - Fractional k-point coordinates

  • kpoints (Optional[np.ndarray], default=None) –

    Custom k-point sampling (only used when block_diag=True):

    • If None: Uses automatic full Brillouin zone grid based on lattice size

      (recommended for most use cases)

    • If provided: Array of shape (Nk, 3) with custom k-points in Cartesian coordinates

      Returns (H_k, kpoints) with H_k shape (Nk, Nb, Nb)

  • unitary_norm (bool, default=True) – Use unitary normalization \(1/\sqrt{N_c}\) for the Bloch transform. If False, uses normalization \(1/N_c\) instead. Keep True for standard quantum mechanics convention preserving operator norms.

  • return_transform (bool, default=False) –

    If True, also return the Bloch unitary matrix W used for the transformation. This is useful for transforming additional operators or computing correlation functions.

    Note: Only available when block_diag=True. The unitary is returned as a 4th output value with shape (Lx, Ly, Lz, Ns, Nb) or (Nk, Ns, Nb) if custom k-points are provided.

Returns:

  • **Case 1 (block_diag=False (default)**) –

    H_k_fullnp.ndarray

    Full transformed matrix with shape (Ns, Ns). This is the complete DFT of the input matrix, preserving all information.

  • **Case 2 (block_diag=True, kpoints=None (full grid)**) –

    H_knp.ndarray

    K-space Hamiltonian blocks with shape (Lx, Ly, Lz, Nb, Nb) where:

    • Lx, Ly, Lz are the lattice dimensions

    • Nb is the number of basis sites per unit cell

    • H_k[ix, iy, iz] is the Nb x Nb block at k-point [ix, iy, iz]

    k_gridnp.ndarray

    Cartesian k-point coordinates with shape (Lx, Ly, Lz, 3). The Γ-point is at index [Lx//2, Ly//2, Lz//2] after fftshift.

    k_grid_fracnp.ndarray

    Fractional k-point coordinates with shape (Lx, Ly, Lz, 3). Values are in the range [0, 1) corresponding to the first Brillouin zone.

    Wnp.ndarray, optional

    Bloch unitary matrix with shape (Lx, Ly, Lz, Ns, Nb). Only returned if return_transform=True. Use for transforming operators: O_k = W† @ O_real @ W

  • **Case 3 (block_diag=True, kpoints provided (custom sampling)**) –

    H_knp.ndarray

    K-space Hamiltonian blocks with shape (Nk, Nb, Nb) where Nk is the number of custom k-points provided.

    kpoints_outnp.ndarray

    Echo of the input k-points with shape (Nk, 3).

    Wnp.ndarray, optional

    Bloch unitary matrix with shape (Nk, Ns, Nb). Only returned if return_transform=True.

Examples

Example 1: Full matrix transform for structure factor

>>> H_k_full = lattice.kspace_from_realspace(H_real, block_diag=False)
>>> # H_k_full has shape (Ns, Ns)

Example 2: Block-diagonal form for band structure (recommended)

>>> H_k, k_grid, k_frac = lattice.kspace_from_realspace(H_real, block_diag=True)
>>> # H_k has shape (Lx, Ly, Lz, Nb, Nb)
>>> # Diagonalize each block to get bands
>>> energies = np.linalg.eigvalsh(H_k)  # shape (Lx, Ly, Lz, Nb)

Example 3: Custom k-points (e.g., high-symmetry path)

>>> k_path = lattice.generate_kpath(['Γ', 'X', 'M', 'Γ'], npoints=100)
>>> H_k, k_pts = lattice.kspace_from_realspace(
...     H_real, block_diag=True, kpoints=k_path
... )
>>> # H_k has shape (100, Nb, Nb)
>>> energies = np.linalg.eigvalsh(H_k)  # shape (100, Nb)

Example 4: Get Bloch unitary for operator transforms

>>> H_k, k_grid, k_frac, W = lattice.kspace_from_realspace(
...     H_real, block_diag=True, return_transform=True
... )
>>> # Transform another operator using the same W
>>> O_k = np.einsum('kia,ij,kjb->kab', W.conj(), O_real, W)

Notes

  • Periodic boundary conditions (PBC) are assumed for the Bloch transform.

  • The method assumes translational invariance of the system, which ensures

the spectrum of H_real equals the union of spectra of H(k) blocks. - For the full grid (kpoints=None), the k-points follow the fftfreq convention with the Γ-point initially at index [0, 0, 0], then shifted to the center. - Site ordering is arbitrary; the method uses the lattice geometry (coordinates + basis) to correctly identify sublattices and apply phases. - For sparse input matrices, automatic conversion to dense format is performed.

See also

realspace_from_kspace

Inverse transform from k-space to real-space

structure_factor

Compute momentum-resolved structure factors with reduction options

generate_kpath

Generate high-symmetry k-point paths for band structure plotting

References

structure_factor(mat: ndarray, *, reduction: Literal['none', 'sum', 'trace', 'mean', 'diag'] = 'sum', norm: Literal['none', 'cell', 'site'] = 'none')[source]

Convert a real-space correlation matrix into a momentum-resolved structure factor.

This is a convenience wrapper around the basis-aware Bloch projector in QES.general_python.lattices.tools.lattice_kspace.kspace_from_realspace. The real-space input mat is first transformed into the multipartite k-space block representation evaluated on self.kvectors

\[C_{\alpha\beta}(q) = \frac{1}{N_c} \sum_{R,R'} e^{-i q\cdot(R-R')} \langle O_{R,\alpha} O_{R',\beta} \rangle,\]

where R, R' label unit cells and alpha, beta label basis sites inside the unit cell. The reduction argument then decides how this multipartite object is converted into a scalar structure factor at each sampled momentum q.

Parameters:
  • mat (np.ndarray) – Real-space correlation or operator matrix with shape (Ns, Ns) or batched shape (..., Ns, Ns). Any leading axes, e.g. time, frequency, disorder sample, or state index, are preserved.

  • reduction ({"none", "sum", "trace", "mean", "diag"}, default="sum") –

    How to reduce the multipartite k-space blocks:

    • "none":

      return the full k-space blocks

      C(q) with shape (Lx, Ly, Lz, Nb, Nb) (i.e., no reduction).

    • "sum":

      return sum_{alpha,beta} C_{alpha beta}(q) (i.e., sum over all entries of each block).

    • "trace":

      return sum_alpha C_{alpha alpha}(q) (i.e., sum over diagonal entries of each block).

    • "mean":

      return the arithmetic mean of all multipartite block entries at each q (i.e., sum over all entries and divide by Nb^2).

    • "diag":

      return the eigenvalues of each block, which can be useful for identifying dominant modes or instabilities. The output shape will be (Lx, Ly, Lz, Nb) since each block’s eigenvalues are returned as a vector of length Nb.

  • norm ({"none", "cell", "site"}, default="none") –

    Optional post-normalization of the returned k-space quantity:

    • "none":

      keep the raw Bloch-projector normalization, i.e. the blocks

      C(q) defined above with the prefactor 1 / N_c.

    • "cell":

      alias for "none" kept for readability when you want to

      emphasize unit-cell normalization.

    • "site":

      divide the returned blocks or reduced values by the number of

      basis sites N_b. For scalar reductions such as "sum", this converts the default unit-cell normalization into the more common site normalization 1 / N_s used in \(S(q) = \langle O_{-q} O_q \rangle\).

Returns:

  • values (np.ndarray) – Momentum-resolved structure factor. For input shape (..., Ns, Ns) the output shape is:

    • (..., Lx, Ly, Lz, Nb, Nb) for reduction="none"

    • (..., Lx, Ly, Lz, Nb) for reduction="diag"

    • (..., Lx, Ly, Lz) for "sum", "trace", or "mean"

    For a single input matrix (Ns, Ns), the leading ... is absent.

  • k_grid (np.ndarray) – Cartesian sampled k-grid with shape (Lx, Ly, Lz, 3).

  • k_frac (np.ndarray) – Fractional sampled k-grid with shape (Lx, Ly, Lz, 3).

Notes

Use reduction="none" when sublattice-resolved information matters. Use one of the scalar reductions when you want a single value per momentum that can be fed directly into bz_path_data.

The default norm="none" preserves the existing unit-cell normalization. For comparisons against structure factors built from Fourier-transformed site operators, norm="site" is typically the physically relevant choice.

Examples

>>> Sq, k_grid, k_frac  = lattice.structure_factor(corr_zz, reduction="sum")
>>> path                = lattice.bz_path_data(k_grid, k_frac, Sq, path=['Gamma', 'K', 'M', 'Gamma'])
>>>
>>> # Frequency-resolved data with shape (Nw, Ns, Ns)
>>> Sqw, k_grid, k_frac = lattice.structure_factor(corr_zz_w, reduction="sum")
>>> # Sqw has shape (Nw, Lx, Ly, Lz)
summary_string(*, precision: int = 3) str[source]

Return a textual summary of lattice metadata.

real_space_table(*, max_rows: int = 10, precision: int = 3) str[source]

Return a formatted table of real-space vectors.

reciprocal_space_table(*, max_rows: int = 10, precision: int = 3) str[source]

Return a formatted table of reciprocal-space vectors.

brillouin_zone_overview(*, precision: int = 3) str[source]

Return a textual overview of the sampled Brillouin zone.

describe(*, precision: int = 3, max_rows: int = 10, include_vectors: bool = True, include_reciprocal: bool = True, include_brillouin_zone: bool = True) str[source]

Combine multiple presentation helpers into a single multi-section string.

plot_real_space(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a real-space scatter plot.

plot_reciprocal_space(**kwargs)[source]

Scatter-plot of reciprocal lattice vectors (k-points).

Parameters mirror plot_real_space()

latticeLattice

The lattice object to plot.

axAxes, optional

Matplotlib axes to plot on. If None, a new figure is created.

show_indicesbool, default=False

If True, annotate each k-point with its index.

show_axesbool, default=True

If False, hides the coordinate axes.

colorstr, default=”C1”

Color of the k-point markers.

markerstr, default=”o”

Marker style.

figsizetuple, optional

Figure size in inches (width, height).

titlestr, optional

Title of the plot.

elev, azimfloat, optional

Elevation and azimuth angles for 3D plots.

extend_kpointsbool, default=False

If True, draw translated reciprocal-space copies around the original mesh.

extend_copiesint or iterable of int, default=2

Number of copies per reciprocal direction used when extend_kpoints=True. Scalars are applied to all active reciprocal directions.

extend_tolfloat, default=1e-10

Tolerance used to identify which extended points are already present in the original reciprocal mesh.

**scatter_kwargs

Include: - point_edgecolor: Color of the marker edges (default “white”). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default “C2”). - edgecolor_extended: Edge color for translated copies (default “gray”). - marker_extended: Marker for translated copies (default marker). - Any other valid arguments for ax.scatter.

plot_brillouin_zone(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a Brillouin zone plot.

Parameters:
  • lattice (Lattice) – The lattice object containing k-vectors.

  • ax (Axes, optional) – Matplotlib axes to plot on. If None, a new figure is created.

  • facecolor (str, default="tab:blue") – Color to fill the Brillouin Zone area.

  • edgecolor (str, default="black") – Color for the Brillouin Zone boundary.

  • alpha (float, default=0.25) – Transparency level for the Brillouin Zone fill.

  • figsize (tuple, optional) – Figure size in inches (width, height).

  • title (str, optional) – Title of the plot.

  • elev (float, optional) – Elevation and azimuth angles for 3D plots.

  • azim (float, optional) – Elevation and azimuth angles for 3D plots.

plot_structure(**kwargs)[source]

Convenience wrapper returning the matplotlib figure and axes for a detailed lattice structure plot.

Parameters:
  • show_indices (bool) – If True, annotates nodes with their site indices.

  • highlight_boundary (bool) – If True, draws boundary nodes with a distinct color/edge.

  • show_axes (bool) – If False, hides the coordinate axes for a cleaner diagram.

  • partition_colors (tuple of str, optional) – Colors to use for bipartite/sublattice coloring. If provided, nodes are colored based on sublattice parity.

  • show_periodic_connections (bool) – If True, indicates wrap-around connections textually or graphically.

  • show_primitive_cell (bool) – If True, overlays the primitive unit cell vectors/box.

  • (e.g. (... other kwargs passed to the underlying plotting function)

  • size (node)

  • map (color)

  • etc.)

  • details. (see plot_lattice_structure() for)

plot_high_symmetry(**kwargs)[source]

Convenience wrapper for plotting the Brillouin zone, high-symmetry path, and sampled reciprocal mesh.

Parameters:
  • path (list[str], str, or iterable[(label, frac)], optional) – High-symmetry path specification. If omitted, the lattice default path is used.

  • show_kpoints (bool, default=True) – Draw sampled reciprocal-space mesh points.

  • show_bz (bool, default=True) – Draw the first Brillouin zone.

  • show_path (bool, default=True) – Draw the ideal high-symmetry path.

  • show_matched_kpoints (bool, default=True) – Highlight sampled k-points whose distance to the path is within the matching tolerance.

  • points_per_seg (int, default=40) – Number of interpolation points per path segment for the ideal path.

  • path_match_tol (float, optional) – Distance tolerance used when highlighting mesh points near the drawn path.

  • extend (bool, default=False) – Draw translated copies of the sampled k-mesh.

  • extend_copies (int or iterable[int], optional) – Number of reciprocal-cell copies per direction. In 2D, extend_copies=1 includes the first shell around the first Brillouin zone and extend_copies=2 includes the second shell as well.

  • show_background_bz (bool, default=False) – Draw translated Brillouin-zone copies behind the mesh.

  • hs_plot ({"none", "markers", "labels", "both"}, default="markers") – Whether to draw exact high-symmetry markers, labels, or both.

  • legend_kwargs (dict, optional) – Extra keyword arguments passed to axis.legend.

  • **kwargs – Additional style overrides forwarded to plot_high_symmetry_points.

property plot

Access plotting utilities for this lattice.

Returns a LatticePlotter instance providing methods: - real_space(**kwargs) : Scatter plot of sites. - reciprocal_space(**kwargs) : Scatter plot of reciprocal lattice vectors. - brillouin_zone(**kwargs) : Visualization of the Brillouin Zone. - structure(**kwargs) : Detailed connectivity plot with boundaries.

Example

>>> lat.plot.structure(show_indices=True, highlight_boundary=True)
>>> lat.plot.brillouin_zone()
general_python.lattices.lattice.save_bonds(lattice: Lattice, directory: str, filename: str)[source]

Saves the bonds of the lattice to a file Parameters

  • filename : filename

Returns: - True if the file has been saved, False otherwise

Square Lattice Class… @Author: Maksymilian Kliczkowski @Email: maksymilian.kliczkowski@pwr.edu.pl @Date: 2025-02-01

class general_python.lattices.square.SquareLattice(lx=1, ly=1, lz=1, dim=None, bc=pbc, **kwargs)[source]

Bases: Lattice

Square Lattice Class for 1D, 2D, and 3D lattices.

The lattice vectors are defined as: - a = [1, 0, 0], - b = [0, 1, 0], - c = [0, 0, 1]

and the reciprocal lattice vectors are: - a* = [2*pi, 0, 0], - b* = [0, 2*pi, 0], - c* = [0, 0, 2*pi]

Input/output contracts

  • Constructor expects integer dimensions lx, ly, lz (as applicable to dim).

  • bc must be a LatticeBC enum or compatible string/int identifier.

  • Coordinates are returned as floating-point arrays of shape (Ns, dim).

  • Neighbor lists are lists of lists, where neighbors[i] contains indices of neighbors of site i.

Shape and dtype expectations

  • coordinates: Real-valued array of shape (Ns, dim).

  • kvectors: Real-valued array of shape (Ns, 3) (or dim).

  • Neighbor indices are integers in range [0, Ns).

High-symmetry points in the Brillouin zone: - 1D: Γ (0) -> X (Pi) -> Γ (2Pi) - 2D: Γ (0,0) -> X (Pi,0) -> M (Pi,Pi) -> Γ (0,0) - 3D: Γ -> X -> M -> Γ -> R -> X

__init__(lx=1, ly=1, lz=1, dim=None, bc=pbc, **kwargs)[source]

Initializer of the square lattice

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the square/cubic lattice.

Returns:

High-symmetry points with default path based on dimension: - 1D: Γ -> X -> Γ (zone boundary at Pi) - 2D: Γ -> X -> M -> Γ (standard square BZ path) - 3D: Γ -> X -> M -> Γ -> R -> X (standard cubic BZ path)

Return type:

HighSymmetryPoints

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a square/cubic special point is present in the current k-grid.

get_k_vec_idx(sym=False)[source]

Returns the indices of kvectors, considering symmetry reduction.

calculate_norm_sym()[source]

Calculates the normalization factors considering symmetric momenta.

site_index(x, y, z)[source]

Convert (x, y, z) coordinates to a site index. :param x: x-coordinate :type x: int :param y: y-coordinate :type y: int :param z: z-coordinate :type z: int

get_real_vec(x: int, y: int, z: int)[source]

Returns the real vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Returns the Euclidean norm of a real-space vector.

get_nn_direction(site: int, direction: LatticeDirection)[source]

Returns nearest neighbors in a given direction (X, Y, Z). :param site: Site index :type site: int :param direction: Direction to get the nearest neighbors :type direction: LatticeDirection

get_nn_forward_num_max()[source]

Maximum number of forward nearest-neighbor bonds per square-lattice site.

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbors of a given site.

get_nnn_forward(site, num: int = -1)[source]

Returns the forward next-nearest neighbors of a given site

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) for 1D, 2D, and 3D square lattices. Also calculates the forward nearest neighbors (NNF).

Parameters:
  • pbcx (-) – Periodic boundary condition in x direction

  • pbcy (-) – Periodic boundary condition in y direction

  • pbcz (-) – Periodic boundary condition in z direction

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) for 1D, 2D, and 3D square lattices. Also calculates the forward next-nearest neighbors (NNNF).

Parameters:
  • pbcx (-) – Periodic boundary condition in x direction

  • pbcy (-) – Periodic boundary condition in y direction

  • pbcz (-) – Periodic boundary condition in z direction

static dispersion(k)[source]

Simple nearest-neighbour tight-binding/spin-wave-like dispersion for the square lattice. Accepts k as (2,) or (…,2) array and returns same-shaped scalar or array of energies.

Armchair Hexagonal (Honeycomb) Lattice implementation.

This module provides the HexagonalLattice class which implements an armchair-oriented honeycomb lattice. Unlike the zig-zag HoneycombLattice, the primitive vectors here are chosen so that the lattice grows vertically and horizontally (aligned with x and y coordinate axes) with armchair edges. There are no dangling bonds at the boundary when periodic boundary conditions are used.

Geometry

The unit cell contains 2 sites (A and B sublattices). Primitive lattice vectors (a = 1 by default):

\[\begin{split}\\mathbf{a}_1 = (\\sqrt{3}\\,a,\\; 0,\\; 0), \\qquad \\mathbf{a}_2 = (\\tfrac{\\sqrt{3}}{2}\\,a,\\; \\tfrac{3}{2}\\,a,\\; 0).\end{split}\]

Basis positions inside each unit cell:

\[\begin{split}\\mathbf{d}_A = (0,\\; 0,\\; 0), \\qquad \\mathbf{d}_B = (0,\\; a,\\; 0).\end{split}\]

Nearest neighbours (coordination z = 3 per site):

  • A at cell (n_x, n_y): → B(n_x, n_y), B(n_x, n_y−1), B(n_x+1, n_y−1)

  • B at cell (n_x, n_y): → A(n_x, n_y), A(n_x, n_y+1), A(n_x−1, n_y+1)

High-symmetry points in the Brillouin zone:
  • Γ (Gamma): Zone center (0, 0)

  • K: Corner of hexagonal BZ (2/3, 1/3)

  • K’: Inequivalent corner (1/3, 2/3)

  • M: Edge midpoint (1/2, 0)

Default path: Γ -> K -> M -> Γ

Date : 2025-02-13

class general_python.lattices.hexagonal.HexagonalLattice(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Bases: Lattice

Armchair-oriented hexagonal (honeycomb) lattice up to 3 dimensions.

The lattice is constructed so that armchair edges lie along the horizontal (x) axis, giving a rectangular bounding box aligned with the coordinate system. Two sites per unit cell (A / B sublattices).

Parameters:
  • dim (int) – Lattice dimensionality (1, 2, or 3).

  • lx (int) – Number of unit cells along each lattice-vector direction.

  • ly (int) – Number of unit cells along each lattice-vector direction.

  • lz (int) – Number of unit cells along each lattice-vector direction.

  • bc (str or LatticeBC) – Boundary conditions ('pbc', 'obc', etc.).

  • **kwargs – Forwarded to Lattice (e.g. flux).

__init__(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

General Lattice class. This class contains the general lattice model.

Parameters:
  • dim (int, optional) – Dimension of the lattice (1, 2, or 3). If None, inferred from lx, ly, lz.

  • lx (int, optional) – Length of the lattice in the x-direction.

  • ly (int, optional) – Length of the lattice in the y-direction.

  • lz (int, optional) – Length of the lattice in the z-direction.

  • bc (str, optional) – Boundary conditions (e.g., ‘PBC’, ‘OBC’).

  • adj_mat (np.ndarray, optional) – Adjacency matrix for the lattice.

  • flux (np.ndarray, optional) – Flux piercing the boundaries. This can be a dictionary specifying the flux in each direction, or a single value applied to all directions. Importantly, this automatically implies TWISTED boundary conditions, so the bc parameter can be left as None or set to ‘TWISTED’ for clarity.

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the hexagonal BZ.

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a hexagonal special point is present in the current k-grid.

static dispersion(k, a=1.0)[source]

Hexagonal/honeycomb (armchair) nearest-neighbour dispersion magnitude. Uses the three NN vectors defined in the hexagonal geometry.

get_real_vec(x: int, y: int, z: int)[source]

Real-space position for stored coordinate tuple (x, y, z).

The base class calculate_coordinates already stores proper vectors via _a1, _a2, _a3, _basis. This helper is kept for backwards compatibility and any custom coordinate look-ups.

get_norm(x: int, y: int, z: int)[source]

Euclidean norm of the real-space vector.

get_nn_direction(site: int, direction: LatticeDirection)[source]

Return the nearest-neighbour in the specified bond direction.

Mapping:

X -> intra-cell bond (A<->B within same cell) Y -> bond along a2 Z -> bond along a1

bond_type(s1: int, s2: int) int[source]

Return directional bond type (X_BOND, Y_BOND, Z_BOND) or -1.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculate nearest neighbours for the armchair hexagonal lattice.

Each site has exactly 3 nearest neighbours (honeycomb coordination).

Bond convention (for an A-site at cell (cx, cy)):

[X_BOND] intra-cell -> B(cx, cy ) [Y_BOND] along -a2 -> B(cx, cy-1) [Z_BOND] along a1 - a2 -> B(cx+1, cy-1)

For a B-site at cell (cx, cy):

[X_BOND] intra-cell -> A(cx, cy ) [Y_BOND] along +a2 -> A(cx, cy+1) [Z_BOND] along -a1+a2 -> A(cx-1, cy+1)

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculate next-nearest neighbours for the armchair hexagonal lattice.

NNN connect sites within the same sublattice. Each site has 6 NNN in the full 2D honeycomb; finite clusters may have fewer depending on boundary conditions.

NNN displacements (same sublattice) in cell coordinates:

+a1, -a1, +a2, -a2, +(a1-a2), -(a1-a2)
i.e.  (+-1, 0), (0, +-1), (+-1, -+1)
get_sym_pos(x, y, z)[source]

Map coordinates to a position in the symmetry norm array.

For the armchair lattice with 2 sublattices, y ranges over 0 .. 2*Ly - 1 (cell index ×2 + sublattice).

get_sym_pos_inv(x, y, z)[source]

Inverse of get_sym_pos().

symmetry_checker(x, y, z)[source]

Always returns True (placeholder for future symmetry calculations).

calculate_plaquettes()[source]

Calculate hexagonal plaquettes (6-site loops) of the armchair lattice.

Each plaquette is a list of 6 site indices forming a closed loop around one hexagonal face. Only unique plaquettes are returned.

Triangular Lattice Class Implements a 2D triangular lattice for general_python.

Date : 2025-12-22

class general_python.lattices.triangular.TriangularLattice(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Bases: Lattice

Implementation of the Triangular Lattice (2D). The triangular lattice is a 2D Bravais lattice with each site having 6 nearest neighbors.

__init__(*, dim=2, lx=3, ly=3, lz=1, bc='pbc', **kwargs)[source]

Initialize a Triangular Lattice.

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the triangular lattice Brillouin zone.

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a triangular special point is present in the current k-grid.

get_real_vec(x: int, y: int, z: int)[source]

Returns the real-space vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Return the Euclidean norm of integer coordinate offsets.

get_nn_direction(site, direction)[source]

Return the nearest neighbor associated with a lattice direction.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) for the triangular lattice.

Each site has 6 nearest neighbors in 2D corresponding to the six lattice-vector displacements:

+a1, -a1, +a2, -a2, +(a1-a2), -(a1-a2)
i.e. cell-coordinate offsets:

(+1,0), (-1,0), (0,+1), (0,-1), (+1,-1), (-1,+1)

Forward bonds are those connecting to a site with strictly higher index so that each bond is counted exactly once.

Parameters:
  • pbcx (bool) – Whether periodic boundary conditions apply along each direction.

  • pbcy (bool) – Whether periodic boundary conditions apply along each direction.

  • pbcz (bool) – Whether periodic boundary conditions apply along each direction.

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) for the triangular lattice.

NNN are at cell-coordinate offsets:

(+2,0), (-2,0), (0,+2), (0,-2), (+2,-2), (-2,+2),
(+1,+1), (-1,-1), (+2,-1), (-2,+1), (+1,-2), (-1,+2)
site_index(x, y, z)[source]

Convert integer cell coordinates to a linear site index.

static dispersion(k, a=1.0)[source]

Simple triangular-lattice dispersion approximation: ω(k) = 2J * [3 - cos(k·a1) - cos(k·a2) - cos(k·(a1 - a2))] where a1=(a,0), a2=(a/2, √3 a/2). Accepts k as (2,) or (…,2).

Contains the Honeycomb lattice implementation. This module defines the HoneycombLattice class, which extends the base Lattice class to represent a 2D honeycomb lattice structure. It includes methods for calculating nearest and next-nearest neighbors, as well as lattice vectors and coordinates.

Date : 2025-11-01 License : MIT ———————————

class general_python.lattices.honeycomb.HoneycombLattice(lx=3, ly=1, *, lz=1, dim=2, bc='pbc', **kwargs)[source]

Bases: Lattice

Implementation of the Honeycomb Lattice.

The honeycomb lattice is a 2D lattice with a hexagonal structure. The lattice consists of two sublattices (A and B) arranged in a hexagonal pattern. Nearest and next-nearest neighbors are computed based on a hexagonal unit cell.

High-symmetry points in the Brillouin zone: - Gamma:

Zone center at (0, 0)

  • K:

    Dirac point at (2/3, 1/3) - hosts linear band crossings in graphene

  • K’:

    Other Dirac point at (1/3, 2/3)

  • M:

    Edge midpoint at (1/2, 0)

Default path: Γ -> K -> M -> Γ

References

Lx, Ly, Lz

Number of lattice sites in x, y, and z directions.

bc

Boundary condition (e.g. PBC or OBC).

a, c

Lattice parameters.

vectors

Primitive lattice vectors.

kvectors

Reciprocal lattice vectors.

rvectors

Real-space vectors.

__init__(lx=3, ly=1, *, lz=1, dim=2, bc='pbc', **kwargs)[source]

Initialize a honeycomb lattice.

Parameters:
  • (int) (lz) – Lattice dimension (1, 2, or 3)

  • lx – Lattice sizes in x, y, z directions.

  • ly – Lattice sizes in x, y, z directions.

  • (int) – Lattice sizes in x, y, z directions.

  • bc – Boundary condition (e.g. LatticeBC.PBC or LatticeBC.OBC)

high_symmetry_points() HighSymmetryPoints | None[source]

Return high-symmetry points for the honeycomb lattice.

Returns:

High-symmetry points for the hexagonal Brillouin zone: - Γ (Gamma): Zone center (0, 0) - K: Dirac point at (2/3, 1/3) - hosts linear band crossings - K’: Other Dirac point at (1/3, 2/3) - M: Edge midpoint at (1/2, 0)

Default path: Γ -> K -> M -> Γ

Return type:

HighSymmetryPoints

contains_special_point(point, *, tol: float = 1e-12) bool[source]

Check if a honeycomb special point is present in the current k-grid.

get_real_vec(x: int, y: int, z: int = 0)[source]

Returns the real-space vector for a given (x, y, z) coordinate.

get_norm(x: int, y: int, z: int)[source]

Returns the Euclidean norm of the real-space vector.

get_nn_direction(site, direction)[source]

Returns the nearest neighbor in the specified direction.

For the honeycomb lattice, we choose a mapping:

LatticeDirection.X -> neighbor at index 0 of _nn[site] LatticeDirection.Y -> neighbor at index 1 of _nn[site] LatticeDirection.Z -> neighbor at index 2 of _nn[site]

get_nn_forward(site: int, num: int = -1)[source]

Returns the forward nearest neighbor for the given site.

(For honeycomb, this could be defined as the first neighbor in a chosen ordering.)

get_nnn_forward(site: int, num: int = -1)[source]

Returns the forward next-nearest neighbor for the given site.

calculate_nn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the nearest neighbors (NN) using boundary conditions.

The implementation uses a helper function to apply periodic or open boundary conditions. For 2D, for example, we use a different treatment on even and odd indices.

calculate_nnn_in(pbcx: bool, pbcy: bool, pbcz: bool)[source]

Calculates the next-nearest neighbors (NNN) of the honeycomb lattice.

NNN are second-nearest neighbors, connecting sites on the same sublattice. For sublattice A (even sites), the three NNN directions are obtained by composing two consecutive NN hops:

NNN_1: Y-bond then X-bond^{-1} → cell shift (0, -1) [down in y] NNN_2: Z-bond then Y-bond^{-1} → cell shift (-1, 0) [left in x] NNN_3: X-bond then Z-bond^{-1} → cell shift (+1, +1) [diagonal]

For sublattice B (odd sites), the shifts are inverted.

calculate_norm_sym()[source]

Uses base implementation.

get_sym_pos(x, y, z)[source]

Returns the symmetry-transformed position.

get_sym_pos_inv(x, y, z)[source]

Returns the inverse symmetry-transformed position.

bond_type(s1: int, s2: int) int[source]

Return directional bond type (X_BOND_NEI, Y_BOND_NEI, Z_BOND_NEI) or -1.

calculate_plaquettes(open_bc: bool | None = None)[source]

Calculate the hexagonal plaquettes of the honeycomb lattice.

static dispersion(k, a=1.0)[source]

Honeycomb (graphene-like) nearest-neighbour dispersion magnitude. Computes |f(k)| where f = sum_{δ} exp(-i k·δ) for the three A->B vectors used by this lattice implementation.

Mathematics Module

Mathematical helper package with statistics and random tools.

This package collects utility functions used by physics, algebra, and ML modules. Submodules are lazy-loaded to keep top-level import overhead low.

Scope

  • math_utils: numerical helper routines.

  • statistics: smoothing, aggregation, and summary statistics.

  • random: random-matrix-oriented helpers; for backend-wide RNG streams use general_python.algebra.ran_wrapper.

Input/output, dtype, and shape guidance

Most routines accept NumPy array-like inputs and return NumPy arrays or scalars. APIs typically operate on 1D or 2D arrays where axis conventions are documented per function. Use floating dtypes for interpolation and filtering paths.

Determinism and stability

Pure algebraic or statistical transforms are deterministic for fixed inputs. Randomized routines require explicit seeding to be reproducible.

general_python.maths.get_module_description(module_name: str) str[source]

Return a description for a maths submodule.

general_python.maths.list_available_modules() List[str][source]

Return the list of available maths submodules.

Math utilities for various mathematical operations, fittings, and distributions. It includes functions for finding maxima, nearest values, and fitting data to various models.

Imports: - scipy.optimize.curve_fit - scipy.interpolate.splrep - scipy.interpolate.BSpline

Functions: - Fitter - FitterParams

File : general_python/maths/math_utils.py Version : 0.1.0 Author : Maksymilian Kliczkowski License : MIT

general_python.maths.math_utils.find_maximum_idx(x)[source]

Find maximum index in a DataFrame, numpy array, or JAX array - x : DataFrame, numpy array, or JAX array

general_python.maths.math_utils.find_nearest_val(x, val, col)[source]

Find the nearest value to the value given - x : a DataFrame or numpy array - val : a scalar - col : a string on which column to find the nearest

general_python.maths.math_utils.find_nearest_idx(x, val: float, **kwargs)[source]

Find the nearest idx to the value given - x : a DataFrame or numpy array - val : a scalar - col : a string on which column to find the nearest Returns the index of the nearest value to the given value

class general_python.maths.math_utils.FitterParams(funct, popt, pcov)[source]

Bases: object

Class that stores only the parameters of the fit function - popt : parameters of the fit - pcov : covariance matrix of the fit - funct : function of the fit

__init__(funct, popt, pcov)[source]

Initialize the class - funct : function of the fit - popt : parameters of the fit - pcov : covariance matrix of the fit

get_popt()[source]

Optimized fit parameters.

get_pcov()[source]

Estimated covariance matrix of the fitted parameters.

get_fun()[source]

Fitted callable.

property popt

Optimized fit parameters.

property pcov

Estimated covariance matrix of the fitted parameters.

property funct

Fitted callable.

class general_python.maths.math_utils.Fitter(x: ndarray, y: ndarray)[source]

Bases: object

Class that contains the fit functions and their general usage. - x : arguments - y : values - fitter: FitterParams object

__init__(x: ndarray, y: ndarray)[source]

Initialize the class - x : arguments - y : values

apply(x: ndarray)[source]

Evaluate the current fitted function at x.

static skip(x, y, skipF=0, skipL=0)[source]

Skips a certain part of the values for the fit - x : arguments to trim - y : values to trim - skipF : number of first elements to skip - skipL : number of last elements to skip

static aggregate(values, *, mean_type: str = 'arithmetic', weights=None, trim_fraction: float = 0.0, eps: float = 1e-300)[source]

Aggregate samples with configurable averaging convention.

Parameters:
  • values (array-like) – Input samples.

  • mean_type (str) – One of arithmetic/avg/mean, typical/typ/geometric, median, harmonic.

  • weights (array-like, optional) – Optional weights for arithmetic averaging only.

  • trim_fraction (float, default=0.0) – Fraction trimmed from each tail for arithmetic mean (0 <= trim < 0.5).

  • eps (float, default=1e-300) – Positivity cutoff for logarithmic/harmonic aggregations.

static fit_loglog_linear(x, y)[source]

Fit log(y) = a + b log(x) and return (a, b, r2).

static fit_power_scaling(x, y)[source]

Fit y = A x^beta and return (A, beta, r2_log).

static fit_inverse_power_scaling(x, y)[source]

Fit y = A x^{-alpha} and return (A, alpha, r2_log).

static fit_ipr_scaling(dimensions, iprs, q: float = 2.0)[source]

Fit IPR(N) = A N^{-alpha} and infer D_q = alpha/(q-1).

Returns:

(A, alpha, D_q, r2_log).

Return type:

tuple

fit_linear(skipF=0, skipL=0)[source]

Fits a linear function. - skipF : skip first arguments - skipL : skip last arguments

static fitLinear(x, y, skipF=0, skipL=0)[source]

Fits a linear function. - x : arguments - y : values - skipF : skip first arguments - skipL : skip last arguments

fit_exp(skipF=0, skipL=0)[source]

Fits [a * exp(-b * x)] - skipF : skip first arguments - skipL : skip last arguments

static fitExp(x, y, skipF=0, skipL=0)[source]

Fits [a * exp(-b * x)] - x : arguments - y : values - skipF : skip first arguments - skipL : skip last arguments

fit_x_plus_x2(skipF=0, skipL=0)[source]

Fits [a * x + b * x ** 2] - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

static fitXPlusX2(x, y, skipF=0, skipL=0)[source]

Fits [a * x + b * x ** 2] - x : arguments to the fit - y : values to fit - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

fit_power(skipF=0, skipL=0)[source]

Fits function [a*x**b] - x : arguments to the fit - y : values to the fit - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

static fitPower(x, y, skipF=0, skipL=0)[source]

Fits function [a*x**b] - x : arguments to the fit - y : values to the fit - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

fit_any(funct, skipF=0, skipL=0)[source]

Fits function [any] - funct : function to fit to - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

static fitAny(x, y, funct, skipF=0, skipL=0, bounds=[])[source]

Fits function [any] - x : arguments to fit - y : values to fit - funct : function to fit to - skipF : number of elements to skip on the left - skipR : number of elements to skip on the right

static gen_cauchy(x, v=1.0, gamma=1.0, alpha=1.0, beta=1.0)[source]

Generalized Cauchy distribution - v is the normalization factor - alpha is the stability parameter, often referred to as the shape parameter, - beta is the scale parameter, - gamma is a scale parameter related to the width of the distribution.

static cauchy(x, x0=0.0, gamma=1.0, v=1.0)[source]

Cauchy distribution - x : arguments - x0 : x0 parameter - gamma : gamma parameter

static pareto(x, v=1.0, alpha=1.0, xm=1.0, mu=0.0)[source]

Pareto distribution - x : arguments - alpha : alpha parameter - xm : xm parameter

static poisson(x, lambd=1.0, v=1.0)[source]

Poisson distribution - k : arguments - lamb : lambda parameter

static chi2(x, k=1.0, v=1.0, z=1.0)[source]

Chi2 distribution - x : arguments - k : k parameter

static gaussian(x, mu=0.0, sigma=1.0)[source]

Gaussian distribution - x : arguments - mu : mean - sigma : standard deviation

static laplace(x, lambd=1.0, v=1.0, mu=0.0)[source]

Laplace distribution - x : arguments - mu : mean - b : scale parameter

static exponential(x, lambd, sigma)[source]

Exponential distribution - x : arguments - lambd : lambda parameter

static lorentzian(x, v=1.0, g=1.0)[source]

Lorentzian distribution - x : arguments - v : multiplication constant - g : gamma parameter

static two_lorentzian(x, v=1.0, g1=1.0, g2=1.0, v2=1.0)[source]

Two Lorentzian distribution - x : arguments - x0 : x0 parameter - gamma : gamma parameter

static lorentzian_system_size(param)[source]

Return a Lorentzian model whose width is scaled by system size.

static fit_histogram(edges, counts, typek='gaussian', skipF=0, skipL=0, centers=[], params=[], bounds=None)[source]

Fit a parametric distribution to histogram bin edges and counts.

static get_histogram(typek='gaussian')[source]

Get the histogram function based on the type of distribution

general_python.maths.math_utils.next_power(x: float, base: int = 2)[source]

Get the next power of a number (base) that is greater than x - x : number to get the next power - base : base of the power (default 2 for binary, can be 10 for decimal)

general_python.maths.math_utils.prev_power(x: float, base: int = 2)[source]

Get the previous power of a number (base) that is smaller than x - x : number to get the next power - base : base of the power (default 2 for binary, can be 10 for decimal)

general_python.maths.math_utils.mod_euc(a: int, b: int) int[source]

Compute the modified Euclidean remainder of a divided by b.

This function ensures that the result has the same sign as b.

  • a : integer dividend

  • b : integer divisor

general_python.maths.math_utils.mod_floor(a: int, b: int) int[source]

Compute the modified floor division of a divided by b.

This function ensures that the result has the same sign as b.

  • a : integer dividend

  • b : integer divisor

general_python.maths.math_utils.mod_ceil(a: int, b: int) int[source]

Compute the modified ceiling division of a divided by b.

This function ensures that the result has the same sign as b.

  • a : integer dividend

  • b : integer divisor

general_python.maths.math_utils.mod_trunc(a: int, b: int) int[source]

Compute the modified truncation division of a divided by b.

This function ensures that the result has the same sign as a.

  • a : integer dividend

  • b : integer divisor

general_python.maths.math_utils.mod_round(a: int, b: int) int[source]

Compute the modified rounding division of a divided by b.

This function ensures that the result has the same sign as a.

  • a : integer dividend

  • b : integer divisor

Statistical analysis tools and data aggregation utilities.

This module contains: - Binning and averaging routines (bin_avg, rebin). - Histogram classes with support for merging and weighted averages. - Statistical moments, fluctuations, and distribution fitting. - Helpers for spectral function analysis (fractional statistics).

Numerical Stability

Numba-optimized routines (_bin_avg_numba) are provided for performance on large datasets. Robust handling of NaN values and empty bins is included in averaging functions.

class general_python.maths.statistics.Statistics[source]

Bases: object

Class for statistical operations - mean, median, etc.

static bin_avg(data, x, centers, delta=0.05, typical=False, cutoffNum=10, func=<function <lambda>>, verbose=False)[source]

Compute the bin average of data over multiple realizations.

This method aggregates data points that fall within [center - delta, center + delta] for each specified center. It supports both arithmetic mean and typical average (geometric mean).

Parameters:
  • data (np.ndarray) – Input data array. Shape: (n_realizations, n_points).

  • x (np.ndarray) – X-coordinates corresponding to the data. Shape: (n_realizations, n_points), matching data.

  • centers (np.ndarray) – Array of bin centers to compute averages at. Shape: (n_centers,).

  • delta (float, default=0.05) – Half-width of the bin interval. Bins are [c - delta, c + delta].

  • typical (bool, default=False) – If True, computes the typical average (geometric mean). Process: exp(mean(log(data))). The input data is assumed to be positive if this is set.

  • cutoffNum (int, default=10) – Minimum number of data points required in a bin to consider it valid without expansion. If the count is lower, the method attempts to expand the window to find nearest neighbors.

  • func (callable, optional) – Aggregation function to apply to values in a bin. Signature: func(values: array) -> scalar. Defaults to np.mean. Note: The Numba optimized path is only used if func is the default.

  • verbose (bool, default=False) – If True, prints debug information (currently unused).

Returns:

Array of averaged values corresponding to each center. Shape: (n_valid_centers,). Note: Invalid centers (where no data could be found even after expansion) are skipped in the output logic of the Numba path, but np.nan_to_num is applied at the end.

Return type:

np.ndarray

static rebin(arr, av_num: int, d: int, rng=None)[source]

Re-bin an array by averaging blocks of data.

This method reshapes the input array and computes the mean over blocks of size av_num. It randomly shuffles the array before binning to ensure unbiased sampling if the data is ordered.

Parameters:
  • arr (np.ndarray) – Input array to rebin. Shape depends on d.

  • av_num (int) – Number of elements to average into a single bin.

  • d (int) – Dimensionality of the array (1, 2, or 3). If d=1: expects shape (N,). If d=2: expects shape (N, M). If d=3: expects shape (N, M, K).

  • rng (np.random.Generator, optional) – Random number generator for shuffling. If None, a default generator is used.

Returns:

Re-binned array. The first dimension size is divided by av_num.

Return type:

np.ndarray

static permute(*args, rng=None)[source]

Apply the same random permutation to multiple arrays simultaneously.

Parameters:
  • *args (np.ndarray) – One or more arrays to permute. All arrays must have the same length along the first dimension.

  • rng (np.random.Generator, optional) – Random number generator.

Returns:

Tuple of permuted arrays.

Return type:

tuple

static calculate_fluctuations(signals, bin_size, axis=1)[source]

Calculate fluctuations around each signal within a bin, handling NaN values correctly, and keep the original dimensions.

Parameters:
  • signals (numpy.ndarray) – Input signals array of any shape.

  • bin_size (int) – Size of the bin for computing fluctuations.

  • axis (int) – Axis along which to calculate fluctuations.

Returns:

Fluctuations for each signal, same shape as input.

Return type:

numpy.ndarray

static get_cdf(x, y, gammaval=0.5, BINVAL=21)[source]

Calculate the cumulative distribution function (CDF) and find the gamma value.

Parameters: x (array-like): The independent variable values. y (array-like): The dependent variable values, which may contain NaNs. gammaval (float, optional): The target CDF value to find the corresponding gamma value. Default is 0.5.

Returns: tuple: A tuple containing:

  • x (array-like): The input independent variable values.

  • y (array-like): The input dependent variable values with NaNs removed.

  • cdf (array-like): The cumulative distribution function values.

  • gammaf (float): The value of the independent variable corresponding to the target CDF value.

static find_peak_and_interpolate(alphas, values)[source]

Find the peak value in the given data and interpolate to improve peak precision. This function removes NaN values from the input arrays, performs spline interpolation to improve the precision of the peak detection, and then finds the maximum value and its corresponding alpha. A fine-grained search around the peak is performed to find a more precise maximum. Parameters: - alphas (array-like): The array of alpha values. - values (array-like): The array of corresponding values. Returns: - tuple: A tuple containing the refined alpha value at the peak and the refined peak value.

general_python.maths.statistics.avgBin(myArray, N=2)[source]

Calculate the bin average of an array - myArray : array to average into bins - N : number of bins

general_python.maths.statistics.moveAverage(a, n: int)[source]

Moving average with cumsum or sliding window. This is applied along the first axis of the array.

Parameters:
  • a – Input data, can be a numpy array, pandas DataFrame, or list.

  • n – Window size for the moving average.

Returns:

A numpy array, pandas DataFrame, or list with the moving averages.

general_python.maths.statistics.fluctAboveAverage(a, n: int)[source]

Calculate fluctuations above the moving average.

Parameters:
  • a – Input data, can be a numpy array, pandas DataFrame, or list.

  • n – Window size for the moving average.

Returns:

Fluctuations above the moving average.

general_python.maths.statistics.removeMean(a, n: int, moving_average=[])[source]

Neglect average in data and leave fluctuations only

general_python.maths.statistics.gauss(x: ndarray, mu, sig, *args)[source]

Gaussian PDF

class general_python.maths.statistics.Histogram(n_bins: int | None = None, edges: Sequence[float] | None = None, dtype=None)[source]

Bases: object

A histogram class that stores the bin edges and the bin counts.

Convention:
  • The bins are defined by an array bin_edges of length (n_bins+1).

  • The first bin (index 0) collects all values below bin_edges[0] (underflow).

  • For 1 <= i < n_bins, bin i collects values in [bin_edges[i], bin_edges[i+1]).

  • The last bin (index n_bins) collects values greater than or equal to bin_edges[-1] (overflow).

__init__(n_bins: int | None = None, edges: Sequence[float] | None = None, dtype=None)[source]

Initialize the histogram with either a specified number of bins or specific edges. :param n_bins: Number of bins (if edges is None). :param edges: Specific bin edges (if n_bins is None). :param dtype: Data type for the bin edges.

Raises:

ValueError – If both n_bins and edges are None, or if edges is not a one-dimensional array with at least two elements.

Notes

  • If both n_bins and edges are None, a histogram with one bin (0 to 0) is created.

  • If edges are provided, the number of bins is determined from the length of edges.

  • The bin counts are initialized to zero.

set_histogram_counts(values: ndarray | Sequence[float | complex], set_bins: bool = True) None[source]

For the specified values, set the histogram counts. If set_bins is True, the bin edges will be determined from the minimum and maximum of the data. For complex-valued inputs, only the real part is used.

set_edges(edges: ndarray | Sequence[float]) None[source]

Set the bin edges from an array or list. The number of bins is set to len(edges)-1. :param edges: A one-dimensional array or list of bin edges.

Raises:

ValueError – If edges is not a one-dimensional array or list with at least two elements.

property edges: ndarray

Return the bin edges.

counts(i: int | None = None) uint64 | ndarray[source]

If i is provided, return the count for that bin index. Otherwise, return the full counts array.

counts_col() ndarray[source]

Return the counts as a column vector.

static iqr(data: ndarray | Sequence[float]) float[source]

Calculate the interquartile range (IQR) of the data. Splits the sorted data into two halves and computes the difference between the medians.

static freedman_diaconis_rule(n_obs: int, iqr_val: float, max_val: float, min_val: float = 0) int[source]

Calculate the number of bins using the Freedman-Diaconis rule.

reset(nbins: int = None) None[source]

Reset the histogram counts and (optionally) the bin edges to zero. Parameters: - nbins: If provided, reset the histogram with this number of bins.

static uniform(n_bins: int, v_max: float, v_min: float = 0) None[source]

Create a uniform distribution of bins between v_min and v_max. Parameters: - v_max: Maximum value for the histogram. - v_min: Minimum value for the histogram.

static uniform_log(n_bins: int, v_max: float, v_min: float = 1e-05, base: int = 10) ndarray[source]

Create a logarithmic distribution of bins between v_min and v_max. Parameters: - n_bins: Number of bins. - v_max: Maximum value for the histogram. - v_min: Minimum value for the histogram. - base: Logarithm base (default 10). Returns: - bin_edges: Array of bin edges (length n_bins + 1).

append(values) int[source]

Append a value to the histogram by determining its bin and incrementing the corresponding count. Returns the bin index. Parameters: - value: The value to append to the histogram. Returns: - bin indices: The indices of the bin where the value was appended.

merge(other: Histogram) None[source]

Merge another histogram into this one. The histograms must have the same number of bins and matching bin edges.

class general_python.maths.statistics.HistogramAverage(n_bins: int | None = None, edges: Sequence[float] | None = None, dtype=None)[source]

Bases: Histogram

Additional properties for the histogram class, adding bin averages. This class allows one to have a function f(x) averaged over the bins. The binAverages are the sum of the function evaluated in each bin, and they can be normalized by the bin counts.

__init__(n_bins: int | None = None, edges: Sequence[float] | None = None, dtype=None)[source]

Initialize the histogram with either a specified number of bins or specific edges. :param n_bins: Number of bins (if edges is None). :param edges: Specific bin edges (if n_bins is None). :param dtype: Data type for the bin edges.

Raises:

ValueError – If both n_bins and edges are None, or if edges is not a one-dimensional array with at least two elements.

Notes

  • If both n_bins and edges are None, a histogram with one bin (0 to 0) is created.

  • If edges are provided, the number of bins is determined from the length of edges.

  • The bin counts and averages are initialized to zero.

averages(i: int | None = None) float | ndarray[source]

Return the bin averages. If i is provided, return the average for that bin. Otherwise, return the full averages array.

averages_av(is_typical: bool = False) ndarray[source]

Get the average of the function over the bins normalized by the counts. If is_typical is True, exponentiate the normalized averages (useful if the averages represent logarithms).

reset(nbins=None) None[source]

Reset both the histogram counts and the bin averages.

append(values, elements) int[source]

Append a value to the histogram and add the corresponding element to the bin average.

Parameters:
  • values (-) – The value to append to the histogram.

  • elements (-) – The element to add to the bin average.

Returns:

The indices of the bin where the value was appended.

Return type:

  • bin_idx

merge(other: HistogramAverage) None[source]

Merge another HistogramAverage into this one. Warning: The histograms must have the same number of bins and matching bin edges.

add(sums: ndarray, counts: ndarray) None[source]

Add precomputed sums and counts to the histogram averages and counts. :param - sums: Array of sums to add to the bin averages. :param - counts: Array of counts to add to the bin counts.

Raises:

- ValueError – If the shapes of sums or counts do not match the histogram’s bin counts.

remove(sums: ndarray, counts: ndarray) None[source]

Remove precomputed sums and counts from the histogram averages and counts. :param - sums: Array of sums to remove from the bin averages. :param - counts: Array of counts to remove from the bin counts.

Raises:

- ValueError – If the shapes of sums or counts do not match the histogram’s bin counts.

class general_python.maths.statistics.Fraction[source]

Bases: object

Class to handle fractions.

static diag_cut(fraction: float, size: int) int[source]

Calculate the number of states to take based on a fraction of the total size. Parameters: fraction : fraction of the total size to take. size : total size of the Hilbert space. Returns: The number of states to take.

static around_idx(l: int, r: int, idx: int, size: int) Tuple[int, int][source]

Get the specific indices in a range around a given index in the Hilbert space. Checks for boundaries.

Parameters: l : number of elements to the left of idx. r : number of elements to the right of idx. idx : center index. size : total size of the Hilbert space.

Returns: A tuple (min_index, max_index) with the allowed index range.

static take_fraction(frac: float, data: ndarray, around=None, fraction_left=0.5, fraction_right=0.5, around_idx=None) list[source]

Take a fraction of the data.

Parameters: frac (float) : The fraction of the data to take. If frac is less than 1.0, it is treated as a fraction of the total data size.

If frac is greater than 1.0, it is treated as the number of elements to take.

data (list) : The list of data from which to take the fraction. around (float, optional) : The index around which to take the fraction. If None, it defaults to half the size of the data. fraction_left (float, optional) : The fraction of the left side to take. Default is 0.5. fraction_right (float, optional): The fraction of the right side to take. Default is 0.5. around_idx (int, optional) : The index around which to take the fraction. If None, it defaults to half the size of the data.

Returns: list: A list containing the central portion of the original data, based on the specified fraction.

If the calculated number of elements to take is less than or equal to 1, or equal to the size of the data, the original data is returned.

static is_close_target(l: float, r: float, target: float = 0.0, tol: float = 0.0015) bool[source]

Check if the average of two energies (l and r) is within tol of a target energy. :param l: first energy. :param r: second energy.

Returns:

True if the average is close to the target, False otherwise.

static is_difference_close_target(l: float, r: float, target: float = 0.0, tol: float = 0.0015) bool[source]

Check if the absolute energy difference between l and r is within tol of a target difference.

static is_fraction_difference_between(l: float, r: float, min_val: float, max_val: float) bool[source]

Check if the absolute energy difference between l and r lies between min_val and max_val.

static hs_fraction_offdiag(mn: int, max_val: int, hilbert_size: int, energies: ndarray, target_en: float = 0.0, tol: float = 0.0015, sort: bool = True) List[Tuple[float, int, int]][source]

Get the off-diagonal Hilbert-space fraction information.

Iterates over the energy spectrum (from index mn to max_val) and for each pair (i, j) with j > i, if the average energy is within tol of target_en then store a tuple of (energy difference, j, i) in the output list. Finally, sort the list by the energy difference (first element).

Parameters: mn : starting index (inclusive). max_val : ending index (exclusive). hilbert_size: size of the Hilbert space (not used in computation here, but kept for consistency). energies : 1D NumPy array of energies. target_en : target energy for the mean. tol : tolerance for closeness. sort : whether to sort the output list by the energy difference.

Returns: A list of tuples (omega, j, i) sorted by omega if sort is True.

static spectral_function_fraction(x, target, tolerance, tolerance_function=<staticmethod(<function Fraction.is_close_target>)>)[source]

Calculate the spectral function fraction based on the given target and tolerance.

static find_nearest_idx(array, value)[source]

Find the index of the nearest value in an array.

Random Matrix Theory and specialized sampling utilities.

This module complements algebra.ran_wrapper by providing specific ensembles like Circular Unitary Ensemble (CUE) matrices via QR decomposition.

Input/Output Contracts

  • CUE_QR: Returns a unitary complex matrix of shape (n, n).

Numerical Stability

The QR method for CUE is generally stable but phase adjustment is needed for true Haar measure compliance (controlled by simple=False).

general_python.maths.random.CUE_QR(n: int, simple=True, rng=None)[source]

Create the CUE matrix using QR decomposition. - n : size of the matrix (n X n) - simple: use the straightforward method

Machine Learning Module

Machine-learning entry points for neural-network workflows.

The package collects model registries, scheduler utilities, and concrete network implementations used in supervised and variational experiments.

Purpose

Use this namespace to obtain model constructors and training-time utilities without importing every backend-specific implementation eagerly.

Input/output contracts

Public factories typically accept model identifiers, shape metadata (for example input_shape=(n_features,)), and optional dtype and seed arguments. Returned objects are model instances or callables compatible with the training helpers in general_python.ml.training_phases.

dtype and shape expectations

Input batches are conventionally rank-2 arrays with shape (batch, features) unless a model documents an image or sequence layout. For stable optimization, float32 is the practical default on accelerators, while float64 may be required for high-precision experiments.

Numerical stability and determinism

Training trajectories depend on initialization, optimizer state, and operation ordering. For reproducibility, fix random seeds, keep backend/device constant, and avoid mixing precision policies within one experiment.

general_python.ml.networks

Network Factory and Registry.

This module provides a centralized factory function, choose_network, for instantiating various neural network architectures used in the general_python framework. It uses a lazy-loading mechanism to improve startup performance.

Usage

Import and use the factory to create a network. The factory takes the network type and common parameters like input_shape and dtype. Network-specific parameters are passed as keyword arguments.

from general_python.ml.networks import choose_network

# Create an RBM using ‘alpha’ (hidden unit density) rbm_net = choose_network(

‘rbm’, input_shape=(10,), alpha=2.0, # Creates 2*10=20 hidden units dtype=’complex64’

)

# Create a CNN for an 8x8 lattice cnn_net = choose_network(

‘cnn’, input_shape=(64,), reshape_dims=(8, 8), features=[8, 16], kernel_sizes=[3, 3]

)

Date : 01.10.2025 Description : Factory for creating neural network instances. ———————————————————-

class general_python.ml.networks.Networks(*values)[source]

Bases: str, Enum

Enum class for available standard network architectures. Inherits from str to allow string comparison.

SIMPLE = 'simple'
RBM = 'rbm'
CNN = 'cnn'
AR = 'ar'
RESNET = 'resnet'
general_python.ml.networks.choose_network(network_type: str | Networks | Type[Any] | Any, input_shape: tuple | None = None, backend: str = 'jax', dtype: Any = None, param_dtype: Any = None, seed: int | None = None, **kwargs) GeneralNet[source]

Smart factory to instantiate a network.

This factory can create networks by name (e.g., ‘rbm’, ‘cnn’), wrap raw Flax modules, or instantiate custom GeneralNet subclasses. It handles network-specific arguments passed via **kwargs and provides conveniences like alpha for RBMs.

Parameters:
  • network_type (Union[str, Networks, Type, Any]) –

    • String/Enum : ‘rbm’, ‘cnn’, ‘simple’, ‘ar’. The factory will lazy-load and instantiate the corresponding class.

    • Flax Module Class : A raw flax.linen.nn.Module class. The factory will automatically wrap it in a FlaxInterface.

    • GeneralNet Class : A class inheriting from GeneralNet. The factory will instantiate it.

    • Instance : If an already-initialized network instance is passed, it is returned as-is.

  • input_shape (Optional[tuple]) – The shape of the input to the network, e.g., (n_spins,).

  • backend (str) – The computational backend to use (‘jax’ or ‘numpy’). Defaults to ‘jax’.

  • dtype (Any) – The data type for the network’s computations (e.g., ‘float32’, ‘complex64’).

  • param_dtype (Any) – The data type for the network’s parameters. If None, defaults to dtype.

  • seed (Optional[int]) – Random seed for network initialization, if applicable.

  • **kwargs – Network-specific keyword arguments. See below for details on each network type.

  • Modules (Using Custom Flax)

  • -------------------------

  • network_type. (You can pass your own flax.linen.nn.Module class as the)

  • general_python.ml.net_impl.interface_net_flax.FlaxInterface (The factory will wrap it in a)

  • ecosystem. (to make it compatible with the general_python)

  • module (**Requirements for your custom)

  • nn.Module. (1. It must be a valid)

  • `(batch (typically with shape) –

  • input. (n_visible)` JAX array as)

  • wavefunction (3. It should return the log-amplitude of the)

  • `(batch

  • )`.

  • **Example

    >>> import flax.linen as nn
    >>> import jax.numpy as jnp
    >>>
    >>> class MyCustomNet(nn.Module):
    ...     @nn.compact
    ...     def __call__(self, s):
    ...         # A simple dense layer
    ...         x = nn.Dense(features=128)(s)
    ...         x = nn.relu(x)
    ...         log_psi = nn.Dense(features=1)(x)
    ...         return jnp.squeeze(log_psi, axis=-1)
    >>>
    >>> # The factory handles the wrapping
    >>> custom_net = choose_network(
    ...     MyCustomNet,
    ...     input_shape=(100,),
    ...     dtype='complex64' # Passed to the interface
    ... )
    >>> print(custom_net)
    

Example

The following are examples of how to create different network types using the factory.

  1. Restricted Boltzmann Machine (RBM) for NQS.

The RBM is a single-layer dense network. It connects all visible spins to a layer of hidden units. It is the standard baseline for NQS.

from general_python.ml.networks import choose_network

# 1. Define RBM Parameters # ———————— # An alpha (density) of 2 means: n_hidden = 2 * n_visible rbm_params = {

‘input_shape’: (100,), # 100 spins ‘alpha’: 2, # Density of hidden units ‘use_bias’: True, ‘dtype’: ‘complex128’ # Essential for quantum phases

}

# 2. Create the Network # ——————— # ‘rbm’ key triggers the RBM class factory net = choose_network(‘rbm’, **rbm_params)

# 3. Initialize & Run # ——————- # Initialize with a random key (handled internally or explicitly) # params = net.init(jax.random.PRNGKey(0)) # log_psi = net(params, sample_configuration)

  1. Convolutional Neural Network (CNN) for Lattice Systems.

A deep architecture that respects the locality of physical interactions. Essential for 2D frustrated systems (like Kitaev or J1-J2 models) where local correlations are complex.

Features: - Periodic Boundary Conditions (Torus geometry). - Sum Pooling: Ensures energy is extensive (scales with N). - Complex Weights: Captures the sign structure of the wavefunction.

from general_python.ml.networks import choose_network import jax.numpy as jnp

# 1. Define CNN Parameters # ———————— cnn_params = {

‘input_shape’: (n_sites,), ‘reshape_dims’: (L, L), # Reshape 1D input to 2D grid ‘features’: (16, 32, 64), # Deep network with increasing channels ‘kernel_sizes’: ((3,3), (3,3), (3,3)), ‘activations’: [‘lncosh’] * 3, # Holomorphic activation ‘periodic’: True, # Wrap edges (Torus) ‘sum_pooling’: True, # Sum output over all spatial sites ‘dtype’: ‘complex128’

}

# 2. Create the Network # ——————— net = choose_network(‘cnn’, **cnn_params)

# 3. Debug/Check # ————– print(f”Total Parameters: {net.nparams}”) # > Total Parameters: ~25k (Complex) Keyword Args (by network_type) ——————————

Keyword Arguments:
  • Callable]]) (- input_activation (Optional[Union[str,) – Activation function applied to the input layer. Useful for preprocessing inputs (e.g., scaling or encoding).

  • 'rbm'** (**For)

  • (float) (- alpha)

  • (int) (- depth)

  • (bool) (- sum_pooling)

  • (bool)

  • 'cnn'** (**For)

  • ...]) (- reshape_dims (Tuple[int,)

  • (Sequence[int]) (- features)

  • Tuple]]) (- strides (Sequence[Union[int,)

  • Tuple]])

  • ...])

  • Callable]]]) (- activations (Union[str, Sequence[Union[str,)

  • (bool)

  • (bool)

  • 'simple'** (**For)

  • ...])

  • ...])

  • Callable],...]) (- act_fun (Tuple[Union[str,)

  • (Autoregressive)** (**For 'ar')

  • (int)

  • (int)

  • (str) (- rnn_type)

  • Network)** (**For 'res' or 'resnet' (Residual)

  • ...])

  • (int)

  • (int)

  • Tuple[int,...]]) (- kernel_size (Union[int,)

Returns:

An initialized or wrapped network instance compatible with the general_python framework. The returned object is callable with signature: log_psi = net(params, inputs) where inputs has shape (batch, features) and output has shape (batch,).

Return type:

GeneralNet

Scheduler implementations for machine learning training. It includes various learning rate schedulers and an early stopping mechanism.

Namely, it provides: - ConstantScheduler - ExponentialDecayScheduler - StepDecayScheduler - CosineAnnealingScheduler - AdaptiveScheduler (ReduceLROnPlateau)

For the usage, either create scheduler instances directly or use the choose_scheduler factory function.

>>> # Example: Create an exponential decay scheduler
>>> from general_python.ml.schedulers import choose_scheduler
>>> scheduler = choose_scheduler('exponential', initial_lr=0.01, max_epochs=100, lr_decay=0.1)
>>> for epoch in range(10):
>>>     lr = scheduler(epoch)
>>>     print(f"Epoch {epoch}: Learning Rate = {lr:.6f}")

email : maksymilian.kliczkowski@pwr.edu.pl

class general_python.ml.schedulers.BaseSchedulerLogger(logger: Logger | None)[source]

Bases: ABC

Abstract Base Class providing logging capabilities to schedulers.

__init__(logger: Logger | None)[source]
property logger: Logger | None

Logger used for scheduler diagnostics, if one is configured.

class general_python.ml.schedulers.EarlyStopping(patience: int = 0, min_delta: float = 0.001, logger: Logger | None = None)[source]

Bases: BaseSchedulerLogger

Monitors a metric and determines if training should stop.

__init__(patience: int = 0, min_delta: float = 0.001, logger: Logger | None = None)[source]
__call__(_metric: float | complex | number) bool[source]
Parameters:

_metric – The metric value (e.g. loss). Real part used if complex.

classmethod from_kwargs(**kwargs)[source]

Create an EarlyStopping instance from common keyword names.

reset()[source]

Reset the stored best metric and patience counter.

property best_metric: float

Best metric observed since initialization or last reset.

class general_python.ml.schedulers.SchedulerType(*values)[source]

Bases: Enum

Supported learning-rate scheduler families.

CONSTANT = 0
EXPONENTIAL = 1
STEP = 2
COSINE = 3
ADAPTIVE = 4
LINEAR = 5
class general_python.ml.schedulers.Parameters(initial_lr: float, max_epochs: int, lr_decay: float, lr_clamp: float | None = None, logger: Logger | None = None, es: EarlyStopping | None = None)[source]

Bases: BaseSchedulerLogger, ABC

Base class for stateful learning-rate schedules.

__init__(initial_lr: float, max_epochs: int, lr_decay: float, lr_clamp: float | None = None, logger: Logger | None = None, es: EarlyStopping | None = None)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

abstractmethod __call__(_epoch: int, _metric: Any | None = None) float[source]

Calculate LR for the epoch.

property lr: float

Most recently emitted learning rate.

property history: List[float]

Learning-rate values emitted so far.

property early_stopping: EarlyStopping | None

Attached early-stopping monitor, if configured.

set_early_stopping(patience: int, min_delta: float = 0.001)[source]

Attach a new early-stopping monitor to this scheduler.

check_stop(_metric) bool[source]

Return whether the attached early-stopping monitor requests stop.

class general_python.ml.schedulers.ConstantScheduler(initial_lr: float, max_epochs: int, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

Scheduler that always returns the initial learning rate.

__init__(initial_lr: float, max_epochs: int, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

class general_python.ml.schedulers.ExponentialDecayScheduler(initial_lr: float, max_epochs: int, lr_decay: float = 0.99, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

Multiplicative exponential decay: lr = initial_lr * (gamma ^ epoch)

This follows PyTorch’s ExponentialLR convention where: - gamma = 0.99 means 1% decay per epoch - gamma = 0.999 means 0.1% decay per epoch

For the old exp(-rate * epoch) behavior, use gamma = exp(-rate).

__init__(initial_lr: float, max_epochs: int, lr_decay: float = 0.99, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

class general_python.ml.schedulers.LinearScheduler(initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

Linearly decays LR from initial_lr to min_lr (default 0) over max_epochs.

__init__(initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

class general_python.ml.schedulers.StepDecayScheduler(initial_lr: float, max_epochs: int, lr_decay: float, step_size: int, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

lr = initial_lr * decay_factor ^ floor(epoch / step_size)

__init__(initial_lr: float, max_epochs: int, lr_decay: float, step_size: int, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

class general_python.ml.schedulers.CosineAnnealingScheduler(initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

Cosine annealing schedule from initial_lr to min_lr.

__init__(initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

class general_python.ml.schedulers.AdaptiveScheduler(initial_lr: float, max_epochs: int, lr_decay: float = 0.1, patience: int = 100, min_lr: float = 1e-05, cooldown: int = 0, min_delta: float = 0.0001, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Bases: Parameters

ReduceLROnPlateau logic

__init__(initial_lr: float, max_epochs: int, lr_decay: float = 0.1, patience: int = 100, min_lr: float = 1e-05, cooldown: int = 0, min_delta: float = 0.0001, lr_clamp=None, logger=None, es=None, **kwargs)[source]

Base class for learning rate schedulers. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param lr_decay: Decay rate (meaning depends on scheduler type). :param lr_clamp: Minimum learning rate clamp. :param logger: Optional logger for logging. :param es: Optional EarlyStopping instance.

reset()[source]

Reset plateau-tracking state while keeping scheduler configuration.

general_python.ml.schedulers.choose_scheduler(scheduler_type: str | SchedulerType | Parameters, initial_lr: float, max_epochs: int, logger: Logger | None = None, **kwargs) Parameters[source]

Factory to create a scheduler instance.

This function can either accept: - A string or SchedulerType enum to specify the type of scheduler to create. - An existing Parameters instance to reconfigure. :param scheduler_type: Type of scheduler or existing instance. :param initial_lr: Initial learning rate. :param max_epochs: Maximum number of epochs. :param logger: Optional logger for the scheduler. :param **kwargs: Additional arguments specific to the scheduler type.

  • lr_decay: Decay rate for exponential/step/adaptive schedulers.

  • step_size: Step size for step scheduler.

  • min_lr: Minimum learning rate for cosine/linear/adaptive schedulers.

  • patience: Patience for adaptive scheduler.

  • cooldown: Cooldown period for adaptive scheduler.

  • min_delta: Minimum improvement for adaptive scheduler.

  • early_stopping_patience: Patience for early stopping.

  • early_stopping_min_delta: Minimum improvement for early stopping.

Returns:

An instance of Parameters (scheduler).

Raises:

ValueError – If the scheduler type is unknown.

Learning phase framework for Neural Quantum State training.

This module implements a multi-phase training system for NQS, allowing: - Phase transitions with configurable parameters - Phase-specific callbacks and hooks - Adaptive learning rates per phase - Regularization scheduling per phase - Progress tracking and reporting

Learning phases represent different stages of optimization:

  1. Pre-training: Initialize network with simple loss, high learning rate

  2. Main Optimization: Full Hamiltonian, adaptive learning rate

  3. Refinement: Fine-tune observables, low learning rate, high regularization

Quick Start

Using Presets:

>>> from general_python.ml.training_phases import create_phase_schedulers
>>> lr_sched, reg_sched = create_phase_schedulers('default')
>>> # Pass to NQSTrainer: phases=(lr_sched, reg_sched)

Creating Custom Phases:

>>> from general_python.ml.training_phases import LearningPhase, PhaseType, PhaseScheduler
>>>
>>> my_phases = [
...     LearningPhase(
...         name="warmup", epochs=50,
...         lr=0.1, lr_schedule="exponential", lr_kwargs={'lr_decay': 0.05},
...         reg=0.01
...     ),
...     LearningPhase(
...         name="main", epochs=300,
...         lr=0.02, lr_schedule="adaptive", lr_kwargs={'patience': 20, 'lr_decay': 0.5},
...         reg=0.001
...     ),
... ]
>>> lr_sched = PhaseScheduler(my_phases, param_type='lr')
>>> reg_sched = PhaseScheduler(my_phases, param_type='reg')

Available Scheduler Types

  • 'constant': Fixed value

  • 'exponential': Exponential decay: lr * exp(-decay * epoch)

  • 'step': Step decay: lr * gamma^floor(epoch/step_size)

  • 'cosine': Cosine annealing to min_lr

  • 'linear': Linear decay to min_lr

  • 'adaptive': ReduceLROnPlateau (requires loss)

Available Presets

  • 'default': 3-phase (pre_training: 50, main: 200, refinement: 100)

  • 'kitaev': Specialized for frustrated spin systems (pre: 100, main: 300, fine: 150)

Email : maksymilian.kliczkowski@pwr.edu.pl Date : November 1, 2025 —————————————-

class general_python.ml.training_phases.PhaseType(*values)[source]

Bases: Enum

Semantic categories for training phases.

PRE_TRAINING = 1
MAIN = 2
REFINEMENT = 3
CUSTOM = 4
class general_python.ml.training_phases.LearningPhase(name: str = 'phase', epochs: int = 100, phase_type: PhaseType = PhaseType.MAIN, lr: float = 0.01, lr_schedule: str = 'constant', lr_kwargs: Dict[str, ~typing.Any]=<factory>, reg: float = 0.001, reg_schedule: str = 'constant', reg_kwargs: Dict[str, ~typing.Any]=<factory>, loss_type: str = 'energy', beta_penalty: float = 0.0, on_phase_start: Callable | None = None, on_phase_end: Callable | None = None)[source]

Bases: object

Configuration for a specific training phase.

Each phase defines learning rate and regularization schedules that are active for a specific number of epochs. Phases are processed sequentially by the PhaseScheduler.

name

Human-readable phase identifier (e.g., ‘warmup’, ‘main’, ‘fine’).

Type:

str

epochs

Number of epochs this phase lasts.

Type:

int

phase_type

Semantic type (PRE_TRAINING, MAIN, REFINEMENT, CUSTOM).

Type:

PhaseType

lr

Initial learning rate for this phase.

Type:

float

lr_schedule

Scheduler type for LR. Options: - ‘constant’: Fixed lr throughout phase - ‘exponential’: lr * exp(-lr_decay * local_epoch) - ‘step’: lr * lr_decay^floor(local_epoch/step_size) - ‘cosine’: Cosine annealing from lr to min_lr - ‘linear’: Linear decay from lr to min_lr - ‘adaptive’: ReduceLROnPlateau (requires loss)

Type:

str

lr_kwargs

Extra arguments for the LR scheduler. Common keys: - ‘lr_decay’: Decay rate (exponential, step, adaptive) - ‘step_size’: Steps between decays (step scheduler) - ‘min_lr’: Minimum LR (cosine, linear, adaptive) - ‘patience’: Epochs before reduction (adaptive) - ‘min_delta’: Minimum improvement threshold (adaptive)

Type:

Dict[str, Any]

reg

Initial regularization (diagonal shift) for this phase.

Type:

float

reg_schedule

Scheduler type for regularization. Same options as lr_schedule.

Type:

str

reg_kwargs

Extra arguments for the regularization scheduler.

Type:

Dict[str, Any]

loss_type

Loss function type (default: ‘energy’).

Type:

str

beta_penalty

Penalty coefficient for excited state targeting.

Type:

float

on_phase_start

Callback executed when phase begins.

Type:

Callable, optional

on_phase_end

Callback executed when phase ends.

Type:

Callable, optional

Examples

>>> # Exponential decay warmup
>>> warmup = LearningPhase(
...     name='warmup', epochs=50,
...     lr=0.1, lr_schedule='exponential', lr_kwargs={'lr_decay': 0.05},
...     reg=0.01
... )
>>>
>>> # Adaptive main phase (ReduceLROnPlateau)
>>> main = LearningPhase(
...     name='main', epochs=300,
...     lr=0.02, lr_schedule='adaptive',
...     lr_kwargs={'patience': 20, 'lr_decay': 0.5, 'min_lr': 1e-4},
...     reg=0.001
... )
>>>
>>> # Cosine annealing refinement
>>> refine = LearningPhase(
...     name='fine', epochs=100,
...     lr=0.01, lr_schedule='cosine', lr_kwargs={'min_lr': 1e-5},
...     reg=0.005
... )
name: str = 'phase'
epochs: int = 100
phase_type: PhaseType = 2
lr: float = 0.01
lr_schedule: str = 'constant'
lr_kwargs: Dict[str, Any]
reg: float = 0.001
reg_schedule: str = 'constant'
reg_kwargs: Dict[str, Any]
loss_type: str = 'energy'
beta_penalty: float = 0.0
on_phase_start: Callable | None = None
on_phase_end: Callable | None = None
__init__(name: str = 'phase', epochs: int = 100, phase_type: PhaseType = PhaseType.MAIN, lr: float = 0.01, lr_schedule: str = 'constant', lr_kwargs: Dict[str, ~typing.Any]=<factory>, reg: float = 0.001, reg_schedule: str = 'constant', reg_kwargs: Dict[str, ~typing.Any]=<factory>, loss_type: str = 'energy', beta_penalty: float = 0.0, on_phase_start: Callable | None = None, on_phase_end: Callable | None = None) None
class general_python.ml.training_phases.PhaseScheduler(phases: List[LearningPhase], param_type: str = 'lr', logger=None)[source]

Bases: object

Manages transitions between training phases.

The PhaseScheduler orchestrates multi-phase training by: 1. Tracking the current phase based on global epoch count 2. Instantiating appropriate low-level schedulers for each phase 3. Firing callbacks on phase transitions 4. Returning scheduled values via __call__

Parameters:
  • phases (List[LearningPhase]) – Ordered list of training phases to execute.

  • param_type (str, default='lr') – Which parameter to schedule (‘lr’ or ‘reg’).

  • logger (Logger, optional) – Logger for phase transition messages.

current_phase

Currently active phase.

Type:

LearningPhase

history

All scheduled values returned.

Type:

List[float]

Examples

>>> from general_python.ml.training_phases import LearningPhase, PhaseScheduler
>>>
>>> phases = [
...     LearningPhase(name='warmup', epochs=50, lr=0.1, lr_schedule='exponential',
...                   lr_kwargs={'lr_decay': 0.05}),
...     LearningPhase(name='main', epochs=200, lr=0.02, lr_schedule='constant'),
... ]
>>>
>>> lr_scheduler = PhaseScheduler(phases, param_type='lr')
>>> reg_scheduler = PhaseScheduler(phases, param_type='reg')
>>>
>>> # Use in training loop
>>> for epoch in range(250):
...     lr = lr_scheduler(epoch, loss=current_loss)  # Auto phase transition
...     reg = reg_scheduler(epoch, loss=current_loss)
__init__(phases: List[LearningPhase], param_type: str = 'lr', logger=None)[source]
property current_phase: LearningPhase

Currently active learning phase.

If all configured phases are exhausted, the final phase is returned so callers can continue querying terminal settings.

__call__(global_epoch: int, loss: float = None) float[source]

Delegates calculation to the specific scheduler instance.

general_python.ml.training_phases.create_phase_schedulers(preset: str = 'default', logger=None)[source]

Factory function to create LR and Reg schedulers from a preset.

Parameters:
  • preset (str, default='default') – Preset name. Available: - ‘default’: 3-phase training (350 total epochs) - ‘kitaev’: Specialized for frustrated systems (550 total epochs)

  • logger (Logger, optional) – Logger for scheduler messages.

Returns:

(lr_scheduler, reg_scheduler) tuple.

Return type:

Tuple[PhaseScheduler, PhaseScheduler]

Raises:

ValueError – If preset name is not recognized.

Examples

>>> lr_sched, reg_sched = create_phase_schedulers('default')
>>>
>>> # Pass to NQSTrainer
>>> trainer = NQSTrainer(nqs, phases=(lr_sched, reg_sched))
>>>
>>> # Or use preset string directly
>>> trainer = NQSTrainer(nqs, phases='default')  # Equivalent

Physics Module

Physics toolkit for quantum and statistical computations.

The package groups modules for density matrices, entropy, operators, response functions, spectral functions, and thermal/statistical observables.

Input/output contracts

Functions generally accept array-like states, operators, or spectra and return NumPy or JAX-compatible arrays or scalar observables. Shape requirements follow physics conventions, e.g. state vectors (d,), operators (d, d), and grids for momentum or frequency response evaluations.

Backend expectations

NumPy implementations are broadly available. JAX-specific modules are optional and loaded lazily when dependencies are installed.

Numerical stability and determinism

Entropy and spectral routines include tolerance or regularization knobs to reduce instability near zero eigenvalues or narrow broadenings. Reproducibility depends on deterministic eigensolver settings and fixed random seeds in upstream code.

general_python.physics.list_capabilities()[source]

List available physics capabilities and modules.

This module contains functions for manipulating and analyzing density matrices in quantum mechanics. It provides optimized implementations using NumPy and Numba for computing reduced density matrices, Schmidt decompositions, and entanglement spectra.

QES Convention: - State vector index i = s0 + d*s1 + d^2*s2 + … (Little-endian / Fortran order) - Subsystem A site order: [a0, a1, a2, …] -> Row index I = sa0 + d*sa1 + … - Subsystem B site order: [b0, b1, b2, …] -> Col index J = sb0 + d*sb1 + … where d is the local_dim.

Fermionic Systems: For fermionic systems mapped via Jordan-Wigner transformation, non-local string operators create additional correlations between non-contiguous subsystem sites. The fermionic=True flag applies sign corrections that account for the fermionic exchange statistics when reordering sites. This ensures correct reduced density matrices for arbitrary subsystem geometries.

version : 2.1 copyright : (c) 2026 by Maksymilian Kliczkowski. All rights reserved. ——————————–

general_python.physics.density_matrix.mask_subsystem(va: int | ndarray | List[int], ns: int, local_dim: int = 2, contiguous: bool = False) Tuple[Tuple[int, int], Tuple[int, ...]][source]

Process the subsystem specification to extract site indices and the permutation order. The order tuple specifies how to permute the state vector to bring subsystem A sites to the front.

Parameters:
  • va (Union[int, np.ndarray, List[int]]) – Subsystem specification. Can be: - An integer (if contiguous=True) specifying the number of contiguous sites in A starting from site 0. - A bitmask integer where bits set to 1 indicate sites in A (if contiguous=False). - A list or array of site indices in A.

  • ns (int) – Total number of sites in the system.

  • local_dim (int) – Local Hilbert space dimension (default is 2 for qubits).

  • contiguous (bool) – If True, treat va as the number of contiguous sites in A starting from site 0. If False, treat va as a bitmask or list of site indices.

general_python.physics.density_matrix.psi_numpy(state: ndarray, order: Tuple[int, ...], size_a: int, ns: int, local_dim: int = 2, fermionic: bool = False) ndarray[source]

Reshape and reorder a quantum state vector into a matrix Psi_{A,B} using NumPy. This representation is used to compute the reduced density matrix rho_A = Psi @ Psi^dagger.

Parameters:
  • state (np.ndarray) – The input state vector of shape (local_dim**ns,).

  • order (Tuple[int, ...]) – The permutation order of sites to bring subsystem A sites to the front.

  • size_a (int) – The number of sites in subsystem A.

  • ns (int) – Total number of sites in the system.

  • local_dim (int) – Local Hilbert space dimension (default is 2 for qubits).

  • fermionic (bool) – If True, apply fermionic sign corrections for site permutation. This accounts for the anticommutation of fermionic operators when reordering sites, essential for correct RDMs of non-contiguous subsystems in fermionic systems mapped via Jordan-Wigner.

Returns:

Reshaped state matrix Psi of shape (dA, dB) where dA = local_dim^size_a and dB = local_dim^(ns - size_a).

Return type:

np.ndarray

Notes

For fermionic systems (fermionic=True):

When we permute the site ordering, fermionic operators anticommute:

c_i c_j = -c_j c_i

For a basis state |n_0, n_1, …, n_{ns-1}> represented as:

c_0^{n_0} c_1^{n_1} … c_{ns-1}^{n_{ns-1}} |vacuum>

Reordering the sites requires swapping creation operators, each swap of two occupied sites contributes a factor of -1.

The total sign is (-1)^{number of inversions in occupied sites}.

general_python.physics.density_matrix.rho_numpy(state: ndarray, size_a: int, ns: int, local_dim: int = 2, order: Tuple[int, ...] | None = None, fermionic: bool = False) ndarray[source]

Compute reduced density matrix using NumPy with Fortran-order convention. Works for any local_dim (fermionic mode requires local_dim=2).

Parameters:
  • state (np.ndarray) – The input state vector of shape (local_dim**ns,).

  • size_a (int) – The number of sites in subsystem A.

  • ns (int) – Total number of sites in the system.

  • local_dim (int) – Local Hilbert space dimension (default is 2 for qubits).

  • order (Optional[Tuple[int, ...]]) – The permutation order of sites to bring subsystem A sites to the front. If None, assumes natural order.

  • fermionic (bool) – If True, apply fermionic sign corrections. See psi_numpy for details.

Returns:

Reduced density matrix rho_A of shape (dA, dA).

Return type:

np.ndarray

general_python.physics.density_matrix.rho(state: ndarray, va: int | List[int] | ndarray, ns: int | None = None, local_dim: int = 2, contiguous: bool = False, fermionic: bool = False, *, la: int | None = None, order: Tuple[int, ...] | None = None) ndarray[source]

Compute the reduced density matrix rho_A of a subsystem A.

Parameters:
  • state (np.ndarray) – The input state vector of shape (local_dim**ns,).

  • va (Union[int, List[int], np.ndarray]) –

    Subsystem specification. Can be: - An integer: if contiguous=True, number of sites in A starting from site 0.

    if contiguous=False, bitmask where bit i=1 means site i is in A.

    • A list/array of site indices in subsystem A (any geometry).

  • ns (Optional[int]) – Total number of sites. If None, inferred from state size.

  • local_dim (int) – Local Hilbert space dimension (default 2 for qubits/fermions).

  • contiguous (bool) – If True, treat integer va as number of contiguous sites from 0.

  • fermionic (bool) –

    If True, apply fermionic sign corrections for non-contiguous subsystems. Essential for correct entanglement entropy of Jordan-Wigner mapped fermions.

    For fermionic systems, permuting sites requires accounting for the anticommutation of creation operators. Each pair of occupied sites that gets inverted in the permutation contributes a factor of -1.

    Use fermionic=True when: - Computing RDM of non-contiguous subsystem in a fermionic system - Working with Slater determinants or their superpositions - Comparing with correlation matrix entropy results

    For contiguous subsystems (sites 0,1,…,k-1), the fermionic flag has no effect since no site permutation is needed.

  • la (Optional[int]) – Deprecated alias for specifying contiguous subsystem size.

  • order (Optional[Tuple[int, ...]]) – Explicit site permutation order. If provided, overrides va.

Returns:

Reduced density matrix rho_A of shape (dA, dA) where dA = local_dim^|A|.

Return type:

np.ndarray

Examples

>>> # Contiguous subsystem (first 3 sites)
>>> rho_A = rho(psi, va=3, ns=8, contiguous=True)
>>> # Non-contiguous subsystem [0, 2, 4] for fermions
>>> rho_A = rho(psi, va=[0, 2, 4], ns=8, fermionic=True)
>>> # Bitmask specification: sites 0 and 2 (binary 101 = 5)
>>> rho_A = rho(psi, va=5, ns=4)
general_python.physics.density_matrix.schmidt(state: ndarray, va: int | List[int] | ndarray, ns: int | None = None, local_dim: int = 2, contiguous: bool = False, fermionic: bool = False, eig: bool = False, square: bool = True, *, sub_size: int | None = None, order: Tuple[int, ...] | None = None, return_vecs: bool = True) Tuple[ndarray, Any][source]

Compute the Schmidt decomposition of a state vector.

For a bipartition of the system into subsystems A and B, the Schmidt decomposition expresses the state as:

|psi> = sum_k lambda_k |phi_k>_A |chi_k>_B

Parameters:
  • state (np.ndarray) – The input state vector of shape (local_dim**ns,).

  • va (Union[int, List[int], np.ndarray]) – Subsystem A specification (see rho() for details).

  • ns (Optional[int]) – Total number of sites. If None, inferred from state size.

  • local_dim (int) – Local Hilbert space dimension (default 2).

  • contiguous (bool) – If True, treat integer va as number of contiguous sites.

  • fermionic (bool) – If True, apply fermionic sign corrections for site permutation. See rho() for detailed explanation.

  • eig (bool) – If True, use RDM eigendecomposition instead of SVD. SVD (default) is generally faster and more numerically stable.

  • square (bool) – If True (default), return squared singular values (= RDM eigenvalues). If False, return singular values directly.

  • sub_size (Optional[int]) – Deprecated alias for contiguous subsystem size.

  • order (Optional[Tuple[int, ...]]) – Explicit site permutation order.

  • return_vecs (bool) – If True, return Schmidt vectors along with values.

Returns:

  • If return_vecs=True – Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray], np.ndarray] (schmidt_values, (U, Vh), psi_matrix)

    For SVD path: U[i,k] are left singular vectors, Vh[k,j] are right. For eig path: (eigenvectors, None), rho_A

  • If return_vecs=False – np.ndarray Schmidt values only (squared singular values if square=True).

Examples

>>> # Get Schmidt spectrum for fermionic non-contiguous subsystem
>>> s_sq = schmidt(psi, va=[0, 2, 4], ns=8, fermionic=True, return_vecs=False)
>>> entropy = -np.sum(s_sq * np.log(s_sq + 1e-15))
general_python.physics.density_matrix.rho_spectrum(rho_mat: ndarray, eps: float = 1e-15) ndarray[source]

Compute the eigenvalue spectrum of a density matrix.

Parameters:
  • rho_mat (np.ndarray) – Density matrix (Hermitian, positive semi-definite).

  • eps (float) – Threshold for filtering small eigenvalues.

Returns:

Sorted eigenvalues (descending) above threshold eps.

Return type:

np.ndarray

general_python.physics.density_matrix.rho_single_site(state: ndarray, site: int, ns: int, local_dim: int = 2, fermionic: bool = False) ndarray[source]

Compute the single-site reduced density matrix.

Parameters:
  • state (np.ndarray) – Many-body state vector.

  • site (int) – Site index to trace out all others except this site.

  • ns (int) – Total number of sites.

  • local_dim (int) – Local Hilbert space dimension (default 2).

  • fermionic (bool) – If True, apply fermionic sign corrections.

Returns:

Single-site RDM of shape (local_dim, local_dim).

Return type:

np.ndarray

general_python.physics.density_matrix.rho_two_sites(state: ndarray, site_i: int, site_j: int, ns: int, local_dim: int = 2, fermionic: bool = False) ndarray[source]

Compute the two-site reduced density matrix.

Parameters:
  • state (np.ndarray) – Many-body state vector.

  • site_i (int) – Site indices for the two-site subsystem.

  • site_j (int) – Site indices for the two-site subsystem.

  • ns (int) – Total number of sites.

  • local_dim (int) – Local Hilbert space dimension (default 2).

  • fermionic (bool) – If True, apply fermionic sign corrections. Important when site_i and site_j are non-adjacent.

Returns:

Two-site RDM of shape (local_dim^2, local_dim^2).

Return type:

np.ndarray

Allows one to parse physical operators created in the simulation.

class general_python.physics.operators.Spectral[source]

Bases: object

For the spectral statistics and its properties. Allows one to take a fraction of the data and calculate the mean of the data.

static diagonal_cutoff(Nh, nu, minimal_frac=0.1)[source]

The diagonal cutoff for the spectral statistics - Ns: system size - Nh: Hilbert space dimension - nu: number of eigenvalues - if it’s less than 1, then it’s a fraction of the Hilbert space dimension - minimal_frac: minimal fraction of the Hilbert space dimension Returns: - The diagonal cutoff for the spectral statistics

static take_fraction(nu: float, data, middle=None)[source]

Take a fraction of the data for the spectral statistics. - nu: number of eigenvalues - if it’s less than 1, then it’s a fraction of the Hilbert space dimension - data: data to take the fraction from - middle: middle of the data (if None, then it’s the middle of the data) Returns: - The fraction of the data

static take_fraction_arr(nu: float, data, middle=None)[source]

Take a fraction of the data for the spectral statistics. - nu: number of eigenvalues - if it’s less than 1, then it’s a fraction of the Hilbert space dimension - data: data to take the fraction from - middle: middle of the data (if None, then it’s the middle of the data) Returns: - The fraction of the data

static mean_fraction(nu: float, data, middle=None, axis=0)[source]

Take a fraction of the data for the spectral statistics. Calculates the mean of the data. - nu: number of eigenvalues - if it’s less than 1, then it’s a fraction of the Hilbert space dimension - data: data to take the fraction from - middle: middle of the data (if None, then it’s the middle of the data) - axis: axis to calculate the mean Returns: - The mean of the fraction of the data

class general_python.physics.operators.Operators[source]

Bases: object

Operator class that contains the method to parse Operator names both many-body and in single particle sector.

OPERATOR_SEP = '/'
OPERATOR_SEP_CORR = '-'
OPERATOR_SEP_MULT = ','
OPERATOR_SEP_DIFF = 'm'
OPERATOR_SEP_RANGE = ':'
OPERATOR_SEP_RANDOM = 'r'
OPERATOR_SEP_DIV = '_'
OPERATOR_PI = 'pi'
OPERATOR_SITE = 'l'
OPERATOR_SITEU = 'L'
OPERATOR_SITE_M_1 = True
OPERATOR_ED_LIMIT = 16
static resolve_hilbert(Ns, local_hilbert=2)[source]

Resolves the Hilbert space dimension based on the inner limit – OPERATOR_ED_LIMIT

static resolveSite(site: str, _dimension=1)[source]

Parses and resolves a site identifier string into a numeric index or value.

This method handles special symbolic representations used in operator strings, such as ‘L’ (system size), ‘pi’, and arithmetic operations like division or subtraction.

Parameters:
  • site (str) – String representation of the site (e.g., “0”, “L”, “L_2”, “L-1”).

  • _dimension (int, default=1) – The dimension or size of the system (L), used to resolve ‘L’ or ‘l’.

Returns:

The resolved numeric value of the site or parameter.

Return type:

int or float

Examples

>>> Operators.resolveSite("0", 10)
0
>>> Operators.resolveSite("L", 10)  # If OPERATOR_SITE_M_1 is True (default)
9
>>> Operators.resolveSite("L_2", 10) # L / 2
5
static resolve_operator(f_name, dimension)[source]

Resolves a full operator string by parsing the site component.

For an operator string like “Sz/L_2”, this splits the string by the operator separator (‘/’), resolves the site part (“L_2” -> dimension/2), and reconstructs the string.

Parameters:
  • f_name (str) – Operator name string.

  • dimension (int) – System dimension/size.

Returns:

Resolved operator string (e.g., “Sz/5”).

Return type:

str

static name2title(f_name)[source]

Convert an encoded operator name into a LaTeX-style title.

general_python/physics/statistical.py

Statistical analysis utilities for quantum systems.

This module provides tools for: - Finite window averages and time series analysis - Local density of states (LDOS) and strength functions - Spectral histograms and binning - Windowed matrix element calculations - Generic histogram and scatter analysis

Author: Maksymilian Kliczkowski Email: maksymilian.kliczkowski@pwr.edu.pl

general_python.physics.statistical.moving_average(data: ndarray | Array, window_size: int, mode: str = 'valid') ndarray | Array[source]

Compute moving average with specified window size.

Parameters:
  • data (array-like) – Input data array.

  • window_size (int) – Size of the moving window.

  • mode (str, optional) – Convolution mode: ‘valid’ (default), ‘same’, or ‘full’.

Returns:

Moving average array.

Return type:

Array

Examples

>>> data = np.random.randn(100)
>>> smooth = moving_average(data, window_size=5)
general_python.physics.statistical.windowed_variance(data: ndarray | Array, window_size: int, ddof: int = 1) Tuple[ndarray | Array, ndarray | Array][source]

Compute moving mean and variance over a sliding window.

Parameters:
  • data (array-like) – Input data array.

  • window_size (int) – Size of the moving window.

  • ddof (int, optional) – Delta degrees of freedom for variance calculation (default: 1).

Returns:

  • means (Array) – Moving mean.

  • variances (Array) – Moving variance.

general_python.physics.statistical.exponential_moving_average(data: ndarray | Array, alpha: float) ndarray | Array[source]

Compute exponential moving average.

EMA[i] = alpha * data[i] + (1 - alpha) * EMA[i-1]

Parameters:
  • data (array-like) – Input data array.

  • alpha (float) – Smoothing factor in (0, 1]. Higher values give more weight to recent data.

Returns:

Exponential moving average.

Return type:

Array

Examples

>>> data    = np.random.randn(100)
>>> ema     = exponential_moving_average(data, alpha=0.3)
general_python.physics.statistical.window_mask(values: ndarray | Array, center: float, width: float) ndarray | Array[source]

Create a boolean mask for values within [center - width/2, center + width/2].

Parameters:
  • values (array-like) – Array of values.

  • center (float) – Center of the window.

  • width (float) – Width of the window.

Returns:

Boolean mask where True indicates values within the window.

Return type:

Array (bool)

general_python.physics.statistical.centered_window(data: ndarray | Array, center_idx: int, window_size: int) ndarray | Array[source]

Extract a centered window from data around a given index.

Parameters:
  • data (array-like) – Input data array.

  • center_idx (int) – Center index of the window.

  • window_size (int) – Total size of the window (must be odd for symmetric centering).

Returns:

Windowed data array.

Return type:

Array

general_python.physics.statistical.fractional_window(data: ndarray | Array, fraction: float = 0.3, around: float | None = None) ndarray | Array[source]

Extract a fractional window of data centered around a value.

Parameters:
  • data (array-like) – Sorted input data.

  • fraction (float) – Fraction of data to extract (0 < fraction <= 1).

  • around (float, optional) – Value to center the window around. If None, uses median.

Returns:

Windowed subset of data.

Return type:

Array

Examples

>>> energies = np.linspace(-10, 10, 1000)
>>> window = fractional_window(energies, fraction=0.2, around=0.0)
>>> # Returns ~200 energies centered near 0
general_python.physics.statistical.extract_indices_window(start: int, stop: int, eigvals: ndarray, energy_target: float = 0.0, bandwidth: float = 1.0, energy_diff_cut: float = 0.015, whole_spectrum: bool = False) Tuple[ndarray, int][source]

Extract indices of eigenvalue pairs (i, j) where |(E_i + E_j)/2 - E_target| < tolerance.

Optimized for computing matrix elements within energy windows, e.g., for structure factors, transition amplitudes, or response functions.

Parameters:
  • start (int) – Index range to consider in eigvals.

  • stop (int) – Index range to consider in eigvals.

  • eigvals (ndarray) – Sorted eigenvalues (ascending or descending).

  • energy_target (float) – Target energy for the window center.

  • bandwidth (float) – Bandwidth scale factor.

  • energy_diff_cut (float) – Relative tolerance: actual tolerance = bandwidth * energy_diff_cut.

  • whole_spectrum (bool) – If True, skip windowing and return empty indices.

Returns:

  • indices_alloc (ndarray of shape (N, 3)) – Each row: (i, j_start, j_end) where j in [j_start, j_end) satisfies the window.

  • count (int) – Number of valid index triplets.

Notes

Assumes eigvals is sorted. The function efficiently finds pairs within the energy window by exploiting sorted order, avoiding O(N^2) naive search.

general_python.physics.statistical.ldos(energies: ndarray | Array, overlaps: ndarray | Array, degenerate: bool = False, tol: float = 1e-08) ndarray | Array[source]

Local Density of States (LDOS) or strength function.

LDOS_i = |<i|psi >|^2 (non-degenerate) LDOS_i = sum _{j:|E_j - E_i|<tol} |<j|psi >|^2 (degenerate)

Parameters:
  • energies (array-like) – Eigenenergies E_n, shape (N,).

  • overlaps (array-like) – Overlap amplitudes <n|psi >, shape (N,).

  • degenerate (bool, optional) – Whether to sum over nearly degenerate levels (default: False).

  • tol (float, optional) – Tolerance for degeneracy grouping (default: 1e-8).

Returns:

LDOS for each energy index.

Return type:

Array

Notes

Use JAX version (ldos_jax) for better performance on GPU/TPU.

general_python.physics.statistical.ldos_jax(energies: ndarray | Array, overlaps: ndarray | Array, degenerate: bool = False, tol: float = 1e-08) ndarray | Array[source]

JAX-optimized Local Density of States (LDOS).

Parameters:
  • energies (Array) – Eigenenergies.

  • overlaps (Array) – Overlap amplitudes <n|psi >.

  • degenerate (bool) – If True, sum over nearly degenerate levels.

  • tol (float) – Tolerance for degeneracy grouping.

Returns:

LDOS for each energy index.

Return type:

Array

general_python.physics.statistical.create_bins(n_bins: int, range_min: float, range_max: float, log_scale: bool = False) ndarray[source]

Create bin edges for histogramming.

Parameters:
  • n_bins (int) – Number of bins.

  • range_min (float) – Range of the bins.

  • range_max (float) – Range of the bins.

  • log_scale (bool, optional) – If True, create logarithmically spaced bins (default: False).

Returns:

Bin edges of length n_bins + 1.

Return type:

ndarray

Examples

>>> bins = create_bins(50, 0.0, 10.0)
>>> log_bins = create_bins(50, 1e-3, 10.0, log_scale=True)

general_python/physics/thermal.py

Thermal physics utilities for quantum systems.

Provides general-purpose functions for: - Partition functions and statistical sums - Thermal averages and expectation values - Thermodynamic quantities (free energy, entropy, heat capacity) - Magnetic and charge susceptibilities - Boltzmann weights and probability distributions

Author : Maksymilian Kliczkowski Email : maksymilian.kliczkowski@pwr.edu.pl

general_python.physics.thermal.partition_function(energies: ndarray | Array, beta: float) float[source]

Compute the canonical partition function Z(beta) = sum _n exp(-beta E_n).

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Partition function Z(beta).

Return type:

float

Examples

>>> energies = np.array([0.0, 1.0, 2.0, 3.0])
>>> Z = partition_function(energies, beta=1.0)
general_python.physics.thermal.boltzmann_weights(energies: ndarray | Array, beta: float, normalize: bool = True) ndarray | Array[source]

Compute Boltzmann weights rho_n = exp(-beta E_n) / Z.

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

  • normalize (bool, optional) – If True, normalize by partition function Z (default: True).

Returns:

Boltzmann weights (probabilities if normalized).

Return type:

Array

Examples

>>> energies = np.array([0.0, 1.0, 2.0])
>>> rho = boltzmann_weights(energies, beta=1.0)
>>> print(np.sum(rho))  # Should be 1.0
general_python.physics.thermal.thermal_average_diagonal(energies: ndarray | Array, observable_diagonal: ndarray | Array, beta: float) Tuple[float, float][source]

Compute thermal average of an operator diagonal in the energy basis.

<O>_beta = Tr[rho O] / Z = sum _n O_nn exp(-beta E_n) / Z

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • observable_diagonal (array-like) – Diagonal matrix elements O_nn of the observable.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

  • average (float) – Thermal average <O>_beta.

  • partition_func (float) – Partition function Z(beta).

Examples

>>> energies = np.array([0.0, 1.0, 2.0])
>>> magnetization = np.array([1.0, 0.5, -0.5])
>>> avg_M, Z = thermal_average_diagonal(energies, magnetization, beta=1.0)
general_python.physics.thermal.thermal_average_general(energies: ndarray | Array, eigenvectors: ndarray | Array, observable_matrix: ndarray | Array | spmatrix, beta: float) Tuple[float, float][source]

Compute thermal average of a general operator.

<O>_beta = sum _n <n|O|n> exp(-beta E_n) / Z

where |n> are energy eigenstates.

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • eigenvectors (array-like) – Matrix of eigenvectors (columns are eigenstates).

  • observable_matrix (array-like or sparse matrix) – Operator matrix in the original basis.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

  • average (float) – Thermal average <O>_beta.

  • partition_func (float) – Partition function Z(beta).

Notes

This function transforms the observable to the energy basis and computes the diagonal elements <n|O|n>.

general_python.physics.thermal.free_energy(energies: ndarray | Array, beta: float) float[source]

Compute Helmholtz free energy F = -k_B T ln Z = -(1/beta) ln Z.

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Free energy F.

Return type:

float

Notes

We set k_B = 1 (natural units).

general_python.physics.thermal.internal_energy(energies: ndarray | Array, beta: float) float[source]

Compute internal energy U = <H> = sum _n E_n exp(-beta E_n) / Z.

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Internal energy U.

Return type:

float

general_python.physics.thermal.heat_capacity(energies: ndarray | Array, beta: float) float[source]

Compute heat capacity C_V = beta^2 (<H^2> - <H>^2).

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Heat capacity C_V.

Return type:

float

Notes

Uses the fluctuation-dissipation relation.

general_python.physics.thermal.entropy_thermal(energies: ndarray | Array, beta: float) float[source]

Compute thermal entropy S = k_B (ln Z + beta U) = beta(U - F).

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Thermal entropy S.

Return type:

float

Notes

We set k_B = 1 (natural units).

general_python.physics.thermal.magnetic_susceptibility(energies: ndarray | Array, magnetization_diagonal: ndarray | Array, beta: float) float[source]

Compute magnetic susceptibility chi_M = beta (<M^2> - <M>^2).

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • magnetization_diagonal (array-like) – Diagonal matrix elements M_nn of magnetization operator.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Magnetic susceptibility chi_M.

Return type:

float

Notes

This is the linear response of magnetization to applied field.

general_python.physics.thermal.charge_susceptibility(energies: ndarray | Array, charge_diagonal: ndarray | Array, beta: float) float[source]

Compute charge susceptibility chi_c = beta (<N^2> - <N>^2).

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • charge_diagonal (array-like) – Diagonal matrix elements N_nn of particle number operator.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Charge susceptibility chi_c.

Return type:

float

Notes

Related to compressibility via chi_c = beta <(delta N)^2>.

general_python.physics.thermal.specific_heat_from_moments(avg_H: float, avg_H2: float, beta: float) float[source]

Compute specific heat from energy moments: C_V = beta^2 (<H^2> - <H>^2).

Parameters:
  • avg_H (float) – Average energy <H>.

  • avg_H2 (float) – Average energy squared <H^2>.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Specific heat C_V.

Return type:

float

general_python.physics.thermal.susceptibility_from_moments(avg_O: float, avg_O2: float, beta: float) float[source]

Generic susceptibility from moments: chi = beta (<O^2> - <O>^2).

Parameters:
  • avg_O (float) – Average observable <O>.

  • avg_O2 (float) – Average observable squared <O^2>.

  • beta (float) – Inverse temperature beta = 1/(k_B T).

Returns:

Susceptibility chi.

Return type:

float

general_python.physics.thermal.thermal_scan(energies: ndarray | Array, temperatures: ndarray | Array, observables: dict | None = None) dict[source]

Scan thermal quantities over a range of temperatures.

Parameters:
  • energies (array-like) – Eigenenergies E_n.

  • temperatures (array-like) – Array of temperatures T.

  • observables (dict, optional) – Dictionary of observable names to diagonal elements. Example: {‘M_z’: magnetization_diagonal, ‘N’: charge_diagonal}

Returns:

Dictionary containing: - ‘T’ : temperatures - ‘beta’ : inverse temperatures - ‘F’ : free energies - ‘U’ : internal energies - ‘S’ : entropies - ‘C_V’ : heat capacities - For each observable: average and susceptibility

Return type:

dict

Examples

>>> energies = np.array([0.0, 1.0, 2.0])
>>> temps = np.linspace(0.1, 10.0, 100)
>>> observables = {'M': np.array([1.0, 0.0, -1.0])}
>>> results = thermal_scan(energies, temps, observables)
>>> plt.plot(results['T'], results['C_V'])

Eigenlevel statistics and entropy calculators for quantum systems.

This module contains tools for: - Reduced density matrix calculation (direct or Schmidt decomposition). - Entanglement entropy (von Neumann). - Level statistics (gap ratios). - Statistical measures of eigenstates (participation ratio, moments).

Input/Output Contracts

  • States are typically 1D or 2D NumPy arrays (basis size, number of states).

  • Reduced density matrices return shape (dimA, dimA).

  • Entropies return scalar floats.

  • Gap ratios return a dictionary with mean, std, and raw values.

Numerical Stability

Entropy calculations handle small eigenvalues by clipping or conditional checks to avoid log(0). Schmidt decomposition is preferred for reduced density matrices when possible for stability and efficiency.

general_python.physics.eigenlevels.reduced_density_matrix(state: ndarray, A_size: int, L: int)[source]

Calculate the reduced density matrix out of a state.

general_python.physics.eigenlevels.reduced_density_matrix_schmidt(state: ndarray, L: int, La: int)[source]

Calculates the reduced density matrix via the Schmidt decomposition.

general_python.physics.eigenlevels.entropy_vonNeuman(state: ndarray, L: int, La: int, TYP='SCHMIDT')[source]

Calculate the bipartite entanglement entropy.

general_python.physics.eigenlevels.gap_ratio(en: ndarray, fraction=0.3, use_mean_lvl_spacing=True)[source]
Calculate the gap ratio of the eigenvalues as:

$gamma = frac{min(Delta_n, Delta_{n+1})}{max(Delta_n, Delta_{n+1})}$

  • en : eigenvalues

  • fraction : fraction of the eigenvalues to use

  • use_mean_lvl_spacing : divide by mean level spacing

general_python.physics.eigenlevels.mean_entropy(df: <MagicMock id = '130824349382992'>, row: int)[source]

Calculate the average entropy in a given DataFrame. - df : DataFrame with entropies - row : row number (-1 for half division of a system)

class general_python.physics.eigenlevels.HamiltonianProperties[source]

Bases: object

Namespace for analytic Hamiltonian and spectral-property helpers.

static hilbert_schmidt_norm(mat: ndarray)[source]

Creates the Hilbert-Schmidt norm of the matrix. :param mat: matrix to calculate the norm of :type mat: np.ndarray

Returns:

The Hilbert-Schmidt norm of the matrix.

Return type:

_type_

class general_python.physics.eigenlevels.StatMeasures[source]

Bases: object

Namespace for statistical measures of spectra and eigenstates.

static moments(arr: ndarray, axis=None)[source]

Calculate the moments of the array. - arr : array to calculate the moments - axis : axis to calculate the moments

static gaussianity(arr: ndarray, axis=None)[source]

Calculate the gaussianity <|Oab|^2>/<|Oab|>^2 -> for normal == pi/2 - arr : array to calculate the gaussianity - axis : axis to calculate the gaussianity

static binder_cumulant(arr: ndarray, axis=None)[source]

Calculate the binder cumulant <|Oab|^4>/(3 * <|Oab|^2>^2) -> for normal == 2/3 - arr : array to calculate the binder cumulant

static modulus_fidelity(states: ndarray)[source]

Calculate the modulus fidelity - should be 2/pi for gauss. - states : np.array of eigenstates

general_python.physics.eigenlevels.info_entropy(states: ndarray, model_info: str)[source]

Calculate the information entropy for given states.

This file contains the EntanglementModule class.

email : maksymilian.kliczkowski@pwr.edu.pl

Unified entanglement calculation module for both quadratic and many-body Hamiltonians.

Features

  • Single-particle correlation matrix methods (fast, for quadratic/non-interacting Hamiltonians)

  • Many-body reduced density matrix methods (exact, for any state)

  • Arbitrary bipartitions (contiguous and non-contiguous subsystems)

  • Multipartite entropy calculations (topological entanglement entropy)

  • Wick’s theorem verification for quadratic systems

  • JAX backend for GPU acceleration

  • Mask generation utilities for subsystem selection

Theoretical Background

For quadratic (non-interacting) Hamiltonians, the entanglement entropy can be computed efficiently from the single-particle correlation matrix C_ij = <c_i^dag c_j>:

S = -Tr[C log C + (1-C) log(1-C)]

This scales as O(L^3) compared to O(2^L) for exact diagonalization.

For many-body states, we use Schmidt decomposition of the wavefunction:

|psi> = sum_i sqrt(lambda_i) |i_A> |i_B> S = -sum_i lambda_i log(lambda_i)

Topological Entanglement Entropy (TEE)

For topological phases, the entanglement entropy follows:

S(A) = alpha * L - gamma + O(1/L)

where gamma is the topological entanglement entropy. Using Kitaev-Preskill or Levin-Wen constructions:

gamma = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC

Examples

Basic usage with quadratic Hamiltonians:
>>> hamil   = QuadraticHamiltonian(ns=12, ...)
>>> hamil.diagonalize()
>>> ent     = hamil.entanglement
>>>
>>> # Define bipartition and calculate entropy
>>> bipart      = ent.bipartition([0, 1, 2, 3])
>>> orbitals    = [0, 1, 2, 3, 4, 5]  # occupied states
>>> S = ent.entropy_correlation(bipart, orbitals)
Access correlation matrices:
>>> C_full      = ent.correlation_matrix(orbitals)
>>> C_A         = ent.correlation_matrix(orbitals, bipartition=bipart)
Batch calculations:
>>> results     = ent.entropy_multipartition(
...     bipartitions=[[0,1], [0,1,2], [0,1,2,3]],
...     occupied_orbitals=orbitals
... )
>>> entropies               = results['entropies']
>>> correlation_matrices    = results['correlation_matrices']
JAX backend for GPU acceleration:
>>> S_jax = ent.entropy_correlation(bipart, orbitals, backend='jax')
>>> C_jax = ent.correlation_matrix(orbitals, backend='jax')
Mask generation utilities:
>>> masks = MaskGenerator.contiguous(ns=12, size_a=4)  # First 4 sites
>>> masks = MaskGenerator.alternating(ns=12)           # Even/odd sites
>>> masks = MaskGenerator.random(ns=12, size_a=6)      # Random 6 sites
>>> masks = MaskGenerator.kitaev_preskill(ns=12)       # ABC regions for TEE
Topological entanglement entropy:
>>> gamma = ent.topological_entropy(orbitals, construction='kitaev_preskill')
Wick’s theorem verification:
>>> is_valid, error = ent.verify_wicks_theorem(orbitals, state)
Manual many-body entropy calculations:
>>> bipart      = ent.bipartition([0, 1, 2, 3])
>>> S_manual    = ent.entropy_correlation(bipart, orbitals)
class general_python.physics.entanglement_module.MaskGenerator[source]

Bases: object

Utility class for generating subsystem masks for entanglement calculations.

Provides convenient methods to create site masks for various bipartition geometries, including contiguous, alternating, random, and topological (Kitaev-Preskill) constructions.

Examples

Basic contiguous mask:
>>> mask_a = MaskGenerator.contiguous(ns=12, size_a=4)
>>> print(mask_a)       # array([0, 1, 2, 3])
Alternating (even/odd) sites:
>>> mask_even, mask_odd = MaskGenerator.alternating(ns=12)
>>> print(mask_even)    # array([0, 2, 4, 6, 8, 10])
Random subsystem:
>>> mask = MaskGenerator.random(ns=12, size_a=6, seed=42)
For topological entanglement entropy (Kitaev-Preskill construction):
>>> regions = MaskGenerator.kitaev_preskill(ns=12)
>>> A, B, C = regions['A'], regions['B'], regions['C']
static contiguous(ns: int, size_a: int, start: int = 0) ndarray[source]

Create a contiguous subsystem mask [start, start+1, …, start+size_a-1].

Parameters:
  • ns (int) – Total number of sites

  • size_a (int) – Size of subsystem A

  • start (int) – Starting site index (default: 0)

Returns:

Array of site indices in subsystem A

Return type:

np.ndarray

static alternating(ns: int, offset: int = 0) Tuple[ndarray, ndarray][source]

Create alternating (even/odd) site masks.

Parameters:
  • ns (int) – Total number of sites

  • offset (int) – Offset for even sites (0 = sites 0,2,4,…; 1 = sites 1,3,5,…)

Returns:

(mask_even, mask_odd) site index arrays

Return type:

Tuple[np.ndarray, np.ndarray]

static every_n(ns: int, n: int, start: int = 0) ndarray[source]

Create a mask selecting every n-th site.

Parameters:
  • ns (int) – Total number of sites

  • n (int) – Step size (select every n-th site)

  • start (int) – Starting site index (default: 0)

Returns:

Array of site indices selected every n-th site

Return type:

np.ndarray

static random(ns: int, size_a: int, seed: int | None = None) ndarray[source]

Create a random subsystem mask.

Parameters:
  • ns (int) – Total number of sites

  • size_a (int) – Size of subsystem A

  • seed (int, optional) – Random seed for reproducibility

Returns:

Sorted array of randomly selected site indices

Return type:

np.ndarray

static periodic_interval(ns: int, start: int, size_a: int) ndarray[source]

Create a contiguous mask with periodic boundary conditions.

Parameters:
  • ns (int) – Total number of sites

  • start (int) – Starting site index

  • size_a (int) – Size of subsystem A

Returns:

Sorted array of site indices (wrapped around if necessary)

Return type:

np.ndarray

static sublattice(ns: int, sublattice_id: int = 0, n_sublattices: int = 2) ndarray[source]

Create a sublattice mask (e.g., A/B sublattices in bipartite lattices).

Parameters:
  • ns (int) – Total number of sites

  • sublattice_id (int) – Which sublattice (0, 1, …, n_sublattices-1)

  • n_sublattices (int) – Total number of sublattices (default: 2 for bipartite)

Returns:

Array of site indices in the specified sublattice

Return type:

np.ndarray

static kitaev_preskill(ns: int, center: int | None = None) Dict[str, ndarray][source]

Generate regions A, B, C for Kitaev-Preskill topological entanglement entropy.

The Kitaev-Preskill construction divides the system into three regions meeting at a point. The topological entanglement entropy is:

gamma = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC

Parameters:
  • ns (int) – Total number of sites (should be divisible by 3 for equal regions)

  • center (int, optional) – Central site index (default: ns // 2)

Returns:

Dictionary with keys ‘A’, ‘B’, ‘C’, ‘AB’, ‘BC’, ‘AC’, ‘ABC’ containing site index arrays for each region

Return type:

Dict[str, np.ndarray]

Notes

For 1D chains, regions are consecutive thirds of the chain. For 2D systems, you should define regions based on geometry.

References

  • Kitaev & Preskill, PRL 96, 110404 (2006)

  • Levin & Wen, PRL 96, 110405 (2006)

static levin_wen_disk(ns: int, n_annuli: int = 3) Dict[str, ndarray][source]

Generate annular regions for Levin-Wen construction.

For a disk geometry, creates concentric annuli to extract topological entanglement entropy with area law subtraction.

Parameters:
  • ns (int) – Total number of sites

  • n_annuli (int) –

    Number of concentric annuli (default: 3).

    What are annuli? - 1 annulus: inner region only - 2 annuli: inner + middle regions - 3 annuli: inner + middle + outer regions

    Therefore, the annuli represent nested regions of increasing size.

    Example: - n_annuli=1: ‘inner’ region

    • S_inner = alpha * L_inner - gamma

    • n_annuli=2: ‘inner’, ‘middle’, ‘inner_middle’ regions
      • S_inner = alpha * L_inner - gamma

      • S_middle = alpha * L_middle - gamma

      • S_inner_middle = alpha * (L_inner + L_middle) - gamma

Returns:

Dictionary with ‘inner’, ‘middle’, ‘outer’, and combined regions

Return type:

Dict[str, np.ndarray]

static from_bitmask(mask_int: int, ns: int) ndarray[source]

Convert an integer bitmask to an array of site indices.

Parameters:
  • mask_int (int) – Integer whose bits indicate included sites

  • ns (int) – Total number of sites

Returns:

Array of site indices where bits are set

Return type:

np.ndarray

Example

>>> MaskGenerator.from_bitmask(0b1010, ns=4)
array([1, 3])
static to_bitmask(indices: ndarray) int[source]

Convert an array of site indices to an integer bitmask.

Parameters:

indices (np.ndarray) – Array of site indices

Returns:

Integer bitmask with bits set at specified positions

Return type:

int

Example

>>> MaskGenerator.to_bitmask(np.array([1, 3]))
10  # = 0b1010
class general_python.physics.entanglement_module.BipartitionInfo(mask_a: ndarray, mask_b: ndarray, size_a: int, size_b: int, order: tuple, extractor_a: Callable, extractor_b: Callable)[source]

Bases: object

Information about a bipartition of the system.

mask_a: ndarray
mask_b: ndarray
size_a: int
size_b: int
order: tuple
extractor_a: Callable
extractor_b: Callable
__init__(mask_a: ndarray, mask_b: ndarray, size_a: int, size_b: int, order: tuple, extractor_a: Callable, extractor_b: Callable) None
class general_python.physics.entanglement_module.EntanglementModule(operator)[source]

Bases: object

Entanglement calculation module for Hamiltonians.

Provides unified interface for calculating entanglement entropy using: - Single-particle correlation matrices (quadratic Hamiltonians, fast)

  • To be optimized when system sizes are large, we probably don’t want to compute np.arange(ns) as it can be large!

  • Many-body reduced density matrices (any state, exact)
    • Features:
      • Wick’s theorem verification

      • Topological entanglement entropy (Kitaev-Preskill, Levin-Wen)

      • Manual bipartition handling

      • Multipartite entropy calculations

      • Symmetry sector support

  • JAX backend for GPU acceleration

  • Batch calculations for multiple bipartitions

Automatically handles arbitrary bipartitions including non-contiguous subsystems (those are more problematic but handled here).

Examples

  1. Quadratic Hamiltonian (non-interacting):
    • Basic entropy calculation:

>>> hamil       = QuadraticHamiltonian(ns=12, ...)
>>> hamil.diagonalize()
>>> ent         = hamil.entanglement
>>> bipart      = ent.bipartition([0, 1, 2, 3])                             # subsystem A
>>> orbitals    = [0, 1, 2, 3, 4, 5]                                        # occupied quasi-particle states
>>> S           = ent.entropy_correlation(bipart, orbitals)                 # entropy from correlation matrix
  • Access correlation matrices themselves:

>>> C_full      = ent.correlation_matrix(orbitals)                          # (ns, ns)
>>> C_A         = ent.correlation_matrix(orbitals, bipartition=bipart)      # (4, 4)
  • Batch calculations:

>>> results     = ent.entropy_multipartition(
...             [[0,1], [0,1,2], [0,1,2,3]],
...             orbitals
...             )                                                           # computes entropies for 3 bipartitions
>>> entropies   = results['entropies']                                      # array of 3 entropies for each bipartition
>>> C_matrices  = results['correlation_matrices']                           # list of 3 matrices
  • JAX backend:

>>> S_jax       = ent.entropy_correlation(bipart, orbitals, backend='jax')
>>> results_jax = ent.entropy_multipartition(
...             [[0,1], [0,1,2]], orbitals, backend='jax'
...             )
  • Mutual information:

>>> I_AB        = ent.mutual_information([0,1,2], [3,4,5], orbitals)        # I(A:B) = S_A + S_B - S_AB, A=[0,1,2], B=[3,4,5], occupied orbitals
  • Entropy scaling:

>>> results     = ent.entropy_scan(orbitals, sizes=[1,2,3,4,5])             # entropies for subsystems of sizes 1 to 5, consqutive sites starting from 0
  1. Many-body Hamiltonian (interacting):
    • Manual bipartition entropy:

>>> hamil       = ManyBodyHamiltonian(ns=8, ...)
>>> hamil.diagonalize()
>>> ent         = hamil.entanglement
>>> bipart      = ent.bipartition([0, 1, 2, 3])                             # subsystem A
>>> state       = hamil.eig_vec[:, 0]                                       # ground state wavefunction
>>> S_manual    = ent.entropy_manybody(bipart, state)                       # entropy from reduced density matrix, it happens internally
  • Density matrix access:

>>> rho_A       = ent.reduced_density_matrix(bipart, state)                 # reduced density matrix for subsystem A
>>> rho_B       = ent.reduced_density_matrix(bipart, state, subsystem='B')  # for subsystem B
__init__(operator)[source]

Initialize entanglement module for a Hamiltonian.

Parameters:

operator (object) – The Operator object (quadratic or many-body)

bipartition(mask_a: List[int] | ndarray | int, *, cache: bool = True) BipartitionInfo[source]

Create bipartition information for subsystem A.

Parameters:
  • mask_a (array-like or int) – Indices of sites in subsystem A, or number of sites in A (takes first N sites).

  • cache (bool) – Whether to cache the bipartition for reuse

Returns:

Information about the bipartition

Return type:

BipartitionInfo

Examples

>>> # Contiguous partition
>>> bipart = ent.bipartition(5)                 # First 5 sites
>>>
>>> # Non-contiguous partition
>>> bipart = ent.bipartition([0, 2, 4, 6, 8])   # Even sites
correlation_matrix(occupied_orbitals: List[int] | ndarray, *, bipartition: BipartitionInfo | None = None, subtract_identity: bool = False, raw: bool = True, mode: Literal['slater', 'BdG'] = 'slater', backend: str = 'numpy', **kwargs) ndarray[source]

Get single-particle correlation matrix C_ij = <c_i^\dag c_j>.

Computes the correlation matrix for a free fermion state defined by occupied orbitals. Uses spin-unpolarized convention (factor of 2).

Parameters:
  • occupied_orbitals (array-like) – Indices of occupied orbitals (in eigenstate basis). For ground state, use [0, 1, …, N-1] for N particles.

  • bipartition (BipartitionInfo, optional) – If provided, returns correlation matrix restricted to subsystem A. If None, returns full correlation matrix for all sites.

  • subtract_identity (bool) – Whether to subtract identity from the correlation matrix.

  • backend (str) – ‘numpy’ or ‘jax’ for GPU acceleration.

Returns:

Correlation matrix C_ij = <c_i^\dag c_j>. Shape: (size_a, size_a) if bipartition given, else (ns, ns). Diagonal elements are site occupations (in [0,2] range with factor 2).

Return type:

np.ndarray or jax.numpy.ndarray

Examples

  • Full correlation matrix for ground state:

>>> hamil   = QuadraticHamiltonian(ns=8, dtype=np.complex128)
>>> # ... add hopping terms ...
>>> hamil.diagonalize()
>>> ent     = hamil.entanglement                        # Get entanglement module
>>>
>>> # Half-filling: occupy lowest 4 orbitals
>>> orbitals    = [0, 1, 2, 3]
>>> C_full      = ent.correlation_matrix(orbitals)
>>> print(C_full.shape)                                 # (8, 8)
>>> print(np.trace(C_full))                             # Should be 2*4 = 8 (spin-unpolarized)
  • Subsystem correlation matrix:

>>> bipart      = ent.bipartition([0, 1, 2])            # First 3 sites
>>> C_A = ent.correlation_matrix(orbitals, bipartition=bipart)
>>> print(C_A.shape)                                    # (3, 3)
>>> # Use for entropy: eigenvalues -> occupations -> entropy
  • JAX backend for GPU, same result as NumPy:

>>> C_jax       = ent.correlation_matrix(orbitals, backend='jax')
>>> # Same result as NumPy, but runs on GPU
  • Verify correlation matrix properties:

>>> C           = ent.correlation_matrix(orbitals)
>>> # Hermitian
>>> assert np.allclose(C, C.conj().T)
>>> # Occupations in [0, 2]
>>> assert np.all(np.diag(C) >= 0) and np.all(np.diag(C) <= 2)
entropy_correlation(bipartition: BipartitionInfo, occupied_orbitals: List[int] | ndarray, *, q: float = 1.0, C_A: ndarray | Array | None = None, subtract_identity: bool = False, backend: str = 'numpy', **kwargs) float[source]

Calculate entanglement entropy from single-particle correlation matrix.

SINGLE-PARTICLE METHOD - Fast O(L_A³) method for non-interacting (quadratic) Hamiltonians. Computes entropy from correlation matrix eigenvalues.

Works for ANY bipartition (contiguous or non-contiguous) of free fermion states.

Parameters:
  • bipartition (BipartitionInfo) – Bipartition of the system (use ent.bipartition() to create). Works for both contiguous and non-contiguous subsystems.

  • occupied_orbitals (array-like) – Indices of occupied orbitals (in eigenstate basis). For ground state with N particles, use [0, 1, …, N-1].

  • C_A (np.ndarray or jax.numpy.ndarray, optional) – Precomputed correlation matrix for subsystem A. If provided, uses this instead of computing from occupied_orbitals.

  • subtract_identity (bool) – Whether to subtract identity from correlation matrix (advanced)

  • backend (str) – ‘numpy’ or ‘jax’ for computation backend

Returns:

Entanglement entropy (in natural log units, always positive)

Return type:

float

Notes

Algorithm: 1. Compute full correlation matrix C_ij = <c_i^dag c_j> from occupied orbitals. For BdG, use all <c_i c_j> etc. 2. Extract subblock C_A for sites in subsystem A (handles non-contiguous) 3. Diagonalize C_A to get eigenvalues (occupations in [0,1]) 4. Apply single-particle entropy formula:

S = - sum_k [ n_k log(n_k) + (1-n_k) log(1-n_k) ]

This gives the EXACT entanglement entropy for ANY bipartition of non-interacting (quadratic) Hamiltonians and matches entropy_many_body().

Limitations: - Requires diagonalized Hamiltonian - Only works for quadratic (non-interacting) Hamiltonians - For interacting systems, use entropy_many_body()

entropy_many_body(bipartition: BipartitionInfo, *, rho_a: ndarray | None = None, state: ndarray | None = None, q: float = 1.0, method: str = 'auto', use_eig: bool = False, hilbert=None, occupied_orbitals: List[int] | ndarray | None = None) float[source]

Calculate entanglement entropy from many-body state.

MANY-BODY METHOD - Exact method that works for ANY quantum state, including interacting systems. Performs Schmidt decomposition of the many-body wavefunction.

Parameters:
  • bipartition (BipartitionInfo) – Bipartition of the system (use ent.bipartition() to create)

  • q (float) – Renyi index (default: 1.0 for von Neumann entropy)

  • rho_a (np.ndarray, optional) – Precomputed reduced density matrix for subsystem A. If provided, uses this instead of computing from state.

  • state (np.ndarray, optional) – Many-body state vector (length 2^ns). If None, occupied_orbitals must be provided to construct the state (for free fermions).

  • method (str) – ‘auto’ : Choose best method based on bipartition geometry ‘schmidt’ : Use Schmidt decomposition with mask (for non-contiguous) ‘numpy’ : Use direct numpy Schmidt (for contiguous, faster)

  • use_eig (bool) – Whether to use eigenvalue decomposition (True) or SVD (False)

  • hilbert (HilbertSpace, optional) – Hilbert space with symmetries. If provided and has symmetries, symmetry-based reduced density matrix computation is used. This is only available when general_python is installed correctly, as it is not a part of General Python.

  • occupied_orbitals (array-like, optional) – Indices of occupied orbitals. Required only if state is None.

Returns:

Von Neumann entanglement entropy (always positive)

Return type:

float

entropy_scan(*, state: ndarray | None = None, occupied_orbitals: List[int] | ndarray | None = None, subsystem_sizes: List[int] | None = None, q: float = 1.0, method: str = 'auto', contiguous: bool = True) dict[source]

Calculate entanglement entropy for multiple subsystem sizes.

Parameters:
  • occupied_orbitals (array-like, optional) – Occupied orbitals for the state (for free fermions).

  • subsystem_sizes (list of int, optional) – Sizes of subsystem A to scan. If None, scans all sizes from 1 to ns-1

  • q (float) – Renyi index (default: 1.0 for von Neumann entropy)

  • method (str) – ‘auto’ : Use correlation matrix for quadratic, many-body otherwise ‘correlation’ : Force correlation matrix method ‘many_body’ : Force many-body method

  • contiguous (bool) – If True, use contiguous partitions [0:size_a] If False, use random partitions

  • state (np.ndarray, optional) – Many-body state vector. Required if occupied_orbitals is None and method is ‘many_body’ or ‘auto’ (for interacting systems).

Returns:

Dictionary with keys: - ‘sizes’ : Subsystem sizes - ‘entropies’ : Entanglement entropies - ‘method’ : Method used

Return type:

dict

Examples

>>> results = ent.entropy_scan(orbitals=[0,1,2,3,4])
>>> plt.plot(results['sizes'], results['entropies'])
  • Use many-body state for interacting system:

>>> results = ent.entropy_scan(state=ground_state_vector)
>>> plt.plot(results['sizes'], results['entropies'])
mutual_information(mask_a: List[int] | ndarray, mask_b: List[int] | ndarray, *, q: float = 1.0, occupied_orbitals: List[int] | ndarray | None = None, method: str = 'auto', state: ndarray | None = None, **kwargs) float[source]

Calculate mutual information I(A:B) = S(A) + S(B) - S(AB). Importantly, regions A and B can be smaller than total system size, for instance, a common use case is to compute mutual information for single site subsystems, for instance I(i:j) between sites i and j. In such case A = {i}, B = {j}, AB = {i,j}.

Parameters:
  • mask_a (array-like) – Indices for subsystems A and B

  • mask_b (array-like) – Indices for subsystems A and B

  • occupied_orbitals (array-like, optional) – Occupied orbitals (for free fermions)

  • method (str) – ‘auto’, ‘correlation’, or ‘many_body’

  • state (np.ndarray, optional) – Many-body state vector.

Examples

>>> I_AB = ent.mutual_information(
...             mask_a              = [0,1,2],
...             mask_b              = [3,4,5],
...             occupied_orbitals   = [0,1,2,3,4,5]
...         )
  • Using many-body state:

>>> I_AB = ent.mutual_information(
...             mask_a  = [0,1,2],
...             mask_b  = [3,4,5],
...             state   = ground_state_vector,
...             method  = 'many_body'
...         )
Returns:

Mutual information

Return type:

float

entropy_multipartition(bipartitions: List[BipartitionInfo | List[int] | ndarray], occupied_orbitals: List[int] | ndarray | None = None, *, method: str = 'auto', backend: str = 'numpy', state: ndarray | None = None) dict[source]

Calculate entanglement entropy for multiple bipartitions simultaneously.

Efficient batch calculation that computes correlation matrix once and reuses it for all bipartitions, or computes many-body state once for all Schmidt decompositions.

Parameters:
  • bipartitions (list) – List of BipartitionInfo objects or site masks (will create BipartitionInfo).

  • occupied_orbitals (array-like, optional) – Occupied orbitals for correlation method.

  • method (str) – ‘auto’, ‘correlation’, or ‘many_body’.

  • backend (str) – ‘numpy’ or ‘jax’ (for correlation method).

  • state (np.ndarray, optional) – Pre-computed many-body state (for many_body method). If None and method is many_body, will be computed from occupied_orbitals.

Returns:

Results dictionary containing: - ‘entropies’: array of entropies for each bipartition - ‘bipartitions’: list of BipartitionInfo objects - ‘method’: method used (‘correlation’ or ‘many_body’) - ‘correlation_matrices’: list of C_A matrices (if method=’correlation’)

Return type:

dict

Examples

Basic usage:
>>> masks = [[0,1], [0,1,2], [0,1,2,3]]
>>> results = ent.entropy_multipartition(masks, orbitals=[0,1,2,3,4])
>>> print(results['entropies'])  # array([S_1, S_2, S_3])
Access correlation matrices:
>>> C_matrices = results['correlation_matrices']
>>> for i, C_A in enumerate(C_matrices):
...     print(f"Bipartition {i}: C_A shape = {C_A.shape}")
JAX backend:
>>> results_jax = ent.entropy_multipartition(
...     masks, orbitals, backend='jax'
... )
Many-body method:
>>> state = hamil.many_body_state(orbitals)
>>> results_mb = ent.entropy_multipartition(
...     masks, orbitals, method='many_body', state=state
... )
topological_entropy(*, q: float = 1.0, state: ndarray | None = None, occupied_orbitals: List[int] | ndarray = None, construction: Literal['kitaev_preskill', 'levin_wen'] = 'kitaev_preskill', method: str = 'auto', regions: Dict[str, ndarray] | None = None, hilbert=None) Dict[str, float][source]

Calculate topological entanglement entropy (TEE).

The topological entanglement entropy γ characterizes topological order. For topologically ordered states, S(A) = αL - γ + O(1/L), where γ > 0.

Parameters:
  • occupied_orbitals (array-like) – Occupied orbitals defining the state

  • construction (str) – ‘kitaev_preskill’ : γ = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC ‘levin_wen’ : Alternative construction with disk geometry

  • method (str) – Entropy calculation method (‘auto’, ‘correlation’, ‘many_body’)

  • regions (dict, optional) – Custom region definitions. If None, uses MaskGenerator.

  • hilbert (HilbertSpace, optional) – Hilbert space object for symmetry-aware calculations.

Returns:

Dictionary containing: - ‘gamma’ : Topological entanglement entropy - ‘entropies’ : Individual region entropies - ‘regions’ : Region masks used

Return type:

dict

Notes

For the Kitaev-Preskill construction:

γ = S_A + S_B + S_C - S_AB - S_BC - S_AC + S_ABC

This combination cancels the area law contribution and extracts the universal topological term. For topological phases like the toric code, γ = log(D) where D is the total quantum dimension.

References

  • Kitaev & Preskill, PRL 96, 110404 (2006)

  • Levin & Wen, PRL 96, 110405 (2006)

Examples

>>> result = ent.topological_entropy(orbitals, construction='kitaev_preskill')
>>> print(f"Topological entropy: γ = {result['gamma']:.4f}")
verify_wicks_theorem(occupied_orbitals: List[int] | ndarray, state: ndarray | None = None, *, test_sites: Tuple[int, int, int, int] | None = None, tolerance: float = 1e-10, hilbert=None) Dict[str, bool | float | ndarray][source]

Verify Wick’s theorem for a state: check if it’s a valid Slater determinant.

For free fermion (quadratic) Hamiltonians, all correlation functions factorize according to Wick’s theorem. This method verifies this property.

Wick’s theorem states that for a Slater determinant:

<c_i† c_j† c_l c_k> = <c_i† c_k><c_j† c_l> - <c_i† c_l><c_j† c_k>

Parameters:
  • occupied_orbitals (array-like) – Occupied orbitals defining the expected Slater determinant

  • state (np.ndarray, optional) – Many-body state to verify. If None, constructs from occupied_orbitals.

  • test_sites (tuple, optional) – Specific (i, j, k, l) sites to test. If None, tests random sites.

  • tolerance (float) – Tolerance for numerical comparison

  • hilbert (HilbertSpace, optional) – Hilbert space object for symmetry-aware calculations.

Returns:

Dictionary containing: - ‘is_valid’ : bool - True if Wick’s theorem is satisfied - ‘max_error’ : float - Maximum deviation from Wick’s theorem - ‘errors’ : np.ndarray - Error matrix for all tested site combinations - ‘correlation_matrix’ : np.ndarray - Single-particle correlation matrix

Return type:

dict

Examples

>>> result = ent.verify_wicks_theorem(orbitals)
>>> if result['is_valid']:
...     print("State satisfies Wick's theorem (is a Slater determinant)")
>>> else:
...     print(f"Max error: {result['max_error']:.2e}")

Notes

A state satisfies Wick’s theorem if and only if it is a Slater determinant (or a mixture thereof for finite temperature). This is equivalent to the state being a Gaussian state for fermions.

help()[source]

Print usage help for the entanglement module.

general_python.physics.entanglement_module.get_entanglement_module(hamiltonian) EntanglementModule[source]

Factory function to create entanglement module for a Hamiltonian.

Parameters:

hamiltonian (Hamiltonian) – The Hamiltonian object

Returns:

Entanglement module instance

Return type:

EntanglementModule