Source code for general_python.algebra.preconditioners

"""
file:       general_python/algebra/preconditioners.py
author:     Maksymilian Kliczkowski

This module contains the implementation of preconditioners for iterative solvers
of linear systems Ax = b. Preconditioners transform the system into an
equivalent one,

M^{-1}Ax = M^{-1}b (left preconditioning) 

or

A M^{-1} y = b, x = M^{-1}y (right preconditioning),

where M is the preconditioner matrix.

The goal is for the transformed system to have more
favorable spectral properties (e.g., eigenvalues clustered around 1, lower
condition number), leading to faster convergence of iterative methods like CG,
MINRES, GMRES. The matrix M should approximate A in some sense, while the
operation M^{-1}r should be computationally inexpensive.

----------------------------------------------------------------------------
File        : general_python/algebra/preconditioners.py
Author      : Maksymilian Kliczkowski
Date        : 2025-02-02
Version     : 1.0
"""

# Import the required modules
from abc import ABC, abstractmethod
from typing import Union, Callable, Optional, Any, Type, Tuple, Dict, TYPE_CHECKING
from enum import Enum, auto, unique
import inspect
import numpy as np

# Add sparse imports at the top of the file
import  scipy.sparse as sps
import  scipy.sparse.linalg as spsla
import  scipy.linalg as sla

try:
    from .utils import JAX_AVAILABLE, get_backend, Array
except ImportError:
    raise ImportError("Failed to import from .utils. Ensure the module is in the correct package structure.")

if TYPE_CHECKING:
    from common.flog import Logger

# ---------------------------------------------------------------------

try:
    if JAX_AVAILABLE:
        import jax
        import jax.numpy as jnp
        import jax.scipy as jsp
    else:
        jnp   = None
        jax   = None # Define jax as None if not available
except ImportError:
    jnp   = None
    jax   = None

# ---------------------------------------------------------------------

# Interface for Preconditioner Apply Function (Static version)
# Takes residual, and precomputed data...
PreconitionerApplyFun   = Callable[[Array], Array]
# Setup returns a dictionary of precomputed data
StaticSetupKernel       = Callable[..., Dict[str, Any]]
# can be:
#   - r, backend_mod, sigma, precomputed_data
#   - r, a, sigma, backend_mod, precomputed_data
#   - r, s, sp, sigma, backend_mod, precomputed_data
# Apply uses the precomputed data dictionary
StaticApplyKernel       = Callable[[Array, Any, float, Dict[str, Any]], Array] # r, backend_mod, sigma, precomputed_data

_TOLERANCE_SMALL        = 1e-13
_TOLERANCE_BIG          = 1e13

