Source code for general_python.algebra.backend_linalg

"""
general_python/algebra/backend_linalg.py

Unified linear algebra backend providing matrix operations, decompositions,
and transformations with support for multiple backends (NumPy, JAX).

This module consolidates linalg.py and linalg_sparse.py functionality into
a single backend-aware interface.

Features:
  - Matrix transformations and basis changes
  - Outer and Kronecker products (dense and sparse)
  - Inner products and overlaps
  - Matrix properties (trace, Hilbert-Schmidt norm)
  - Matrix creation (identity, etc.)
  - Eigendecomposition (dense, sparse, Lanczos)
  - State manipulation (Givens rotations)
  - Backend support: NumPy, SciPy, JAX

Type Safety:
  - Proper dtype handling and promotion
  - Complex values preserved through computations
  - Backend-specific type conversions handled transparently

Author: Maksymilian Kliczkowski
Email: maksymilian.kliczkowski@pwr.edu.pl
"""

from typing import Optional, Union, Tuple, Literal, Any, Callable
import numpy as np
import scipy as sp
import scipy.sparse
import scipy.linalg
import scipy.sparse.linalg
from numpy.typing import NDArray

# Backend imports
try:
    from .utils import JAX_AVAILABLE, get_backend, JIT
except ImportError:
    JAX_AVAILABLE = False
    get_backend = lambda x="default": np
    JIT = lambda f: f

if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
    import jax.scipy as jsp
    from jax.experimental.sparse import BCOO
    from functools import partial
else:
    jnp = None
    jsp = None
    BCOO = None
    partial = lambda f, **kwargs: f

# Type alias
Array = Union[np.ndarray, Any]  # Any allows JAX arrays

# =============================================================================
# Basis Transformations
# =============================================================================