# just the example
[docs] def preconditioner_idn(r: Array) -> Array: """ Identity function for preconditioner apply. Parameters: r (Array): The input array. Returns: Array: The same input array. """ return r
[docs] @unique class PreconditionersType(Enum): """ Enumeration of the symmetry type of preconditioners. """ SYMMETRIC = auto() NONSYMMETRIC = auto()
[docs] @unique class PreconditionersTypeSym(Enum): """ Enumeration of specific symmetric preconditioner types. """ IDENTITY = 0 JACOBI = 1 INCOMPLETE_CHOLESKY = 2 COMPLETE_CHOLESKY = 3 SSOR = 4
[docs] @unique class PreconditionersTypeNoSym(Enum): """ Enumeration of specific potentially non-symmetric preconditioner types. """ IDENTITY = 0 INCOMPLETE_LU = 1
# Add others like Gauss-Seidel etc. # --------------------------------------------------------------------- #! Preconditioners # ---------------------------------------------------------------------
[docs] class Preconditioner(ABC): """ Abstract base class for preconditioners M used in iterative solvers. Provides a framework for setting up the preconditioner based on a matrix A (or its factors S, Sp for Gram matrices) and applying the inverse operation M^{-1}r efficiently. Supports different computational backends (NumPy, JAX). Attributes: is_positive_semidefinite (bool): Indicates if the original matrix A (and potentially M) is assumed to be positive semi-definite. Important for methods like Cholesky. is_gram (bool): True if the preconditioner setup uses factors S, Sp such that A = Sp @ S / N. sigma (float): Regularization parameter sigma added during setup, effectively forming M based on A + sigma*I. type (Enum): The specific type of the preconditioner (e.g., JACOBI, ILU). Set by subclass. stype (PreconditionersType): Symmetry type (SYMMETRIC/NONSYMMETRIC). backend_str (str): The name of the current backend ('numpy', 'jax'). """ _type : Optional[Union[PreconditionersTypeNoSym, PreconditionersTypeSym]] = None _name : str = "General Preconditioner" _dcol : str = "yellow" # -----------------------------------------------------------------
[docs] def __init__(self, is_positive_semidefinite = False, is_gram = False, backend = 'default', apply_func: Optional[PreconitionerApplyFun] = None, **kwargs ): """ Initialize the preconditioner. Parameters: is_positive_semidefinite (bool): True if the matrix is positive semidefinite. is_gram (bool): True if the preconditioner setup uses factors S, Sp such that A = Sp @ S / N. backend (optional): The computational backend to be used by the preconditioner. apply_func (PreconitionerApplyFun, optional): The apply function for the preconditioner. """ self._backend_str : str self._backend : Any # The numpy-like module (np or jnp) self._backends : Any # The scipy-like module (sp or jax.scipy) self._isjax : bool # True if using JAX backend self._logger: 'Logger' = kwargs.get('logger', None) self._TOLERANCE_SMALL = _TOLERANCE_SMALL self._TOLERANCE_BIG = _TOLERANCE_BIG self._zero = _TOLERANCE_BIG self._sigma = 0.0 # Regularization parameter self.reset_backend(backend) # Sets backend attributes self._is_positive_semidefinite = is_positive_semidefinite # Positive semidefinite flag self._is_gram = is_gram # Gram matrix flag if self._is_positive_semidefinite: self._stype = PreconditionersType.SYMMETRIC else: #! May still be symmetric, but we don't assume based on this flag alone self._stype = PreconditionersType.NONSYMMETRIC # Store reference to the static apply logic of the concrete class self._base_apply_logic : Callable[..., Array] = self.__class__._apply_kernel if not apply_func else apply_func # Preconditioner setup self._precomputed_data_instance : Optional[Dict[str, Any]] = None # Compiled/wrapped apply function for instance usage self._apply_func_instance : Optional[Callable[[Array], Array]] = None self._update_instance_apply_func() # Create initial apply(r)
# ----------------------------------------------------------------- #! Logging # -----------------------------------------------------------------
[docs] def log(self, msg: str, log: Union[int, str] = 'info', lvl : int = 0, color: str = "white", append_msg = True): """ Log the message. Args: msg (str) : The message to log. log (Union[int, str]) : The flag to log the message (default is 'info'). lvl (int) : The level of the message. color (str) : The color of the message. append_msg (bool) : Flag to append the message. """ if not self._logger: return if isinstance(log, str): log = self._logger.LEVELS_R.get(log.lower(), self._logger.LEVELS_R['info']) if append_msg: msg = f"[{self._name}] {msg}" msg = self._logger.colorize(msg, color) self._logger.say(msg, log=log, lvl=lvl)
# ----------------------------------------------------------------- #! Backend Management # -----------------------------------------------------------------
[docs] def reset_backend(self, backend: str): """ Resets the backend and recompiles the internal apply function. Parameters: backend (str): The name of the new backend ('numpy', 'jax'). """ new_backend_str = backend if backend != 'default' else 'numpy' # Resolve default if not hasattr(self, '_backend_str') or self._backend_str != new_backend_str: self.log(f"Resetting backend to: {new_backend_str}", lvl=1, color=self._dcol) self._backend_str = new_backend_str self._backend, self._backends = get_backend(self._backend_str, scipy=True) self._isjax = JAX_AVAILABLE and self._backend is not np # Re-create the wrapped/compiled apply function for the new backend self._update_instance_apply_func()
# ----------------------------------------------------------------- #! Closure for the apply function # ----------------------------------------------------------------- def _update_instance_apply_func(self): """ Creates/updates the instance's `apply(r)` func using stored data. """ static_apply = self.__class__._apply_kernel backend_mod = self._backend # Capture current sigma for the closure current_sigma = self._sigma instance_self = self def wrapped_apply_instance(r: Array) -> Array: precomputed_data = instance_self._get_precomputed_data_instance() return static_apply(r = r, backend_mod = backend_mod, sigma = current_sigma, **precomputed_data) if self._isjax and jax is not None: self._apply_func_instance = jax.jit(wrapped_apply_instance) else: self._apply_func_instance = wrapped_apply_instance # ----------------------------------------------------------------- #! Getters for apply functions # -----------------------------------------------------------------
[docs] def get_apply(self) -> Callable[[Array], Array]: """ Returns the potentially JIT-compiled function `apply(r)`. Uses precomputed data stored by the last call to `set()`. Raises RuntimeError if called before `set()` to match test expectations. """ # Ensure precomputed data exists before returning apply function if self._precomputed_data_instance is None: raise RuntimeError("Preconditioner apply function could not be initialized before set().") if self._apply_func_instance is None: self._update_instance_apply_func() if self._apply_func_instance is None: # If still None, raise error raise RuntimeError("Preconditioner apply(r) function failed to initialize.") return self._apply_func_instance
[docs] def get_apply_mat(self, **default_setup_kwargs) -> Callable[[Array, Array, float], Array]: """ Returns a potentially JIT-compiled function `apply_mat(r, A, sigma, **override_kwargs)` that computes preconditioner data from `A` and applies it on the fly. Params: **default_setup_kwargs: Default keyword arguments for the setup kernel (e.g., tol_small for Jacobi). These are fixed at compile time if using JIT. Returns: Callable: The compiled function. """ static_setup = self.__class__._setup_standard_kernel static_apply = self.__class__._apply_kernel backend_mod = self._backend instance_defaults = default_setup_kwargs sigma = self._sigma # Capture current sigma for the closure # Define the function that performs setup + apply def wrapped_apply_mat(r: Array, a: Array, **call_time_kwargs) -> Array: # Merge default setup kwargs with call-time kwargs (call-time overrides) setup_kwargs = {**instance_defaults, **call_time_kwargs} # Perform setup on the fly precomputed_data = static_setup(a, sigma, backend_mod, **setup_kwargs) # Apply using the computed data return static_apply(r, backend_mod, sigma, **precomputed_data) if self._isjax and jax is not None: # JIT the wrapper. A, sigma, r are dynamic. Backend is fixed. # We can make setup_kwargs static *if needed* by requiring them # to be passed to get_apply_mat instead of the returned function. # For simplicity now, assume setup_kwargs are dynamic unless performance dictates otherwise. self.log("JIT compiling apply_mat(r, A, sigma, ...) function...", log='info', lvl=2, color=self._dcol) return jax.jit(wrapped_apply_mat) else: self.log("Using numpy backend for apply_mat(r, A, sigma, ...), no JIT.", log='info', lvl=2, color=self._dcol) return wrapped_apply_mat
[docs] def get_apply_gram(self, **default_setup_kwargs) -> Callable[[Array, Array, Array, float], Array]: """ Returns a potentially JIT-compiled function `apply_gram(r, S, Sp, sigma, **override_kwargs)` that computes preconditioner data from `S`, `Sp` and applies it on the fly. Params: **default_setup_kwargs: Default keyword arguments for the setup kernel. Returns: Callable: The compiled function. Note: For JAX compatibility, kwargs are frozen at function creation time. The returned function does NOT accept runtime kwargs to avoid dictionary operations inside JIT-traced code. """ static_setup = self.__class__._setup_gram_kernel static_apply = self.__class__._apply_kernel backend_mod = self._backend # Freeze kwargs at creation time for JAX compatibility frozen_kwargs = dict(default_setup_kwargs) sigma = self._sigma # Capture current sigma for the closure def wrapped_apply_gram(r: Array, s: Array, sp: Array) -> Array: # Use frozen kwargs - no runtime kwargs to avoid dict ops in traced code precomputed_data = static_setup(s, sp, sigma, backend_mod, **frozen_kwargs) # Apply using the computed data return static_apply(r, backend_mod, sigma, **precomputed_data) # Don't JIT here - the preconditioner is called from inside the solver's # JIT-compiled function, so it will be traced and compiled as part of that. # Nested JIT causes issues with traced arrays. if self._isjax and jax is not None: self.log("Returning apply_gram(r, S, Sp) for JAX (will be traced by solver's JIT)", log='info', lvl=2, color=self._dcol) else: self.log("Using numpy backend for apply_gram(r, S, Sp, sigma, ...)", log='info', lvl=2, color=self._dcol) return wrapped_apply_gram
# ----------------------------------------------------------------- #! Properties # ----------------------------------------------------------------- # ----------------------------------------------------------------- #! Properties: General Attributes # ----------------------------------------------------------------- @property def name(self) -> str: """Name of the preconditioner.""" return self._name @property def dcol(self) -> str: """Color for logging messages.""" return self._dcol @property def backend_str(self) -> str: """Name of the current backend ('numpy', 'jax').""" return self._backend_str @property def backend(self) -> Any: """The backend module (e.g., np, jnp).""" return self._backend @property def backends(self) -> Any: """The backend module for scipy-like operations.""" return self._backends @property def isjax(self) -> bool: """True if using JAX backend.""" return self._isjax @property def precomputed_data(self) -> dict: """ Returns empty dict as no precomputed data is needed. """ return self._get_precomputed_data() # ----------------------------------------------------------------- #! Properties: Preconditioner Type and Symmetry # ----------------------------------------------------------------- @property def type(self) -> Optional[Union[PreconditionersTypeNoSym, PreconditionersTypeSym]]: """Specific preconditioner type (e.g., JACOBI, ILU).""" return self._type @property def stype(self) -> PreconditionersType: """Symmetry type (SYMMETRIC/NONSYMMETRIC).""" return self._stype @property def is_positive_semidefinite(self) -> bool: """True if the matrix A (and potentially M) is positive semidefinite.""" return self._is_positive_semidefinite # ----------------------------------------------------------------- #! Properties: Gram Matrix Setup # ----------------------------------------------------------------- @property def is_gram(self) -> bool: """True if the preconditioner is set up from Gram matrix factors S, Sp.""" return self._is_gram @is_gram.setter def is_gram(self, value: bool): """Set the is_gram flag. Requires re-running set().""" if self._is_gram != value: self.log(f"({self._name}) Changed is_gram to {value}. Remember to call set() again.", log='warning', lvl=1, color=self._dcol) self._is_gram = value # ----------------------------------------------------------------- #! Regularization parameter # ----------------------------------------------------------------- @property def sigma(self): """Regularization parameter.""" return self._sigma @sigma.setter def sigma(self, value): """Set the diagonal regularization and refresh cached apply state.""" self._sigma = value self._update_instance_apply_func() # Recompile apply(r) if data changes # ----------------------------------------------------------------- #! Tolerances # ----------------------------------------------------------------- @property def tol_big(self): """Tolerance for big values.""" return self._TOLERANCE_BIG @tol_big.setter def tol_big(self, value): """Set the large sentinel value used by safe inverse kernels.""" self._TOLERANCE_BIG = value @property def tol_small(self): """Tolerance for small values.""" return self._TOLERANCE_SMALL @tol_small.setter def tol_small(self, value): """Set the threshold below which values are treated as numerically zero.""" self._TOLERANCE_SMALL = value @property def zero(self): """Value treated as zero.""" return self._zero @zero.setter def zero(self, value): """Set the replacement value used for zero or near-zero pivots.""" self._zero = value # ----------------------------------------------------------------- #! KERNELS # ----------------------------------------------------------------- @staticmethod @abstractmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """Static Kernel: Computes precond data dict from matrix A.""" raise NotImplementedError @staticmethod @abstractmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """Static Kernel: Computes precond data dict from factors S, Sp.""" raise NotImplementedError @staticmethod @abstractmethod def _apply_kernel(r: Array, backend_mod: Any, sigma: float, **precomputed_data: Any) -> Array: """Static Kernel: Applies M^{-1}r using precomputed data.""" raise NotImplementedError # ----------------------------------------------------------------- #! Instance Setup Methods # ----------------------------------------------------------------- def _set_standard(self, a: Array, sigma: float, **kwargs): """Instance: Calls static setup kernel and stores result.""" self._precomputed_data_instance = self.__class__._setup_standard_kernel( a, sigma, self._backend, **kwargs ) self._update_instance_apply_func() # Recompile apply(r) if data changes def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): """Instance: Calls static setup kernel and stores result.""" self._precomputed_data_instance = self.__class__._setup_gram_kernel( s, sp, sigma, self._backend, **kwargs ) self._update_instance_apply_func() # Recompile apply(r) if data changes def _get_precomputed_data_instance(self) -> Dict[str, Any]: """Instance: Returns the stored precomputed data.""" if self._precomputed_data_instance is None: # Include standardized substring expected by tests raise RuntimeError(f"Preconditioner data not available - ({self._name}) not set up. Call set() first.") return self._precomputed_data_instance # ----------------------------------------------------------------- #! General Setup Method # -----------------------------------------------------------------
[docs] def set(self, a: Array, sigma: float = 0.0, ap: Optional[Array] = None, backend: Optional[str] = None, **kwargs): """ Sets up the preconditioner using the provided matrix A and optional parameters. This method computes the preconditioner data and prepares the apply function. Params: a (Array): The matrix to be used for setting up the preconditioner. sigma (float, optional): The regularization parameter. Defaults to 0.0. ap (Optional[Array], optional): An optional second matrix for Gram setup. Defaults to None. backend (Optional[str], optional): The backend to use for computations. Defaults to None. **kwargs: Additional keyword arguments for specific implementations. """ if backend is not None and backend != self.backend_str: self.reset_backend(backend) # Will trigger _update_instance_apply_func # Ensure backend consistency for inputs # For sparse matrices, np.asarray wraps them in a 0-d array object, which breaks shape checks if sps.issparse(a) and self._backend is np: a_be = a else: a_be = self._backend.asarray(a) ap_be = None if ap is not None: if sps.issparse(ap) and self._backend is np: ap_be = ap else: ap_be = self._backend.asarray(ap) # Use provided sigma or instance sigma, update instance sigma via setter self.sigma = sigma if sigma is not None else self._sigma self.log(f"Setting up preconditioner state with sigma={self.sigma} using backend='{self.backend_str}'...", log='info', lvl=1, color=self._dcol) if self.is_gram: s_mat, sp_mat = a_be, ap_be if sp_mat is None: sp_mat = self._backend.conjugate(s_mat).T # Shape checks... if s_mat.shape[1] != sp_mat.shape[0] or s_mat.shape[0] != sp_mat.shape[1]: raise ValueError("Shape mismatch") self._set_gram(s_mat, sp_mat, self.sigma, **kwargs) # Pass kwargs else: if a_be.ndim != 2 or a_be.shape[0] != a_be.shape[1]: raise ValueError("Needs square matrix") self._set_standard(a_be, self.sigma, **kwargs)
# ----------------------------------------------------------------- #! Apply Method # -----------------------------------------------------------------
[docs] def __call__(self, r: Array) -> Array: """ Apply the configured preconditioner instance M^{-1} to vector r using precomputed data. """ apply_func = self.get_apply() return apply_func(r)
# ----------------------------------------------------------------- #! String Representation # -----------------------------------------------------------------
[docs] def __repr__(self) -> str: """ Returns the name and configuration of the preconditioner. """ return f"{self._name}(sigma={self.sigma}, backend='{self.backend_str}', type={self.type})"
[docs] def __str__(self) -> str: """ Returns the name of the preconditioner. """ return self.__repr__()
# ----------------------------------------------------------------- # ===================================================================== #! Identity preconditioner # =====================================================================
[docs] class IdentityPreconditioner(Preconditioner): """ Identity preconditioner M = I. Applying M^{-1} simply returns the input vector. This is the simplest preconditioner and has no effect on the system. It serves as a baseline or placeholder. Math: M = I M^{-1}r = I^{-1}r = r """ _name = "Identity Preconditioner" _type = PreconditionersTypeSym.IDENTITY # Can be Sym or NoSym
[docs] def __init__(self, backend: str = 'default'): """ Initialize the Identity preconditioner. Args: backend (str): The computational backend ('numpy', 'jax', 'default'). """ # is_positive_semidefinite doesn't matter, is_gram=False super().__init__(is_positive_semidefinite=True, is_gram=False, backend=backend) # Identity is always symmetric self._stype = PreconditionersType.SYMMETRIC
@staticmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """Static Setup Kernel for Identity (no-op).""" return {} # No data needed @staticmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """Static Setup Kernel for Identity (no-op).""" return {} # No data needed @staticmethod def _apply_kernel(r: Array, backend_mod: Any, sigma: float, **precomputed_data: Any) -> Array: """Static Apply Kernel for Identity.""" return backend_mod.asarray(r) # Ensure correct backend type # Convenience static/class apply used in tests
[docs] @staticmethod def apply(r: Array, backend_mod: Any, sigma: float = 0.0, **precomputed_data: Any) -> Array: """ Static apply convenience wrapper used by tests. Mirrors the signature expected in test files. """ return IdentityPreconditioner._apply_kernel(r=r, backend_mod=backend_mod, sigma=sigma, **precomputed_data)
def _set_standard(self, a: Array, sigma: float, **kwargs): self._precomputed_data_instance = self.__class__._setup_standard_kernel(a, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): self._precomputed_data_instance = self.__class__._setup_gram_kernel(s, sp, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _get_precomputed_data(self) -> dict: return self._get_precomputed_data_instance()
[docs] def __repr__(self) -> str: """ Returns the name and configuration of the Identity preconditioner. """ base_repr = super().__repr__() return f"{base_repr[:-1]}, type={self.type})"
# ===================================================================== #! Jacobi preconditioner # =====================================================================
[docs] class JacobiPreconditioner(Preconditioner): """ Jacobi (Diagonal) Preconditioner. M = diag(A + sigma*I). Uses the diagonal of the (potentially regularized) matrix A as the preconditioner M. Applying the inverse M^{-1}r involves element-wise division by the diagonal entries. Math: M = diag(A) + sigma*I = D + sigma*I M^{-1}r = (D + sigma*I)^{-1} r = [1 / (A_ii + sigma)] * r_i References: - Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10. """ _name = "Jacobi Preconditioner" _type = PreconditionersTypeSym.JACOBI
[docs] def __init__(self, is_positive_semidefinite: bool = False, is_gram : bool = False, backend : str = 'default', # Tolerances specific to Jacobi tol_small : float = _TOLERANCE_SMALL, zero_replacement : float = _TOLERANCE_BIG, **kwargs ): """ Initialize the Jacobi preconditioner. Args: is_positive_semidefinite (bool): If A is positive semi-definite. is_gram (bool): If setting up from Gram matrix factors. backend (str): The computational backend. tol_small (float): Values on diagonal smaller than this (in magnitude) after regularization are considered zero. zero_replacement (float): Value used to replace division by zero (effectively setting the result component to zero, 1 / large_number -> 0). """ super().__init__(is_positive_semidefinite = is_positive_semidefinite, is_gram = is_gram, backend = backend) # Tolerances / constants for safe division self._TOLERANCE_SMALL = tol_small self._zero = zero_replacement # Precomputed data storage self._inv_diag : Optional[Array] = None # Stores 1 / (diag(A) + sigma)
# ----------------------------------------------------------------- @staticmethod def _static_compute_inv_diag(diag_a : Array, sigma : float, backend_mod : Any, tol_small : float, zero_replacement: float) -> Array: """ Static helper to compute inverse diagonal safely. Handles small values and avoids division by zero. ':math:`M^{-1}r = [1 / (A_ii + sigma)] * r_i` This function computes the inverse diagonal of the matrix A. Parameters: diag_a (Array): Diagonal of the matrix A. sigma (float): Regularization parameter. backend_mod (Any): The backend module (e.g., np, jnp). tol_small (float): Tolerance for small values. zero_replacement (float): Value to replace small values with. Returns: Array: The inverse diagonal (1 / (A_ii + sigma)). """ be = backend_mod reg_diag = diag_a + sigma abs_reg_diag = be.abs(reg_diag) is_small = abs_reg_diag < tol_small # Use where for conditional assignment (JAX compatible) # Assign large magnitude for small values before inversion safe_diag = be.where(is_small, be.sign(reg_diag) * zero_replacement, reg_diag) # Avoid division by exact zero if somehow it occurred after clamping safe_diag = be.where(safe_diag == 0.0, zero_replacement, safe_diag) inv_diag = 1.0 / safe_diag # Ensure inverse is zero where original was small inv_diag = be.where(is_small, 0.0, inv_diag) return inv_diag # ----------------------------------------------------------------- @staticmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel for Jacobi from matrix A. """ tol_small = kwargs.get('tol_small', _TOLERANCE_SMALL) # Get tolerances from kwargs or default zero_replacement = kwargs.get('zero_replacement', _TOLERANCE_BIG) # Replacement for small values inverse diag_a = backend_mod.diag(a) inv_diag = JacobiPreconditioner._static_compute_inv_diag( diag_a, sigma, backend_mod, tol_small, zero_replacement ) return {'inv_diag': inv_diag} @staticmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel for Jacobi from Gram factors S, Sp. """ tol_small = kwargs.get('tol_small', _TOLERANCE_SMALL) zero_replacement = kwargs.get('zero_replacement', _TOLERANCE_BIG) be = backend_mod # Use shape directly - avoid float() which breaks JAX tracing n = s.shape[0] if s.shape[0] > 0 else 1 diag_s_dag_s = be.einsum('ij,ji->i', sp, s) diag_a_approx = diag_s_dag_s / n inv_diag = JacobiPreconditioner._static_compute_inv_diag( diag_a_approx, sigma, backend_mod, tol_small, zero_replacement ) return {'inv_diag': inv_diag} # Backwards-compat alias expected by tests (instance method using internal backend and tolerances) def _compute_inv_diag(self, diag_a: Array, sigma: float) -> Array: return JacobiPreconditioner._static_compute_inv_diag( diag_a, sigma, self._backend, self._TOLERANCE_SMALL, self._zero ) # Convenience static apply used in tests
[docs] @staticmethod def apply(r: Array, backend_mod: Any, sigma: float = 0.0, **precomputed_data: Any) -> Array: """Static apply wrapper matching test signature.""" return JacobiPreconditioner._apply_kernel(r=r, backend_mod=backend_mod, sigma=sigma, **precomputed_data)
@staticmethod def _apply_kernel(r: Array, backend_mod: Any, sigma: float, **precomputed_data: Any) -> Array: """ Static Apply Kernel for Jacobi using precomputed data. Applies the preconditioner M^{-1}r using the inverse diagonal. Parameters: r (Array): The residual vector to precondition. backend_mod (Any): The backend module (e.g., np, jnp). sigma (float): Regularization parameter (ignored here). **precomputed_data (Any): Precomputed data from setup_kernel(). Must include 'inv_diag' key. Returns: Array: The preconditioned vector M^{-1}r. """ inv_diag = precomputed_data.get('inv_diag', None) if inv_diag is None: raise ValueError("Jacobi apply kernel requires 'inv_diag' in precomputed_data.") if r.ndim != 1 or inv_diag.ndim != 1 or r.shape[0] != inv_diag.shape[0]: raise ValueError(f"Shape mismatch in Jacobi apply: r={r.shape}, inv_diag={inv_diag.shape}") return inv_diag * r # ----------------------------------------------------------------- #! Instance Setup Methods # ----------------------------------------------------------------- def _set_standard(self, a: Array, sigma: float, **kwargs): # Pass instance tolerances to static kernel via kwargs kwargs.setdefault('tol_small', self._TOLERANCE_SMALL) kwargs.setdefault('zero_replacement', self._zero) self._precomputed_data_instance = self.__class__._setup_standard_kernel(a, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): kwargs.setdefault('tol_small', self._TOLERANCE_SMALL) kwargs.setdefault('zero_replacement', self._zero) self._precomputed_data_instance = self.__class__._setup_gram_kernel(s, sp, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _get_precomputed_data(self) -> dict: return self._get_precomputed_data_instance() # -----------------------------------------------------------------
[docs] def __repr__(self) -> str: """ Returns the name and configuration of the Jacobi preconditioner. """ base_repr = super().__repr__() return f"{base_repr[:-1]}, tol_small={self._TOLERANCE_SMALL})"
# Expose zero_replacement for tests/consumers expecting that name @property def zero_replacement(self) -> float: """Value substituted for unsafe diagonal entries in Jacobi setup.""" return self._zero @zero_replacement.setter def zero_replacement(self, value: float): """Set the Jacobi replacement value for zero or tiny diagonal entries.""" self._zero = value
# ----------------------------------------------------------------- # ===================================================================== #! Complete Cholesky factorization # =====================================================================
[docs] class CholeskyPreconditioner(Preconditioner): """ Cholesky Preconditioner using complete Cholesky decomposition. Suitable for symmetric positive-definite matrices A. The preconditioner M is defined by the Cholesky factorization of the regularized matrix: M = L @ L.T (or L @ L.conj().T for complex), where L is the lower Cholesky factor of (A + sigma*I). Applying the inverse M^{-1}r involves solving two triangular systems: 1. Solve L y = r for y (forward substitution) 2. Solve L.T z = y for z (backward substitution) (or L.conj().T z = y for complex) Note: This performs a *complete* Cholesky factorization. For *incomplete* Cholesky (suitable for large sparse matrices), specialized routines (often from sparse linear algebra libraries) are required. References: - Golub, G. H., & Van Loan, C. F. (2013). Matrix Computations (4th ed.). JHU Press. Chapter 4. - Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10. """ _name = "Cholesky Preconditioner" _type = PreconditionersTypeSym.COMPLETE_CHOLESKY
[docs] def __init__(self, backend: str = 'default'): """ Initialize the Cholesky preconditioner. Args: backend (str): The computational backend ('numpy', 'jax', 'default'). """ # Always requires positive definite matrix (after regularization) # Set is_gram=False, as Cholesky typically applies directly to A. super().__init__(is_positive_semidefinite=True, is_gram=False, backend=backend) self._stype = PreconditionersType.SYMMETRIC self._l: Optional[Array] = None # Stores the lower Cholesky factor
@staticmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel: Computes the Cholesky factor L of the matrix A + sigma*I. Parameters: a (Array): The input matrix A. sigma (float): Regularization parameter. Adds sigma*I to A. backend_mod (Any): The computational backend (e.g., numpy, jax). **kwargs: Additional keyword arguments (not used here). Returns: Dict[str, Any]: A dictionary containing the Cholesky factor 'l'. If decomposition fails, 'l' is set to None. """ be = backend_mod l_factor = None # Default to None if decomposition fails print(f"({CholeskyPreconditioner._name}) Performing Cholesky decomposition...") try: a_reg = a + sigma * be.eye(a.shape[0], dtype=a.dtype) # Use appropriate Cholesky per backend if be is np: l_factor = sla.cholesky(a_reg, lower=True, check_finite=False) elif jnp is not None and be is jnp: l_factor = jsp.linalg.cholesky(a_reg, lower=True) else: l_factor = be.linalg.cholesky(a_reg) print(f"({CholeskyPreconditioner._name}) Cholesky decomposition successful.") except Exception as e: print(f"({CholeskyPreconditioner._name}) Cholesky decomposition failed: {e}") print(f"({CholeskyPreconditioner._name}) Matrix might not be positive definite after regularization (sigma={sigma}).") return {'l': l_factor} @staticmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel: Forms the Gram matrix A = Sp @ S / N and computes its Cholesky factor L. Parameters: s (Array): The matrix S (factor of the Gram matrix). sp (Array): The matrix Sp (conjugate transpose of S or another factor). sigma (float): Regularization parameter. Adds sigma*I to A. backend_mod (Any): The computational backend (e.g., numpy, jax). **kwargs: Additional keyword arguments (not used here). Returns: Dict[str, Any]: A dictionary containing the Cholesky factor 'l'. If decomposition fails, 'l' is set to None. """ be = backend_mod # Use shape directly - avoid float() which breaks JAX tracing n = s.shape[0] if s.shape[0] > 0 else 1 print(f"({CholeskyPreconditioner._name}) Warning: Forming explicit Gram matrix A = Sp @ S / N for Cholesky setup (N={n}).") a_gram = (sp @ s) / n # Call the standard setup kernel on the computed Gram matrix return CholeskyPreconditioner._setup_standard_kernel(a_gram, sigma, backend_mod, **kwargs) @staticmethod def _apply_kernel(r: Array, backend_mod: Any, sigma: float, **precomputed_data: Any) -> Array: """ Static Apply Kernel: Solves the system M^{-1}r using the Cholesky factor L. The system is solved in two steps: 1. Solve L y = r (forward substitution). 2. Solve L^H z = y (backward substitution). Parameters: r (Array): The input vector to precondition. backend_mod (Any): The computational backend (e.g., numpy, jax). sigma (float): Regularization parameter (not used here, as it is applied during setup). **precomputed_data (Any): Precomputed data from the setup kernel. Must include the Cholesky factor 'l'. Returns: Array: The preconditioned vector M^{-1}r. If the Cholesky factor is missing, returns the input vector r. """ l = precomputed_data.get('l') if l is None: print(f"Warning: ({CholeskyPreconditioner._name}) Cholesky factor 'l' is None in apply, returning original vector.") return r be = backend_mod if r.shape[0] != l.shape[0]: raise ValueError(f"Shape mismatch in Cholesky apply: r ({r.shape[0]}) vs L ({l.shape[0]})") try: use_conj = be.iscomplexobj(l) lh = be.conjugate(l).T if use_conj else l.T # Forward substitution: Solve L y = r if be is np: y = sla.solve_triangular(l, r, lower=True, check_finite=False) else: y = jsp.linalg.solve_triangular(l, r, lower=True) # Backward substitution: Solve L^H z = y if be is np: z = sla.solve_triangular(lh, y, lower=False, check_finite=False) else: z = jsp.linalg.solve_triangular(lh, y, lower=False) return z except Exception as e: print(f"({CholeskyPreconditioner._name}) Cholesky triangular solve failed during apply: {e}") return r # Return the original vector if the solve fails # ----------------------------------------------------------------- def _set_standard(self, a: Array, sigma: float, **kwargs): """ Instance method to set up the preconditioner using the matrix A. Parameters: a (Array): The input matrix A. sigma (float): Regularization parameter. Adds sigma*I to A. **kwargs: Additional keyword arguments for the setup kernel. """ self._precomputed_data_instance = self.__class__._setup_standard_kernel(a, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): """ Instance method to set up the preconditioner using Gram matrix factors S and Sp. Parameters: s (Array): The matrix S (factor of the Gram matrix). sp (Array): The matrix Sp (conjugate transpose of S or another factor). sigma (float): Regularization parameter. Adds sigma*I to A. **kwargs: Additional keyword arguments for the setup kernel. """ self._precomputed_data_instance = self.__class__._setup_gram_kernel(s, sp, sigma, self._backend, **kwargs) self._update_instance_apply_func() def _get_precomputed_data(self) -> dict: """ Retrieves the precomputed data for the preconditioner. Returns: dict: A dictionary containing the precomputed Cholesky factor 'l'. """ return self._get_precomputed_data_instance()
[docs] def __repr__(self) -> str: """ Returns a string representation of the Cholesky preconditioner, including its status. Returns: str: A string indicating whether the preconditioner is factorized or not. """ status = "Factorized" if self._precomputed_data_instance and self._precomputed_data_instance.get('l') is not None else "Not Factorized/Failed" base_repr = super().__repr__() return f"{base_repr[:-1]}, status='{status}')"
# ===================================================================== #! SSOR Preconditioner # =====================================================================
[docs] class SSORPreconditioner(Preconditioner): """ Symmetric Successive Over-Relaxation (SSOR) Preconditioner. Suitable for symmetric matrices A, particularly those that are positive definite. It involves forward and backward Gauss-Seidel-like sweeps with a relaxation parameter omega. Math: Let A = D + L + U, where D is diag(A), L is strictly lower part, U is strictly upper part. The SSOR preconditioner matrix M is defined implicitly by its application M^{-1}r: 1. Solve (D/omega + L) y = r (Forward sweep) 2. Solve (D/omega + U) z = D y / omega (Backward sweep - check formulation) A common formulation for the inverse application M^{-1}r is: M^{-1} = omega * (2 - omega) * (D/omega + U)^{-1} D (D/omega + L)^{-1} Applying M^{-1}r involves solving the two triangular systems mentioned above. The choice of omega (0 < omega < 2) is crucial for performance. omega=1 gives Symmetric Gauss-Seidel (SGS). References: - Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10. - Axelsson, O. (1996). Iterative Solution Methods. Cambridge University Press. """ # ----------------------------------------------------------------- _name = "SSOR Preconditioner" # Although often used for SPD, the setup doesn't strictly require it _type = PreconditionersTypeSym.JACOBI # No specific SSOR enum yet, Jacobi is placeholder
[docs] def __init__(self, omega : float = 1.0, is_positive_semidefinite: bool = False, # User hint is_gram : bool = False, # Typically used with explicit A backend : str = 'default', # Backend for computations tol_small : float = _TOLERANCE_SMALL, # For inverting diagonal zero_replacement : float = _TOLERANCE_BIG): """ Initialize the SSOR preconditioner. Args: omega (float): Relaxation parameter (0 < omega < 2). Default is 1.0 (SGS). is_positive_semidefinite (bool): If A is assumed positive semi-definite. is_gram (bool): If setting up from Gram matrix factors (less common for SSOR). backend (str): The computational backend. tol_small (float): Tolerance for safe diagonal inversion. zero_replacement (float): Value for safe diagonal inversion. """ if not (0 < omega < 2): raise ValueError("SSOR relaxation parameter omega must be between 0 and 2.") super().__init__(is_positive_semidefinite=is_positive_semidefinite, is_gram=is_gram, backend=backend) # Assume symmetric application if A is symmetric self._stype = PreconditionersType.SYMMETRIC self._omega = omega self._TOLERANCE_SMALL = tol_small self._zero = zero_replacement
# ----------------------------------------------------------------- #! Properties # ----------------------------------------------------------------- @property def omega(self) -> float: """ Relaxation parameter omega (0 < omega < 2). """ return self._omega @omega.setter def omega(self, value: float): """Set the SSOR relaxation parameter. Changing this value affects future setup calls; call :meth:`set` again to rebuild cached factors for an already configured preconditioner. """ if not (0 < value < 2): raise ValueError("omega must be in (0, 2)") self._omega = value # Note: Need to call set() again to recompute scaled factors # ----------------------------------------------------------------- #! Setup Methods # ----------------------------------------------------------------- @staticmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup for SSOR: compute D (diag of A+sigma I), and strict L, U parts. Returns dict with keys: d_diag, L, U, omega """ be = backend_mod omega = kwargs.get('omega', 1.0) a_reg = a + sigma * be.eye(a.shape[0], dtype=a.dtype) diag_a_reg = be.diag(a_reg) L = be.tril(a_reg, k=-1) U = be.triu(a_reg, k=1) return {'d_diag': diag_a_reg, 'L': L, 'U': U, 'omega': omega} @staticmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup for SSOR from Gram factors by forming A = Sp @ S / n. """ be = backend_mod # Use shape directly - avoid float() which breaks JAX tracing n = s.shape[0] if s.shape[0] > 0 else 1 a_gram = (sp @ s) / n return SSORPreconditioner._setup_standard_kernel(a_gram, sigma, be, **kwargs) def _set_standard(self, a: Array, sigma: float, **kwargs): """ Instance: Calls static setup kernel and stores result. """ # Pass instance omega to static kernel via kwargs kwargs.setdefault('omega', self._omega) self._precomputed_data_instance = self.__class__._setup_standard_kernel( a, sigma, self._backend, **kwargs ) self._update_instance_apply_func() # Recompile apply(r) if data changes def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): """ Instance: Calls static setup kernel and stores result. """ kwargs.setdefault('omega', self._omega) self._precomputed_data_instance = self.__class__._setup_gram_kernel( s, sp, sigma, self._backend, **kwargs ) self._update_instance_apply_func() # Recompile apply(r) if data changes # --- Static Apply --- @staticmethod def _apply_kernel(r : Array, backend_mod : Any, sigma : float, **precomputed_data) -> Array: """ Static apply method for SSOR: forward and backward triangular solves. Solves Mz = r where M = (D/w + L) D^{-1} (D/w + U) / (w(2-w)). This implementation directly performs the forward and backward solves associated with Mz = r, which corresponds to: 1. Solve (D/omega + L) y_temp = r (Forward sweep) 2. Solve (D/omega + U) z = D y_temp / omega (Backward sweep) Precomputed data: d_diag, L, U, omega Returns: The preconditioned vector M^{-1}r (z). """ d_diag = precomputed_data.get('d_diag') L = precomputed_data.get('L') U = precomputed_data.get('U') omega = precomputed_data.get('omega', 1.0) if d_diag is None or L is None or U is None: print("Warning: SSOR factors missing in apply, returning original vector.") return r be = backend_mod # Use the shorter alias # 1. Forward sweep: Solve (D/omega + L) y_temp = r try: # Ensure d_diag/omega doesn't contain zeros; setup already handles small diagonals diag_scaled = d_diag / omega m_fwd = be.diag(diag_scaled) + L if be is np: y_temp = sla.solve_triangular(m_fwd, r, lower=True, check_finite=False) else: y_temp = jsp.linalg.solve_triangular(m_fwd, r, lower=True) # 2. Backward sweep setup: rhs_bwd = (D/omega) * y_temp (element-wise via diag_scaled) rhs_bwd = diag_scaled * y_temp # 3. Backward sweep: Solve (D/omega + U) z = rhs_bwd m_bwd = be.diag(diag_scaled) + U if be is np: z = sla.solve_triangular(m_bwd, rhs_bwd, lower=False, check_finite=False) else: z = jsp.linalg.solve_triangular(m_bwd, rhs_bwd, lower=False) return z except Exception as e: # Catch potential LinAlgError if triangular matrices are singular print(f"SSOR triangular solve failed during apply: {e}") return r # Return original vector if solve fails # ----------------------------------------------------------------- #! String Representation # ----------------------------------------------------------------- def __repr__(self) -> str: base_repr = super().__repr__() return f"{base_repr[:-1]}, omega={self.omega})"
# ----------------------------------------------------------------- # ===================================================================== #! Incomplete Cholesky Preconditioner (using ILU Proxy) # =====================================================================
[docs] class IncompleteCholeskyPreconditioner(Preconditioner): """ Incomplete Cholesky Preconditioner (Approximation using ILU(0)). Suitable for large, sparse, symmetric positive-definite matrices A. This implementation uses SciPy's sparse ILU(0) decomposition as a proxy for IC(0), as a direct IC(0) is not available in SciPy/NumPy/JAX. ILU(0) computes factors L and U such that L @ U approximates A, maintaining the sparsity pattern of A (zero fill-in). Applying the inverse M^{-1}r involves solving L y = P r and U z = y, where P is a permutation matrix handled internally by the ILU object. **Note:** - Requires SciPy and the NumPy backend. Not JAX compatible. - Designed for sparse matrices. Dense input matrices will be converted. - Uses ILU(0) factorization as a proxy for IC(0). Args for `set`: fill_factor (float): See `scipy.sparse.linalg.spilu`. Default 1. drop_tol (float): See `scipy.sparse.linalg.spilu`. Default None. References: - Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10 (discusses both IC and ILU). - SciPy documentation for `scipy.sparse.linalg.spilu`. """ _name = "Incomplete Cholesky Preconditioner (ILU Proxy)" _type = PreconditionersTypeSym.INCOMPLETE_CHOLESKY
[docs] def __init__(self, backend: str = 'default'): """ Initialize the Incomplete Cholesky (ILU Proxy) preconditioner. Args: backend (str): The computational backend. Must be 'numpy'. """ # Requires positive definite matrix for true IC, ILU is more general but works best for SPD. super().__init__(is_positive_semidefinite=True, is_gram=False, backend=backend) self._stype = PreconditionersType.SYMMETRIC # Check for backend compatibility if self._backend != np: raise NotImplementedError(f"{self._name} requires the 'numpy' backend (uses SciPy sparse). " f"Current backend: '{self.backend_str}'.") # Precomputed data storage self._ilu_obj : Optional[spsla.SuperLU] = None # Stores the SuperLU object from spilu self._fill_factor = 1.0 self._drop_tol = None
# ----------------------------------------------------------------- #! Properties # ----------------------------------------------------------------- @property def fill_factor(self) -> float: """ The fill factor for ILU(0) (default 1.0). """ return self._fill_factor @fill_factor.setter def fill_factor(self, value: float): """Set the maximum fill-in multiplier for sparse incomplete factors.""" if value <= 0: raise ValueError("fill_factor must be positive.") self._fill_factor = value @property def drop_tol(self) -> Optional[float]: """ The drop tolerance for ILU(0) (default None). """ return self._drop_tol @drop_tol.setter def drop_tol(self, value: Optional[float]): """Set the sparse factor drop tolerance.""" if value is not None and value < 0: raise ValueError("drop_tol must be non-negative.") self._drop_tol = value # ----------------------------------------------------------------- #! Setup Methods # ----------------------------------------------------------------- def _set_standard(self, a: Array, sigma: float, **kwargs): """ Sets up ILU factorization for A + sigma*I. Converts dense input to sparse CSC format if needed. Parameters: a (Array): The matrix A. sigma (float): Regularization parameter. kwargs (dict): Additional parameters for ILU setup (e.g., fill_factor, drop_tol). Raises: RuntimeError: If the backend is not NumPy or if ILU factorization fails. """ # Ensure NumPy backend if self._backend is not np: raise RuntimeError(f"{self._name} internal error: Backend is not NumPy despite check.") print(f"({self._name}) Setting up ILU factorization...") # Convert dense input to sparse CSC format (efficient for spilu) if not sps.issparse(a): print(f"({self._name}) Warning: Input matrix is dense. Converting to sparse CSC format.") a_sparse = sps.csc_matrix(a) else: # Ensure it's CSC format a_sparse = a.tocsc() # Apply regularization (sparse identity needed) if sigma != 0.0: print(f"({self._name}) Applying regularization sigma={sigma} to sparse matrix.") eye_sparse = sps.identity(a.shape[0], dtype=a.dtype, format='csc') a_reg_sparse = a_sparse + sigma * eye_sparse else: a_reg_sparse = a_sparse # Get ILU parameters from kwargs or use defaults fill_factor = kwargs.get('fill_factor', self._fill_factor) drop_tol = kwargs.get('drop_tol', self._drop_tol) try: # Perform ILU decomposition using SuperLU through spilu self._ilu_obj = spsla.spilu(a_reg_sparse, drop_tol = drop_tol, fill_factor = fill_factor, # panel_size = options.panel_size, # relax = options.relax, ) print(f"({self._name}) ILU decomposition successful.") except RuntimeError as e: print(f"({self._name}) ILU decomposition failed: {e}") print(f"({self._name}) Matrix might be singular or factorization numerically difficult.") self._ilu_obj = None # Ensure object is None if decomposition fails # ----------------------------------------------------------------- def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): """ Set up ILU by forming A = Sp @ S / N first (sparse recommended). """ # Ensure NumPy backend if self._backend is not np: raise RuntimeError(f"{self._name} internal error: Backend is not NumPy despite check.") be = self._backend # Use shape directly - avoid float() which breaks JAX tracing n = s.shape[0] if s.shape[0] > 0 else 1 if n <= 0.0: n = 1.0 # Check if s/sp are sparse s_is_sparse = sps.issparse(s) sp_is_sparse = sps.issparse(sp) if not s_is_sparse or not sp_is_sparse: print(f"({self._name}) Warning: Forming explicit Gram matrix A = Sp @ S / N for ILU setup (N={n}). " "Input factors should ideally be sparse.") # Convert to sparse if necessary before matmul s_mat = sps.csc_matrix(s) if not s_is_sparse else s.tocsc() sp_mat = sps.csc_matrix(sp) if not sp_is_sparse else sp.tocsc() a_gram = (sp_mat @ s_mat) / n else: # Perform sparse matrix multiplication a_gram = (sp.tocsc() @ s.tocsc()) / n # Now call the standard setup self._set_standard(a_gram, sigma, **kwargs) # ----------------------------------------------------------------- def _get_precomputed_data(self) -> dict: """ Returns the computed ILU object. """ # Let apply handle None case return {'ilu_obj': self._ilu_obj} # ----------------------------------------------------------------- #! Static Apply Method # -----------------------------------------------------------------
[docs] @staticmethod def apply(r: Array, sigma: float, backend_module: Any, ilu_obj: Optional[spsla.SuperLU]) -> Array: """ Static apply method for ILU: solves Mz = r using the LU factors. Args: r (Array): The residual vector. sigma (float): Regularization (used during setup). backend_module (Any): The backend numpy module (must be numpy). ilu_obj (Optional[SuperLU]): The precomputed ILU object from spilu. Returns: Array: The preconditioned vector M^{-1}r, or r if ilu_obj is None. """ if backend_module is not np: # This check is important because spsla.SuperLU is SciPy/NumPy specific raise RuntimeError("IncompleteCholeskyPreconditioner.apply requires NumPy backend.") if ilu_obj is None: # Factorization failed or not performed, return original vector print("Warning: ILU object is None in apply, returning original vector.") return r try: # Use the solve method of the SuperLU object # This handles permutations and solves L(U(x)) = Pr return ilu_obj.solve(r) # Input r should be a NumPy array except Exception as e: # Catch specific linalg errors if possible print(f"ILU solve failed during apply: {e}") # Return original vector if solve fails return r
# -----------------------------------------------------------------
[docs] def __repr__(self) -> str: """ Returns the name and configuration of the Incomplete Cholesky preconditioner. """ # Check if the ILU object is None (not factorized) or not status = "Factorized" if self._ilu_obj is not None else "Not Factorized/Failed" base_repr = super().__repr__() # Add ILU specific params return f"{base_repr[:-1]}, ILU_status='{status}')"
# ----------------------------------------------------------------- # ===================================================================== #! Incomplete LU Preconditioner # =====================================================================
[docs] class ILUPreconditioner(Preconditioner): """ Incomplete LU Preconditioner (ILU(0)). Suitable for large, sparse, non-symmetric matrices A. This implementation uses SciPy's sparse ILU(0) decomposition (`scipy.sparse.linalg.spilu`). ILU(0) computes factors L and U such that L @ U approximates A, maintaining the sparsity pattern of A (zero fill-in). Applying the inverse M^{-1}r involves solving L y = P r and U z = y, where P is a permutation matrix handled internally by the ILU object. **Note:** - Requires SciPy and the NumPy backend. Not JAX compatible. - Designed for sparse matrices. Dense input matrices will be converted. Args for `set`: fill_factor (float): See `scipy.sparse.linalg.spilu`. Default 1. drop_tol (float): See `scipy.sparse.linalg.spilu`. Default None. References: - Saad, Y. (2003). Iterative Methods for Sparse Linear Systems (2nd ed.). SIAM. Chapter 10. - SciPy documentation for `scipy.sparse.linalg.spilu`. """ _name = "Incomplete LU Preconditioner" _type = PreconditionersTypeNoSym.INCOMPLETE_LU
[docs] def __init__(self, is_positive_semidefinite: bool = False, is_gram : bool = False, backend: str = 'default'): """ Initialize the Incomplete LU preconditioner. Args: is_positive_semidefinite (bool): If A is assumed positive semi-definite (not required for ILU). is_gram (bool): If setting up from Gram matrix factors. backend (str): The computational backend. Must be 'numpy'. """ super().__init__(is_positive_semidefinite=is_positive_semidefinite, is_gram=is_gram, backend=backend) self._stype = PreconditionersType.NONSYMMETRIC # Check for backend compatibility if self._backend != np: raise NotImplementedError(f"{self._name} requires the 'numpy' backend (uses SciPy sparse). " f"Current backend: '{self.backend_str}'.") # Precomputed data storage self._ilu_obj : Optional[spsla.SuperLU] = None # Stores the SuperLU object from spilu self._fill_factor = 1.0 self._drop_tol = None
# ----------------------------------------------------------------- #! Properties # ----------------------------------------------------------------- @property def fill_factor(self) -> float: """ The fill factor for ILU(0) (default 1.0). """ return self._fill_factor @fill_factor.setter def fill_factor(self, value: float): """Set the maximum fill-in multiplier for sparse ILU factors.""" if value <= 0: raise ValueError("fill_factor must be positive.") self._fill_factor = value @property def drop_tol(self) -> Optional[float]: """ The drop tolerance for ILU(0) (default None). """ return self._drop_tol @drop_tol.setter def drop_tol(self, value: Optional[float]): """Set the sparse ILU drop tolerance.""" if value is not None and value < 0: raise ValueError("drop_tol must be non-negative.") self._drop_tol = value # ----------------------------------------------------------------- #! Setup Methods # ----------------------------------------------------------------- @staticmethod def _setup_standard_kernel(a: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel: Computes ILU factorization of A + sigma*I. """ if backend_mod is not np: # Should be caught earlier, but safe guard raise RuntimeError("ILUPreconditioner requires NumPy backend.") # Convert dense input to sparse CSC format (efficient for spilu) if not sps.issparse(a): print(f"({ILUPreconditioner._name}) Warning: Input matrix is dense. Converting to sparse CSC format.") a_sparse = sps.csc_matrix(a) else: # Ensure it's CSC format a_sparse = a.tocsc() # Apply regularization (sparse identity needed) if sigma != 0.0: print(f"({ILUPreconditioner._name}) Applying regularization sigma={sigma} to sparse matrix.") eye_sparse = sps.identity(a.shape[0], dtype=a.dtype, format='csc') a_reg_sparse = a_sparse + sigma * eye_sparse else: a_reg_sparse = a_sparse # Get ILU parameters from kwargs (defaults defined in class are not accessible in static context directly unless passed) # We assume defaults are passed via kwargs or we rely on spilu defaults if not present fill_factor = kwargs.get('fill_factor', 1.0) drop_tol = kwargs.get('drop_tol', None) ilu_obj = None try: # Perform ILU decomposition using SuperLU through spilu ilu_obj = spsla.spilu(a_reg_sparse, drop_tol = drop_tol, fill_factor = fill_factor, ) print(f"({ILUPreconditioner._name}) ILU decomposition successful.") except RuntimeError as e: print(f"({ILUPreconditioner._name}) ILU decomposition failed: {e}") print(f"({ILUPreconditioner._name}) Matrix might be singular or factorization numerically difficult.") return {'ilu_obj': ilu_obj} @staticmethod def _setup_gram_kernel(s: Array, sp: Array, sigma: float, backend_mod: Any, **kwargs) -> Dict[str, Any]: """ Static Setup Kernel: Computes ILU factorization from Gram factors S, Sp. Forms A = Sp @ S / N. """ if backend_mod is not np: raise RuntimeError("ILUPreconditioner requires NumPy backend.") # Use shape directly n = s.shape[0] if s.shape[0] > 0 else 1.0 if n <= 0.0: n = 1.0 # Check if s/sp are sparse s_is_sparse = sps.issparse(s) sp_is_sparse = sps.issparse(sp) if not s_is_sparse or not sp_is_sparse: print(f"({ILUPreconditioner._name}) Warning: Forming explicit Gram matrix A = Sp @ S / N for ILU setup (N={n}). " "Input factors should ideally be sparse.") # Convert to sparse if necessary before matmul s_mat = sps.csc_matrix(s) if not s_is_sparse else s.tocsc() sp_mat = sps.csc_matrix(sp) if not sp_is_sparse else sp.tocsc() a_gram = (sp_mat @ s_mat) / n else: # Perform sparse matrix multiplication a_gram = (sp.tocsc() @ s.tocsc()) / n return ILUPreconditioner._setup_standard_kernel(a_gram, sigma, backend_mod, **kwargs) def _set_standard(self, a: Array, sigma: float, **kwargs): # Merge instance properties with kwargs to pass to static kernel kwargs.setdefault('fill_factor', self._fill_factor) kwargs.setdefault('drop_tol', self._drop_tol) self._precomputed_data_instance = self.__class__._setup_standard_kernel(a, sigma, self._backend, **kwargs) # Update instance properties if they were modified/returned? # Actually _setup_standard_kernel returns data dict, not properties. # But for ILU, ilu_obj is in data dict. self._ilu_obj = self._precomputed_data_instance.get('ilu_obj') self._update_instance_apply_func() def _set_gram(self, s: Array, sp: Array, sigma: float, **kwargs): kwargs.setdefault('fill_factor', self._fill_factor) kwargs.setdefault('drop_tol', self._drop_tol) self._precomputed_data_instance = self.__class__._setup_gram_kernel(s, sp, sigma, self._backend, **kwargs) self._ilu_obj = self._precomputed_data_instance.get('ilu_obj') self._update_instance_apply_func() def _get_precomputed_data(self) -> dict: """ Returns the computed ILU object. """ # Let apply handle None case return {'ilu_obj': self._ilu_obj} # ----------------------------------------------------------------- #! Static Apply Method # -----------------------------------------------------------------
[docs] @staticmethod def apply(r: Array, sigma: float, backend_module: Any, ilu_obj: Optional[spsla.SuperLU]) -> Array: """ Static apply method for ILU: solves Mz = r using the LU factors. Args: r (Array): The residual vector. sigma (float): Regularization (used during setup). backend_module (Any): The backend numpy module (must be numpy). ilu_obj (Optional[SuperLU]): The precomputed ILU object from spilu. Returns: Array: The preconditioned vector M^{-1}r, or r if ilu_obj is None. """ if backend_module is not np: # This check is important because spsla.SuperLU is SciPy/NumPy specific raise RuntimeError("ILUPreconditioner.apply requires NumPy backend.") if ilu_obj is None: # Factorization failed or not performed, return original vector print("Warning: ILU object is None in apply, returning original vector.") return r try: # Use the solve method of the SuperLU object # This handles permutations and solves L(U(x)) = Pr return ilu_obj.solve(r) # Input r should be a NumPy array except Exception as e: # Catch specific linalg errors if possible print(f"ILU solve failed during apply: {e}") # Return original vector if solve fails return r
@staticmethod def _apply_kernel(r: Array, backend_mod: Any, sigma: float, **precomputed_data: Any) -> Array: """ Static kernel implementation that calls the public static apply. Needed for the base class machinery. """ ilu_obj = precomputed_data.get('ilu_obj') return ILUPreconditioner.apply(r, sigma, backend_mod, ilu_obj) # -----------------------------------------------------------------
[docs] def __repr__(self) -> str: """ Returns the name and configuration of the Incomplete LU preconditioner. """ # Check if the ILU object is None (not factorized) or not status = "Factorized" if self._ilu_obj is not None else "Not Factorized/Failed" base_repr = super().__repr__() # Add ILU specific params return f"{base_repr[:-1]}, ILU_status='{status}')"
# ----------------------------------------------------------------- # ===================================================================== #! Choose wisely # ===================================================================== def _resolve_precond_type(precond_id: Any) -> Union[PreconditionersTypeSym, PreconditionersTypeNoSym]: """ Helper to convert string/int/Enum id to a specific Enum member. Args: precond_id (Any): Identifier (instance, Enum, str, int). Raises: ValueError: If the id is not recognized. TypeError: If the id is of an unsupported type. """ # Check if precond_id is None precond_type = None if isinstance(precond_id, str): name = precond_id.strip().replace('-', '_').replace(' ', '_').upper() try: precond_type = PreconditionersTypeSym[name] except KeyError as e: try: precond_type = PreconditionersTypeNoSym[name] except KeyError as e: raise ValueError(f"Unknown preconditioner name: '{precond_id}'.") from e elif isinstance(precond_id, int): try: precond_type = PreconditionersTypeSym(precond_id) except ValueError as e: try: precond_type = PreconditionersTypeNoSym(precond_id) except ValueError as e: raise ValueError(f"Unknown preconditioner value: {precond_id}.") from e elif isinstance(precond_id, (PreconditionersTypeSym, PreconditionersTypeNoSym)): precond_type = precond_id else: # Unsupported identifier type raise TypeError(f"Unsupported type for precond_id: {type(precond_id)}. Expected Enum, str, or int.") return precond_type # ===================================================================== def _get_precond_class_and_defaults(precond_type: Union[PreconditionersTypeSym, PreconditionersTypeNoSym]) -> Tuple[Type[Preconditioner], dict]: """ Helper to map Enum type to class and set default kwargs. Returns: target_class (Type[Preconditioner]): The target preconditioner class. defaults (dict): Default arguments for the preconditioner constructor. Raises: ValueError: If the preconditioner type is not recognized. TypeError: If the preconditioner type is of an unsupported type. """ target_class : Type[Preconditioner] = None defaults : dict = {} if isinstance(precond_type, PreconditionersTypeSym): match precond_type: case PreconditionersTypeSym.IDENTITY: target_class = IdentityPreconditioner defaults['is_positive_semidefinite'] = True case PreconditionersTypeSym.JACOBI: target_class = JacobiPreconditioner case PreconditionersTypeSym.INCOMPLETE_CHOLESKY: target_class = IncompleteCholeskyPreconditioner case PreconditionersTypeSym.COMPLETE_CHOLESKY: target_class = CholeskyPreconditioner case PreconditionersTypeSym.SSOR: target_class = SSORPreconditioner case _: raise ValueError(f"Symmetric type {precond_type} not handled.") elif isinstance(precond_type, PreconditionersTypeNoSym): match precond_type: case PreconditionersTypeNoSym.IDENTITY: target_class = IdentityPreconditioner defaults['is_positive_semidefinite'] = False case PreconditionersTypeNoSym.INCOMPLETE_LU: target_class = ILUPreconditioner case _: raise ValueError(f"Non-Symmetric type {precond_type} not handled.") elif isinstance(precond_type, (int)): # 0 or 1 only for now if precond_type == 0: target_class = IdentityPreconditioner defaults['is_positive_semidefinite'] = True elif precond_type == 1: target_class = JacobiPreconditioner else: raise ValueError(f"Unknown preconditioner integer value: {precond_type}.") elif precond_type is None: raise ValueError("Preconditioner type could not be resolved (None).") else: raise TypeError("Internal error: Invalid precond_type.") return target_class, defaults # ===================================================================== #! Main Factory Function # =====================================================================
[docs] def choose_precond(precond_id: Any, **kwargs) -> Preconditioner: """ Factory function to select and instantiate a preconditioner. Accepts various identifiers (Enum, str, int, instance) and passes kwargs to the specific preconditioner's constructor. Args: precond_id (Any): Identifier (instance, Enum, str, int). **kwargs: Additional arguments for the constructor (e.g., backend='jax'). Returns: Preconditioner: An instance of the selected preconditioner. """ if precond_id is None: return None if isinstance(precond_id, str): precond_id = precond_id.strip().replace('-', '_').replace(' ', '_').upper() if precond_id == 'NONE': return None elif precond_id == 'DEFAULT': precond_id = PreconditionersTypeSym.JACOBI elif precond_id == 'SYMMETRIC_DEFAULT': precond_id = PreconditionersTypeSym.JACOBI # 1. Handle Instance Passthrough if isinstance(precond_id, Preconditioner): return precond_id # 2. Resolve ID to Enum Type try: precond_type = _resolve_precond_type(precond_id) except (ValueError, TypeError) as e: # Re-raise with more context if needed, or just let the original error propagate raise e # 3. Get Target Class and Default Kwargs try: target_class, default_kwargs = _get_precond_class_and_defaults(precond_type) except (ValueError, TypeError, NotImplementedError) as e: raise e # 4. Combine Defaults and User Kwargs (User kwargs override defaults) final_kwargs = default_kwargs.copy() final_kwargs.update(kwargs) # 5. Filter Kwargs for Constructor and Instantiate try: import inspect sig = inspect.signature(target_class.__init__) valid_params = sig.parameters # Check if the class accepts **kwargs (VAR_KEYWORD) has_varkw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in valid_params.values()) # Always allow 'logger' to be passed if the base class expects it, # even if the subclass __init__ signature doesn't list it explicitly. # Ideally, subclasses should accept **kwargs and pass to super(). filtered_kwargs = {} for k, v in final_kwargs.items(): if k in valid_params or has_varkw: filtered_kwargs[k] = v # SPECIAL CASE: If subclass doesn't take **kwargs, we can't pass logger # unless we modify the subclass. return target_class(**filtered_kwargs) except Exception as e: print(f"Error instantiating {target_class.__name__} with kwargs {filtered_kwargs}: {e}") raise e return None # Fallback, should not reach here if everything is correct
# ===================================================================== #! End of File # =====================================================================