[docs] def change_basis( unitary_matrix : Array, state_vector : Array, backend : str = "default") -> Array: r""" Transform state vector to new basis. V' = U V Parameters ---------- unitary_matrix : array-like, shape (N, N) Unitary transformation matrix. state_vector : array-like, shape (N,) or (N, 1) State vector to transform. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Transformed state vector. """ be = get_backend(backend) U = be.asarray(unitary_matrix) vec = be.asarray(state_vector) # Common dtype common_dtype = np.result_type(U.dtype, vec.dtype) if hasattr(U, 'dtype') else np.complex128 U = be.asarray(U, dtype=common_dtype) vec = be.asarray(vec, dtype=common_dtype) return U @ vec
[docs] def change_basis_matrix( unitary_matrix : Array, matrix : Array, direction : Literal['forward', 'backward'] = 'forward', backend : str = "default") -> Array: r""" Change basis of matrix using unitary transformation. Forward: A' = U A Udagger Backward: A' = Udagger A U Parameters ---------- unitary_matrix : array-like, shape (N, N) Unitary transformation matrix U. matrix : array-like, shape (N, N) Matrix to transform A. direction : {'forward', 'backward'}, optional Transformation direction (default: 'forward'). backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Transformed matrix. """ be = get_backend(backend) U = be.asarray(unitary_matrix) A = be.asarray(matrix) # Common dtype common_dtype = np.result_type(U.dtype, A.dtype) if hasattr(U, 'dtype') else np.complex128 U = be.asarray(U, dtype=common_dtype) A = be.asarray(A, dtype=common_dtype) U_H = be.conj(U).T if direction == 'forward': return U @ A @ U_H else: # backward return U_H @ A @ U
# ============================================================================= # Outer and Kronecker Products # =============================================================================
[docs] def outer(A: Array, B: Array, backend: str = "default") -> Array: r""" Compute outer product of two vectors. C = A otimes B (element-wise: C[i,j] = A[i] * B[j]) Parameters ---------- A : array-like, shape (N,) First vector. B : array-like, shape (M,) Second vector. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Outer product, shape (N, M). """ be = get_backend(backend) return be.outer(A, B)
[docs] def kron(A: Array, B: Array, backend: str = "default") -> Array: r""" Compute Kronecker product of two matrices (dense). C = A otimes B Parameters ---------- A : array-like, shape (N, M) First matrix. B : array-like, shape (P, Q) Second matrix. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Kronecker product, shape (N*P, M*Q). """ be = get_backend(backend) return be.kron(A, B)
[docs] def kron_sparse(A: Array, B: Array, backend: str = "default") -> Array: r""" Compute Kronecker product of sparse matrices. Preserves sparsity structure efficiently. Parameters ---------- A : array-like or sparse, shape (N, M) First matrix. B : array-like or sparse, shape (P, Q) Second matrix. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Kronecker product, sparse if JAX BCOO available, else dense. """ be = get_backend(backend) if backend == "jax" and JAX_AVAILABLE: return _kron_sparse_jax(A, B) else: return _kron_sparse_numpy(A, B)
def _kron_sparse_numpy(A: Array, B: Array) -> scipy.sparse.csr_matrix: """Kronecker product using SciPy sparse matrices.""" if not sp.sparse.issparse(A): A = sp.sparse.csr_matrix(A) if not sp.sparse.issparse(B): B = sp.sparse.csr_matrix(B) return sp.sparse.kron(A, B) def _kron_sparse_jax(A: Array, B: Array) -> "BCOO": """Kronecker product using JAX BCOO sparse matrices.""" if not isinstance(A, BCOO): A = BCOO.fromdense(A, index_dtype=jnp.int64) if not isinstance(B, BCOO): B = BCOO.fromdense(B, index_dtype=jnp.int64) m, n = A.shape p, q = B.shape A_idx, A_data = A.indices, A.data B_idx, B_data = B.indices, B.data nnz_A = A_idx.shape[0] nnz_B = B_idx.shape[0] # All combinations of nonzeros new_i = jnp.repeat(A_idx[:, 0], nnz_B) * p + jnp.tile(B_idx[:, 0], nnz_A) new_j = jnp.repeat(A_idx[:, 1], nnz_B) * q + jnp.tile(B_idx[:, 1], nnz_A) new_indices = jnp.stack([new_i, new_j], axis=1) new_data = (A_data[:, None] * B_data[None, :]).reshape(-1) new_shape = (m * p, n * q) return BCOO((new_data, new_indices), shape=new_shape) # ============================================================================= # Inner Products and Overlaps # =============================================================================
[docs] def inner(vec1: Array, vec2: Array, backend: str = "default") -> Array: r""" Compute inner product of two vectors. <v1|v2> = v1dagger \cdot v2 Parameters ---------- vec1 : array-like, shape (N,) First vector. vec2 : array-like, shape (N,) Second vector. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar Inner product (complex if inputs complex, real otherwise). """ be = get_backend(backend) return be.dot(be.conj(vec1), vec2)
[docs] def ket_bra(vec: Array, backend: str = "default") -> Array: r""" Compute ket-bra (outer product) of a vector. |Psi ><Psi | = Psi Psi^dagger Parameters ---------- vec : array-like, shape (N,) Vector. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Outer product matrix, shape (N, N). """ be = get_backend(backend) return be.outer(vec, be.conj(vec))
[docs] def bra_ket(vec: Array, backend: str = "default") -> Array: r""" Compute bra-ket (inner product) of a vector with itself. <Psi |Psi > = Psi dagger \cdot Psi = ||Psi ||² Parameters ---------- vec : array-like, shape (N,) Vector. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar Squared norm (always real and non-negative). """ be = get_backend(backend) return be.real(be.dot(be.conj(vec), vec))
[docs] def overlap( a : Array, O : Array, b : Optional[Array] = None, backend : str = "default") -> Array: r""" Compute matrix element <a|O|b>. Supports: - 1D vectors: returns scalar - 2D matrices (columns are states): returns matrix of overlaps - Mixed dimensions: returns vector Parameters ---------- a : array-like, shape (dim,) or (dim, n) Left states. Columns are individual states if 2D. O : array-like or sparse, shape (dim, dim) Operator matrix. b : array-like, shape (dim,) or (dim, m), optional Right states. If None, defaults to a. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar, 1D or 2D Array <a|O|b>. Shape depends on input dimensions. Examples -------- >>> # Single pair of vectors: scalar result >>> s = overlap(psi, H, phi) # shape: (), dtype: complex >>> # Multiple right states: vector result >>> v = overlap(psi, H, phi_array) # shape: (n,), dtype: complex >>> # Multiple left and right: matrix result >>> M = overlap(psi_array, H, phi_array) # shape: (m, n), dtype: complex """ be = get_backend(backend) if b is None: b = a # Convert to 2D column matrices if a.ndim == 1: a_mat = a[:, None] else: a_mat = a if b.ndim == 1: b_mat = b[:, None] else: b_mat = b # Apply operator Ob = O @ b_mat # Compute overlaps: adagger O b res = a_mat.conj().T @ Ob # Squeeze trivial dimensions if res.shape == (1, 1): return res[0, 0] if res.shape[0] == 1: return res[0, :] if res.shape[1] == 1: return res[:, 0] return res
[docs] def overlap_diagonal( a : Array, O : Array, b : Optional[Array] = None, backend : str = "default") -> Array: r""" Compute only diagonal elements <a_i|O|b_i>. More efficient than full overlap() when only diagonals needed. Parameters ---------- a : array-like, shape (dim,) or (dim, n) Left states (columns are states if 2D). O : array-like, shape (dim, dim) Operator matrix. b : array-like, shape (dim,) or (dim, n), optional Right states. If None, defaults to a. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar or 1D Array Diagonal elements <a_i|O|b_i>. """ be = get_backend(backend) if b is None: b = a if a.ndim == 1: # Scalar case return inner(a, O @ b, backend) # Matrix case: extract diagonals a_mat = be.atleast_2d(a) b_mat = be.atleast_2d(b) if a_mat.shape != b_mat.shape: raise ValueError("a and b must have same shape for diagonal overlap") # Apply operator to each column Ob = O @ b_mat # Diagonal: a[i]dagger O b[i] return be.einsum('ik,ki->i', be.conj(a_mat), Ob)
# ============================================================================= # Matrix Properties # =============================================================================
[docs] def trace(matrix: Array, backend: str = "default") -> Any: r""" Compute matrix trace. Tr(A) = Sum_i A_ii Parameters ---------- matrix : array-like, shape (N, N) Matrix to compute trace of. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar Trace of matrix. """ be = get_backend(backend) return be.trace(matrix)
[docs] def hilbert_schmidt_norm(matrix: Array, backend: str = "default") -> Any: r""" Compute Hilbert-Schmidt norm of matrix. ||A||_HS = √(Tr(Adagger A)) Parameters ---------- matrix : array-like, shape (N, N) Matrix to compute norm of. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar (real) Hilbert-Schmidt norm. """ be = get_backend(backend) A = be.asarray(matrix, dtype=be.complex128) return be.sqrt(be.trace(be.conj(A) @ A))
[docs] def frobenius_norm(matrix: Array, backend: str = "default") -> Any: r""" Compute Frobenius norm of matrix. ||A||_F = √(Sum_ij |A_ij|²) Parameters ---------- matrix : array-like, shape (N, M) Matrix to compute norm of. backend : str, optional Numerical backend (default: 'default'). Returns ------- scalar (real) Frobenius norm. """ be = get_backend(backend) return be.linalg.norm(matrix, 'fro')
# ============================================================================= # Matrix Creation # =============================================================================
[docs] def identity( n : int, dtype : Optional[Union[str, np.dtype]] = None, backend : str = "default") -> Array: r""" Create identity matrix I_n. Parameters ---------- n : int Size of identity matrix. dtype : data-type, optional Element data type (default: backend default). backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Identity matrix, shape (n, n). """ be = get_backend(backend) if dtype is None: dtype = be.float64 return be.eye(n, dtype=dtype)
[docs] def identity_sparse( n : int, dtype : Optional[Union[str, np.dtype]] = None, backend : str = "default") -> Union[scipy.sparse.csr_matrix, "BCOO"]: r""" Create sparse identity matrix. Parameters ---------- n : int Size of identity matrix. dtype : data-type, optional Element data type (default: float64). backend : str, optional Numerical backend (default: 'default'). Returns ------- sparse matrix Sparse identity: scipy.sparse.csr_matrix (numpy) or BCOO (jax). """ if dtype is None: dtype = np.float64 if backend == "jax" and JAX_AVAILABLE: indices = jnp.stack([jnp.arange(n), jnp.arange(n)], axis=1, dtype=jnp.int32) data = jnp.ones(n, dtype=dtype) return BCOO((data, indices), shape=(n, n)) else: return sp.sparse.eye(n, dtype=dtype)
# ============================================================================= # Backend Format Conversion # =============================================================================
[docs] def to_dense( matrix : Array, backend : str = "default") -> Array: r""" Convert sparse or other formats to dense array. Parameters ---------- matrix : array-like Matrix to convert. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Dense array. """ be = get_backend(backend) if sp.sparse.issparse(matrix): return matrix.toarray() elif JAX_AVAILABLE and isinstance(matrix, BCOO): return matrix.todense() else: return be.asarray(matrix)
[docs] def to_sparse( matrix : Array, backend : str = "default", format : str = "csr") -> Union[scipy.sparse.csr_matrix, "BCOO"]: r""" Convert dense matrix to sparse format. Parameters ---------- matrix : array-like, shape (N, M) Matrix to convert. backend : str, optional Numerical backend (default: 'default'). format : str, optional Sparse format: 'csr', 'csc', 'coo', 'bsr' (numpy) or 'bcoo' (jax). Default: 'csr'. Returns ------- sparse matrix Sparse matrix in specified format. """ if backend == "jax" and JAX_AVAILABLE and format == "bcoo": if isinstance(matrix, BCOO): return matrix else: return BCOO.fromdense(matrix) else: if sp.sparse.issparse(matrix): return matrix.asformat(format) else: return sp.sparse.csr_matrix(matrix, format=format)
# ============================================================================= # Eigendecomposition # =============================================================================
[docs] def eig(matrix : Array, backend : str = "default", **kwargs) -> Tuple[Array, Array]: r""" Eigendecomposition of general matrix. Computes all eigenvalues and eigenvectors. Parameters ---------- matrix : array-like, shape (N, N) Matrix to diagonalize. backend : str, optional Numerical backend (default: 'default'). **kwargs Additional arguments (backend-specific). Returns ------- eigenvalues : Array, shape (N,) Eigenvalues. eigenvectors : Array, shape (N, N) Eigenvectors as columns. """ be = get_backend(backend) if backend == "jax" and JAX_AVAILABLE: evals, evecs = jnp.linalg.eig(matrix) else: evals, evecs = np.linalg.eig(matrix) return evals, evecs
[docs] def eigh( matrix : Array, backend : str = "default", **kwargs) -> Tuple[Array, Array]: r""" Eigendecomposition of Hermitian matrix (dense). More stable than eig() for Hermitian/symmetric matrices. Parameters ---------- matrix : array-like, shape (N, N) Hermitian matrix to diagonalize. backend : str, optional Numerical backend (default: 'default'). **kwargs Additional arguments (backend-specific). Returns ------- eigenvalues : Array, shape (N,), real Eigenvalues in ascending order. eigenvectors : Array, shape (N, N) Eigenvectors as columns. """ be = get_backend(backend) if backend == "jax" and JAX_AVAILABLE: evals, evecs = jnp.linalg.eigh(matrix) else: evals, evecs = np.linalg.eigh(matrix) return evals, evecs
[docs] def eigsh( matrix : Array, k : int = 6, which : Literal['smallest', 'largest'] = 'smallest', backend : str = "default", **kwargs ) -> Tuple[Array, Array]: r""" Partial eigendecomposition of Hermitian matrix (sparse). Computes k extremal eigenvalues and eigenvectors. Uses SciPy sparse methods. Parameters ---------- matrix : array-like, shape (N, N) Hermitian matrix to diagonalize. k : int, optional Number of eigenvalues to compute (default: 6). which : {'smallest', 'largest'}, optional Which eigenvalues to compute (default: 'smallest'). backend : str, optional Numerical backend (default: 'default'). **kwargs Additional arguments passed to scipy.sparse.linalg.eigsh. Returns ------- eigenvalues : Array, shape (k,) k extremal eigenvalues. eigenvectors : Array, shape (N, k) Corresponding eigenvectors as columns. """ # Convert to scipy's which convention which_map = {'smallest': 'SA', 'largest': 'LA'} which_sp = which_map.get(which, 'SA') evals, evecs = sp.sparse.linalg.eigsh( matrix, k=k, which=which_sp, **kwargs ) return evals, evecs
# ============================================================================= # State Manipulation # =============================================================================
[docs] def givens_rotation( V : Array, i : int, j : int, theta : float, backend : str = "default") -> Array: r""" Apply Givens rotation to matrix or vector. Rotates the (i,j) plane by angle θ. Parameters ---------- V : array-like Matrix or vector to rotate. i, j : int Indices of rotation plane. theta : float Rotation angle in radians. backend : str, optional Numerical backend (default: 'default'). Returns ------- Array Rotated matrix/vector. """ be = get_backend(backend) # Create an explicit copy using .copy() method (works for both NumPy and JAX arrays) V_array = be.array(V) V_rot = V_array.copy() if hasattr(V_array, 'copy') else V_array c = be.cos(theta) s = be.sin(theta) # Apply rotation to columns i and j v_i = V_rot[:, i].copy() v_j = V_rot[:, j].copy() V_rot[:, i] = c * v_i - s * v_j V_rot[:, j] = s * v_i + c * v_j return V_rot
# ============================================================================= # Exports # ============================================================================= __all__ = [ # Basis transformations 'change_basis', 'change_basis_matrix', # Products 'outer', 'kron', 'kron_sparse', # Inner products and overlaps 'inner', 'ket_bra', 'bra_ket', 'overlap', 'overlap_diagonal', # Matrix properties 'trace', 'hilbert_schmidt_norm', 'frobenius_norm', # Matrix creation 'identity', 'identity_sparse', # Format conversion 'to_dense', 'to_sparse', # Eigendecomposition 'eig', 'eigh', 'eigsh', # State manipulation 'givens_rotation', ] # ============================================================================= # End of File # =============================================================================