r"""
general_python.algebra.utils
===========================
This module provides utilities for importing and managing the linear algebra backend (NumPy/JAX),
random number generation, JIT compilation, and backend configuration for the general_python package.
Features
--------
- BackendManager: Class to detect and manage the active backend (NumPy/JAX), including linear algebra, random number generation, and SciPy modules.
- Functions to retrieve backend components (`get_backend`, `get_global_backend`).
- Global access to the active backend via the `backend_mgr` instance.
- Utilities for JIT compilation (`maybe_jit`), hardware info (`get_hardware_info`), and array padding (`pad_array`).
- Clear version reporting and backend status printing.
Usage
-----
Import the backend manager and utilities:
.. code-block:: python
from ..algebra.utils import get_backend, backend_mgr, maybe_jit
xp = backend_mgr.np # NumPy or JAX numpy
rng = backend_mgr.default_rng
@maybe_jit
def my_func(x):
return xp.sum(x)
Testing
-------
To run tests for this module and the general_python package:
.. code-block:: bash
pytest
# or
python -m unittest discover -s Python/test
Test coverage includes:
- Singleton identity and initialization
- Operator algebra correctness
- Monte Carlo sampling
- NQS training and evaluation
See the `test/` directory for details.
------------------------------------------------------------------------
File : algebra/utils.py
Author : Maks Kliczkowski
Email : maksymilian.kliczkowski@pwr.edu.pl
------------------------------------------------------------------------
"""
# Import the required modules
import os
import inspect
import logging
import multiprocessing
import random as py_random
from functools import wraps
from contextlib import contextmanager
from typing import Union, Optional, TypeAlias, Type, Tuple, Any, Callable, List, Dict, Literal, TYPE_CHECKING
from dataclasses import dataclass
# ---------------------------------------------------------------------
#! Global logger placeholder (resolved lazily through qes_globals)
# ---------------------------------------------------------------------
if TYPE_CHECKING: # pragma: no cover - import only for type checking
from ..common.flog import Logger as QESLogger
log: Optional[Union["QESLogger", logging.Logger]] = None # Assigned during _qes_initialize_utils
# ---------------------------------------------------------------------
import numpy as np
import numpy.random as np_random
import scipy as sp
# ---------------------------------------------------------------------
#! Enviroment variable names
num_cores = os.cpu_count()
PY_NUM_CORES_STR : str = "PY_NUM_CORES"
#! os environment variables
PY_JAX_AVAILABLE_STR : str = "PY_JAX_AVAILABLE"
PY_JAX_DONT_USE_STR : str = "PY_JAX_DONT_USE"
PY_FLOATING_POINT_STR : str = "PY_FLOATING_POINT"
PY_BACKEND_STR : str = "PY_BACKEND"
PY_GLOBAL_SEED_STR : str = "PY_GLOBAL_SEED"
PY_INFO_VERBOSE : str = "PY_BACKEND_INFO"
PY_SPIN_VALUE_STR : str = "PY_SPIN_VALUE"
PY_UTILS_INIT_DONE_STR : str = "PY_UTILS_INIT_DONE"
PY_BACKEND_REPR : str = "PY_BACKEND_REPR"
PY_BACKEND_SPIN : str = "PY_BACKEND_DEF_SPIN"
# ---------------------------------------------------------------------
JIT : Callable = lambda x: x # Default JIT function (identity)
DEFAULT_SEED : int = 42
DEFAULT_BACKEND : str = "numpy"
DEFAULT_BACKEND_KEY : Optional[str] = None
DEFAULT_NP_INT_TYPE : Type = np.int64
DEFAULT_NP_FLOAT_TYPE : Type = np.float64
DEFAULT_NP_CPX_TYPE : Type = np.complex128
BACKEND_REPR : float = 0.5
BACKEND_DEF_SPIN : bool = True
os.environ[PY_BACKEND_REPR] = "0.5" # default to 0.5
os.environ[PY_BACKEND_SPIN] = "1" # default to spin systems
DEFAULT_JP_INT_TYPE : Optional[Type] = None
DEFAULT_JP_FLOAT_TYPE : Optional[Type] = None
DEFAULT_JP_CPX_TYPE : Optional[Type] = None
# ---------------------------------------------------------------------
def _log_message(msg, lvl = 0, **kwargs):
"""
Logs a message using the global logger.
This function ensures the logger is only imported when needed.
Parameters:
msg (str):
The message to log.
lvl (int):
The indentation level for the message.
"""
if not PY_INFO_VERBOSE == "1":
return
text = "\t" * lvl + msg
if log is None:
# Lazy import of the logger
print(msg)
else:
log.info(text, **kwargs)
def _initialize_logger() -> Union["QESLogger", logging.Logger]:
"""
Resolve the package logger in a robust way.
Priority:
1) package-relative singleton logger (normal import path)
2) absolute ``general_python`` path (legacy usage)
3) silent stdlib logger fallback
"""
try:
from ..common.flog import get_global_logger
return get_global_logger()
except Exception:
try:
from general_python.common.flog import get_global_logger
return get_global_logger()
except Exception:
fallback = logging.getLogger(__name__)
if not fallback.handlers:
fallback.addHandler(logging.NullHandler())
return fallback
# ---------------------------------------------------------------------
#! SET VARIABLES
# ---------------------------------------------------------------------
PY_GLOBAL_SEED : int = int(os.environ.get(PY_GLOBAL_SEED_STR, DEFAULT_SEED))
os.environ[PY_GLOBAL_SEED_STR] = str(PY_GLOBAL_SEED)
PY_NUM_CORES : int = int(os.environ.get(PY_NUM_CORES_STR, str(num_cores)))
os.environ[PY_NUM_CORES_STR] = str(PY_NUM_CORES)
PREFER_32BIT : bool = os.environ.get(PY_FLOATING_POINT_STR, "64bit").lower() in ["32bit", "32", "float32", "float"]
PY_FLOATING_POINT : str = os.environ.get(PY_FLOATING_POINT_STR, "float32" if PREFER_32BIT else "float64")
PY_USE_32BIT : bool = PREFER_32BIT
os.environ[PY_FLOATING_POINT_STR] = PY_FLOATING_POINT
PY_NP_INT_TYPE : Type = np.int32 if PY_USE_32BIT else np.int64
PY_NP_FLOAT_TYPE : Type = np.float32 if PY_USE_32BIT else np.float64
PY_NP_CPX_TYPE : Type = np.complex64 if PY_USE_32BIT else np.complex128
PY_BACKEND : str = os.environ.get(PY_BACKEND_STR, DEFAULT_BACKEND).lower()
os.environ[PY_BACKEND_STR] = PY_BACKEND
# Define PY_JAX_DONT_USE before using it
PY_JAX_DONT_USE : bool = os.environ.get(PY_JAX_DONT_USE_STR, "0") in ("1", "true", "True")
# by default, use numpy (can override with PY_JAX_DONT_USE)
PREFER_JAX : bool = (PY_BACKEND == "jax") and not PY_JAX_DONT_USE
PREFER_64BIT : bool = True if not PREFER_32BIT else False
#! Backend Detection
JAX_AVAILABLE: bool = False
jax = None
jnp = None
jsp = None
jrn = None
jax_jit = lambda x: x
jcfg = None
# --- JAX-related placeholders ---
JAX_AVAILABLE: bool = False
jax: Optional[Any] = None
jnp: Optional[Any] = None
jsp: Optional[Any] = None
jrn: Optional[Any] = None
jcfg: Optional[Any] = None
# --- Type Aliases (with defaults) ---
Array : TypeAlias = np.ndarray # Default to NumPy array
PRNGKey : TypeAlias = Any # Keep as 'Any' to avoid import errors if JAX is not present
JaxDevice : TypeAlias = Any # Keep as 'Any' to avoid import errors if JAX is not present
# ---------------------------------------------------------------------
PY_JAX_AVAILABLE: bool = JAX_AVAILABLE
os.environ[PY_JAX_AVAILABLE_STR] = "1" if PY_JAX_AVAILABLE else "0"
# Note: PY_JAX_DONT_USE is already defined earlier
# ---------------------------------------------------------------------
#! Type defaults
DEFAULT_NP_INT_TYPE = PY_NP_INT_TYPE
DEFAULT_NP_FLOAT_TYPE = PY_NP_FLOAT_TYPE
DEFAULT_NP_CPX_TYPE = PY_NP_CPX_TYPE
DEFAULT_JP_INT_TYPE = None
DEFAULT_JP_FLOAT_TYPE = None
DEFAULT_JP_CPX_TYPE = None
#! Type Aliases
if JAX_AVAILABLE and jnp:
Array : TypeAlias = Union[np.ndarray, jnp.ndarray]
PRNGKey : TypeAlias = Any # jax.random.PRNGKeyArray
JaxDevice : TypeAlias = Any # Placeholder for jax device type
else:
Array : TypeAlias = np.ndarray
PRNGKey : TypeAlias = None
JaxDevice : TypeAlias = None
#! These will be updated by the backend_mgr after initialization.
ACTIVE_BACKEND_NAME : str = "numpy"
ACTIVE_NP_MODULE : Any = np
ACTIVE_RANDOM : Any = np_random.default_rng(DEFAULT_SEED) # Start with a default numpy RNG
ACTIVE_SCIPY_MODULE : Any = sp
ACTIVE_JIT : Callable = JIT
ACTIVE_JAX_KEY : Optional[PRNGKey] = None
ACTIVE_INT_TYPE : Type = np.int64
ACTIVE_FLOAT_TYPE : Type = np.float64
ACTIVE_COMPLEX_TYPE : Type = np.complex128
backend_mgr : 'BackendManager' = None # Will be set after BackendManager is defined
# ---------------------------------------------------------------------
#! Global methods
# ---------------------------------------------------------------------
[docs]
def is_jax_array(x: Any) -> bool:
'''
Checks if an object is likely a JAX array (including traced).
Parameters
----------
x : Any
The object to check.
Returns
-------
bool
True if x is a JAX array (including traced), False otherwise.
'''
if not JAX_AVAILABLE:
return False
try:
# modern JAX
from jax import Array as JaxArray # type: ignore
return isinstance(x, JaxArray)
except Exception:
# fallback (older versions / tracers)
try:
from jax.core import Tracer as JaxTracer # type: ignore
return hasattr(x, "aval") or isinstance(x, JaxTracer)
except Exception:
return hasattr(x, "aval")
# ----
is_traced_jax = is_jax_array
# ---------------------------------------------------------------------
[docs]
def get_backend(backend_spec : Union[str, Any, None] = None,
random : bool = False,
seed : Optional[int] = None,
scipy : bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Return backend modules based on the provided specifier.
Delegates to the global `backend_mgr.get_backend_modules`.
Parameters
----------
backend_spec : str or module or None, optional
Backend specifier ("numpy", "jax", `np`, `jnp`, "default", None).
Defaults to the globally active backend.
random : bool, optional
If True, include the random module/state. For JAX, also return a PRNG key.
For NumPy, returns a seeded RNG instance. Default is False.
**For JAX backend, ensure you split the returned PRNG key before use.**
seed : int, optional
Seed for the random component. If None, uses the global default seed
to generate the component for this call. Providing a seed here creates
a *new* RNG/Key for this call based on that specific seed.
scipy : bool, optional
If True, also return the associated SciPy module. Default is False.
Returns
-------
module or tuple
Requested backend components. See `BackendManager.get_backend_modules` docs.
**Example for using JAX backend with random number generation:**
>>> import general_python.algebra.utils as abu
>>> import numpy as np
>>> if abu.JAX_AVAILABLE:
... jax_np, (jax_rnd, key), jax_sp = abu.get_backend("jax", random=True, scipy=True, seed=42)
... key, subkey = jax_rnd.split(key) # Split the key!
... random_vector = jax_rnd.uniform(subkey, shape=(5,)) # Use subkey
... print(random_vector)
"""
return backend_mgr.get_backend_modules(backend_spec, use_random=random, seed=seed, use_scipy=scipy)
# ---------------------------------------------------------------------
[docs]
def get_global_backend(random: bool = False, seed: Optional[int] = None, scipy: bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Return the globally configured default backend modules.
Delegates to `backend_mgr.get_global_backend_modules`.
Parameters
----------
random : bool, optional
If True, include the random module/state and potentially a key/RNG.
seed : int, optional
Optional seed for this specific request's random component.
scipy : bool, optional
If True, also return the associated SciPy module.
Returns
-------
module or tuple
The global default backend module(s). See `BackendManager.get_backend_modules`.
"""
return backend_mgr.get_global_backend_modules(use_random=random, seed=seed, use_scipy=scipy)
# ---------------------------------------------------------------------
[docs]
def maybe_jit(func):
"""
Maybe apply JAX JIT compilation to the function.
"""
if not JAX_AVAILABLE or os.getenv("QES_JIT", "1") in ("0", "false", "False"):
return func
from jax import jit as _jit
sig = inspect.signature(func)
if 'backend' not in sig.parameters:
raise ValueError(f"@maybe_jit: '{func.__name__}' must accept 'backend' kwarg.")
jitted = _jit(func, static_argnames=("backend",))
@wraps(func)
def wrapper(*args, **kwargs):
b = kwargs.get("backend", None)
if b is None:
return jitted(*args, **kwargs)
if (isinstance(b, str) and b.lower() in ("np","numpy")) or (b is np):
return func(*args, **kwargs) # no JIT for NumPy
return jitted(*args, **kwargs)
return wrapper
# ---------------------------------------------------------------------
#! Types
# ---------------------------------------------------------------------
#! Define a registry for NumPy and JAX dtypes
DType = Union[Type[np.generic], Any]
_DTYPE_REGISTRY: Dict[str, Dict[Literal['numpy', 'jax'], DType]] = {}
#! Create a reverse mapping from dtype to name
_TYPE_TO_NAME: Dict[Any, str] = {}
for name, backends in _DTYPE_REGISTRY.items():
for backend, dtype in backends.items():
if dtype is not None:
_TYPE_TO_NAME[dtype] = name
_TYPE_TO_NAME[np.complex64] = 'complex64'
_TYPE_TO_NAME[np.complex128] = 'complex128'
_TYPE_TO_NAME[np.float32] = 'float32'
_TYPE_TO_NAME[np.float64] = 'float64'
_TYPE_TO_NAME[np.int32] = 'int32'
_TYPE_TO_NAME[np.int64] = 'int64'
_TYPE_TO_NAME[int] = 'int64'
_TYPE_TO_NAME[float] = 'float64'
_TYPE_TO_NAME[complex] = 'complex128'
# Also add dtype instances (e.g., np.dtype('float64'))
_TYPE_TO_NAME[np.dtype('complex64')] = 'complex64'
_TYPE_TO_NAME[np.dtype('complex128')] = 'complex128'
_TYPE_TO_NAME[np.dtype('float32')] = 'float32'
_TYPE_TO_NAME[np.dtype('float64')] = 'float64'
_TYPE_TO_NAME[np.dtype('int32')] = 'int32'
_TYPE_TO_NAME[np.dtype('int64')] = 'int64'
[docs]
def distinguish_type(typek: Any, backend: Literal['numpy', 'jax'] = 'numpy') -> DType:
"""
Given a type (e.g. np.float32, jnp.int64, or int), return the corresponding
dtype object in either NumPy or JAX.
Parameters
----------
typek
A dtype class or Python int; may be from numpy, jax.numpy, or the builtin int.
backend
'numpy' or 'jax' — which library the returned dtype should belong to.
Returns
-------
dtype
The requested dtype object (e.g. np.float64 or jnp.int32).
Raises
------
ValueError
If `typek` isn't one of the supported types, or if you ask for JAX
but JAX isn't available.
"""
try:
name = _TYPE_TO_NAME[typek]
except KeyError:
raise ValueError(f"Unsupported state type: {typek!r}")
entry = _DTYPE_REGISTRY[name]
dtype = entry.get(backend)
if dtype is None:
if backend == 'jax' and not JAX_AVAILABLE:
raise ValueError("JAX not installed or not available")
raise ValueError(f"Type {name!r} not defined for backend {backend!r}")
return dtype
# ---------------------------------------------------------------------
#! Functions
# ---------------------------------------------------------------------
[docs]
def get_hardware_info() -> Tuple[int, int]:
"""
Get the number of available JAX devices and CPU cores.
Returns:
n_devices :
Number of JAX devices (e.g., GPUs/TPUs) if JAX is available, else 0.
n_threads :
Number of CPU cores available to the system.
"""
n_devices = 0
if backend_mgr.is_jax_available and jax:
try:
n_devices = jax.device_count()
except Exception as e:
log.warning(f"Could not get JAX device count: {e}")
n_devices = 0 # Fallback if detection fails
n_threads = multiprocessing.cpu_count()
_log_message(f"Detected CPU cores: {n_threads}", 1)
if n_devices > 0:
_log_message(f"Detected devices: {n_devices}", 1)
else:
_log_message("No device detected.", 1)
return n_devices, n_threads
# ---------------------------------------------------------------------
[docs]
@dataclass
class RNGManager:
"""Container for synchronized random-number generator state.
Attributes
----------
np_rng
NumPy random generator used by NumPy-backed helpers.
jax_rng
JAX PRNG key or key-like state used by JAX-backed helpers.
py_rng
Python ``random.Random`` instance for standard-library randomness.
"""
np_rng : np.random.Generator | None
jax_rng : Any | None
py_rng : py_random.Random | None
# ---------------------------------------------------------------------
[docs]
class BackendManager:
"""
Manages the numerical backend (NumPy or JAX) state.
Provides access to the appropriate linear algebra module (np/jnp),
random number generator, SciPy module, JIT compiler, and backend info.
Attributes:
is_jax_available (bool):
True if JAX was successfully imported.
name (str):
Name of the active backend ("numpy" or "jax").
np (module):
The active array module (numpy or jax.numpy).
random (module):
The active random module (numpy.random or jax.random).
scipy (module):
The active SciPy module (scipy or jax.scipy).
key (Optional[PRNGKey]):
The default JAX PRNG key (if JAX is active).
jit (Callable):
The JIT compiler function (jax.jit or identity).
default_seed (int):
The seed used for default RNG initialization.
default_rng (np.random.Generator | np_random):
Default NumPy RNG instance.
default_jax_key (Optional[PRNGKey]):
Default JAX key instance.
int_dtype (Type):
Default integer type for the *active* backend.
float_dtype (Type):
Default float type for the *active* backend.
complex_dtype (Type):
Default complex type for the *active* backend.
"""
[docs]
def __init__(self, default_seed: int = DEFAULT_SEED, prefer_jax: bool = PREFER_JAX):
"""
Initializes the manager, detects JAX, and sets the active backend.
Args:
default_seed:
The seed for initializing default random generators.
prefer_jax:
If True and JAX is available, use JAX as the default.
Otherwise, use NumPy.
"""
self.default_seed : int = default_seed
self.is_jax_available : bool = JAX_AVAILABLE
#! Initialize NumPy components first as fallback
self._np_module = np
self._sp_module = sp
self._np_random_module = np_random # Store the base module
self.default_rng = self._create_numpy_rng(self.default_seed)
#! Active backend defaults (start with NumPy)
self.name : str = "numpy"
self.np : Any = self._np_module
self.random : Any = self.default_rng # Use the Generator instance by default
self.scipy : Any = self._sp_module # SciPy module
self.key : Optional[PRNGKey] = None # Key for the random module
self.jit : Callable = lambda x: x # Identity function
#! JAX specific components (if available)
self._jax_module = None
self._jnp_module = None
self._jsp_module = None
self._jrn_module = None
self._jax_jit = None
self.default_jax_key : Optional[PRNGKey] = None
if self.is_jax_available and jax and jnp and jsp and jrn and jax_jit and jcfg:
self._jax_module = jax
self._jnp_module = jnp
self._jsp_module = jsp
self._jrn_module = jrn
self._jax_jit = jax_jit # The imported jax.jit
self.default_jax_key = self._create_jax_key(self.default_seed)
self._update_device()
if prefer_jax:
log.debug("Setting JAX as the active backend.")
self.set_active_backend("jax")
self.detected_jax_backend: Optional[str] = getattr(self, "detected_jax_backend", None)
self.detected_jax_devices: Optional[List[JaxDevice]] = getattr(self, "detected_jax_devices", None)
#! Set active dtypes based on the chosen backend
self._update_dtypes()
env_seed = os.getenv(PY_GLOBAL_SEED_STR, "").strip()
if len(env_seed) > 0:
try:
self.reseed(int(env_seed))
except Exception as e:
log.warning(f"Ignoring PY_GLOBAL_SEED={env_seed!r}: {e}")
# ---------------------------------------------------------------------
def _update_device(self):
'''
Detects the JAX backend and devices after import.
Checks if JAX is available and lists the available devices plus number of threads.
Otherwise, sets the backend to NumPy.
'''
# Reset state before detection
self.detected_jax_backend = None
self.detected_jax_devices = None
self._jax_functional = False
if not self.is_jax_available or not jax or not jax.lib: # Extra safety check
log.warning("Attempted _update_device without JAX being available/imported.")
return
try:
# Use xla_bridge from the imported jax module
self.detected_jax_backend = jax.default_backend() # 'cpu'/'gpu'/'tpu'
# Use preferred jax.local_devices()
try:
self.detected_jax_devices = jax.devices()
except Exception:
self.detected_jax_devices = jax.local_devices()
if not self.detected_jax_devices:
log.warning("JAX backend detected, but no devices found!")
self._jax_functional = False # No devices = not functional for most purposes
else:
# Found backend AND devices
self._jax_functional = True
# Logging moved to __init__ after this call returns
except AttributeError as ae:
# Handle cases where xla_bridge might be missing parts (unlikely with full install)
log.error(f"AttributeError during JAX backend/device detection: {ae}. JAX likely not fully functional.", exc_info=True)
self.detected_jax_backend = "Detection Error (Attribute)"
self.detected_jax_devices = []
self._jax_functional = False
except Exception as e:
log.error(f"An unexpected error occurred during JAX backend/device detection: {e}. "
f"JAX backend might not be functional.", exc_info=True)
self.detected_jax_backend = "Detection Error (Exception)"
self.detected_jax_devices = []
self._jax_functional = False
# ---------------------------------------------------------------------
def _update_dtypes(self):
"""
Updates active dtype attributes based on the active backend.
Sets int_dtype, float_dtype, and complex_dtype to the default types
for the active backend (NumPy or JAX).
"""
if self.name == "jax" and self.is_jax_available:
# Use the JAX defaults stored globally if JAX is active
self.int_dtype = DEFAULT_JP_INT_TYPE if DEFAULT_JP_INT_TYPE else DEFAULT_NP_INT_TYPE # Fallback
self.float_dtype = DEFAULT_JP_FLOAT_TYPE if DEFAULT_JP_FLOAT_TYPE else DEFAULT_NP_FLOAT_TYPE
self.complex_dtype = DEFAULT_JP_CPX_TYPE if DEFAULT_JP_CPX_TYPE else DEFAULT_NP_CPX_TYPE
else:
# Use NumPy defaults otherwise
self.int_dtype = DEFAULT_NP_INT_TYPE
self.float_dtype = DEFAULT_NP_FLOAT_TYPE
self.complex_dtype = DEFAULT_NP_CPX_TYPE
log.debug(f"Active dtypes set for backend '{self.name}': "
f"int={getattr(self.int_dtype, '__name__', 'N/A')}, "
f"float={getattr(self.float_dtype, '__name__', 'N/A')}, "
f"complex={getattr(self.complex_dtype, '__name__', 'N/A')}")
# ----------------------------------------------------------------------
#! Active Backend Management
# ----------------------------------------------------------------------
[docs]
def set_active_backend(self, name: str):
"""
Explicitly sets the active backend globally managed by this instance.
Args:
name :
"numpy", "jax"
Raises:
ValueError:
If 'jax' is requested but not available, or invalid name.
"""
name = name.lower()
if name in ("numpy", "npy", "np"):
new_name = "numpy"
if self.name == new_name:
return
self.name = new_name
self.np = self._np_module
self.random = self.default_rng
self.scipy = self._sp_module
self.key = None
self.jit = lambda x: x
# Process-wide log guard to prevent duplicates across re-imports
if os.environ.get("PY_BACKEND_SWITCH_LOGGED") != "numpy":
log.info("Switched active backend to NumPy.", color="green")
os.environ["PY_BACKEND_SWITCH_LOGGED"] = "numpy"
elif name == "jax":
new_name = "jax"
if self.name == new_name:
return
if not self.is_jax_available or not self._jnp_module or not self._jrn_module or not self._jsp_module or not self._jax_jit:
raise ValueError("Cannot set 'jax' backend: JAX components not fully available.")
self.name = new_name
self.np = self._jnp_module
self.random = self._jrn_module
self.scipy = self._jsp_module
self.key = self.default_jax_key
self.jit = self._jax_jit
if os.environ.get("PY_BACKEND_SWITCH_LOGGED") != "jax":
log.info("Switched active backend to JAX.", color="green")
os.environ["PY_BACKEND_SWITCH_LOGGED"] = "jax"
else:
raise ValueError(f"Invalid backend name: {name}. Choose 'numpy' or 'jax'.")
self._update_dtypes()
# ---------------------------------------------------------------------
#! Random Number Generation Initialization
# ---------------------------------------------------------------------
@staticmethod
def _create_numpy_rng(seed: Optional[int]) -> Union[np_random.Generator, np_random.RandomState]:
"""
Creates a NumPy random number generator instance.
If NumPy >= 1.17, uses the Generator API.
Otherwise, falls back to the legacy RandomState API.
Parameters:
seed (int or None):
Seed for the random number generator. If None, uses the default seed.
Returns:
Union[np_random.Generator, np_random.RandomState]:
A NumPy random number generator instance.
"""
if hasattr(np_random, 'default_rng'):
if seed is not None:
# Seed legacy global state only if a specific seed is given
# Avoids potentially unwanted side effects if seed is None
try:
np.random.seed(seed)
except ValueError: # Handle potential large seed issues for legacy
log.warning(f"Could not seed legacy np.random with seed {seed}. Using default.")
np.random.seed(DEFAULT_SEED)
return np_random.default_rng(seed)
else:
#! Legacy RandomState API for NumPy < 1.17
log.warning(f"NumPy version {np.__version__} < 1.17. Using legacy np.random state.")
rng_instance = np_random.RandomState(seed)
# Monkey patch default_rng onto the instance if it doesn't exist, for potential API consistency attempts
# This is somewhat fragile and mainly for internal consistency here.
if not hasattr(rng_instance, 'default_rng'):
rng_instance.default_rng = lambda s=seed: np_random.RandomState(s)
return rng_instance
@staticmethod
def _create_jax_key(seed: int) -> Optional[PRNGKey]:
"""
Creates a JAX PRNG key.
If JAX is available, uses the PRNGKey function.
Otherwise, returns None.
Parameters:
seed (int):
Seed for the PRNG key.
Returns:
Optional[PRNGKey]:
A JAX PRNG key if JAX is available, otherwise None.
"""
if JAX_AVAILABLE and jrn:
try:
return jrn.PRNGKey(int(seed))
except Exception as e:
log.warning(f"Failed to create JAX PRNGKey with seed {seed}: {e}")
return None
# ---------------------------------------------------------------------
[docs]
def print_info(self):
"""
Prints backend configuration and library versions in a table.
Displays the active backend, available libraries, and their versions.
Also shows the active integer, float, and complex types.
The output is formatted for better readability.
"""
# Collect version information for each library.
backend_versions = {
"NumPy": getattr(np, '__version__', 'Unknown'),
"SciPy": getattr(sp, '__version__', 'Unknown'),
"JAX": "Not Available"
}
# If JAX is available, get its actual version
if self.is_jax_available and self._jax_module:
backend_versions["JAX"] = getattr(self._jax_module, '__version__', 'Unknown')
# Print header.
_log_message("*"*50, 0)
_log_message("Backend Configuration:", 0)
# Log version info.
for lib, version in backend_versions.items():
_log_message(f"{lib} Version: {version}", 2)
# Log active backend details.
_log_message(f"Active Backend: {self.name}", 2)
_log_message(f"JAX Available: {self.is_jax_available}", 3)
_log_message(f"Default Seed: {self.default_seed}", 3)
# Log current backend modules.
if self.name == "jax":
_log_message("JAX Backend Details:", 2)
_log_message(f"\tMain Module: {self.np.__name__}", 3)
_log_message(f"\tRandom Module: {self.random.__name__} (+ PRNGKey)", 3)
_log_message(f"\tSciPy Module: {self.scipy.__name__}", 3)
_log_message(f"\tDefault JAX Key: PRNGKey({self.default_seed})", 3)
elif self.name == "numpy":
_log_message("NumPy Backend Details:", 2)
_log_message(f"\tMain Module: {self.np.__name__}", 3)
_log_message(f"\tRandom Module: {self.random.__class__.__name__}", 3)
_log_message(f"\tSciPy Module: {self.scipy.__name__}", 3)
# Log active data types.
_log_message("Active Data Types:", 2)
_log_message(f"\tInteger Type: {self.int_dtype.__name__}", 3)
_log_message(f"\tFloat Type: {self.float_dtype.__name__}", 3)
_log_message(f"\tComplex Type: {self.complex_dtype.__name__}", 3)
#! Format device detection results
_log_message("Hardware & Device Detection:", 2)
# Use manager's stored info
n_threads = multiprocessing.cpu_count()
_log_message(f"CPU Cores: {n_threads}", 3)
if self.is_jax_available:
detected_backend_str = (self.detected_jax_backend or "Detection Failed").upper()
_log_message(f"Detected JAX Platform: {detected_backend_str}", 3)
device_summary = "N/A"
if self.detected_jax_devices is not None: # Check if detection was attempted
if self.detected_jax_devices: # Check if list is not empty
try:
platforms = [d.platform.upper() for d in self.detected_jax_devices if hasattr(d, 'platform')]
# Get client kind for more detail if available
kinds = [d.client.platform if hasattr(d, 'client') and hasattr(d.client,'platform') else platforms[i] for i, d in enumerate(self.detected_jax_devices)]
device_summary = f"{len(self.detected_jax_devices)} devices ({', '.join(kinds)})"
except Exception:
device_summary = f"{len(self.detected_jax_devices)} devices (Details Error)"
else:
device_summary = "No JAX devices found!"
else:
device_summary = "Detection Failed or Not Run"
_log_message(f"JAX Devices Found: {device_summary}", 2)
else:
_log_message(f"JAX Platform: Not Applicable", 2)
_log_message(f"JAX Devices Found: Not Applicable", 2)
# Footer.
_log_message("*" * 50 + "\n\n\n", 0)
# ---------------------------------------------------------------------
#! Backend Module Retrieval
# ---------------------------------------------------------------------
def _get_numpy_modules(self, use_random: bool = False, seed: Optional[int] = None, use_scipy: bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Returns NumPy backend modules.
- If use_random is True, returns the random module and a key (rng, key=None).
- If use_scipy is True, returns the SciPy module.
Args:
use_random (bool):
If True, include the random module/state.
seed (int or None):
Seed for the random number generator.
If None, uses the manager's default seed.
use_scipy (bool):
If True, include the SciPy module.
"""
# get the main NumPy-like module
main_module = self._np_module
results: list[Any] = [main_module]
if use_random:
# If a specific seed is requested, create a new RNG for that seed.
# Otherwise, use the manager's default RNG instance.
current_seed = seed if seed is not None else self.default_seed
# Always create a new RNG instance when requested via get_backend,
# even if seed matches default, to ensure independence.
rng_instance = self._create_numpy_rng(current_seed)
#! Tuple format (rng, key=None)
results.append((rng_instance, None))
if use_scipy:
results.append(self._sp_module)
return tuple(results) if len(results) > 1 else main_module
def _get_jax_modules(self, use_random: bool = False, seed: Optional[int] = None, use_scipy: bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Returns JAX backend modules.
- If use_random is True, returns the random module and a key (jax.random, key).
- If use_scipy is True, returns the SciPy module.
- If JAX is not available, raises a ValueError.
Args:
use_random (bool):
If True, include the random module/state.
seed (int or None):
Seed for the random number generator.
If None, uses the manager's default seed.
use_scipy (bool):
If True, include the SciPy module.
Raises:
ValueError:
If JAX is not available or if required JAX components are missing.
Returns:
Union[Any, Tuple[Any, ...]]:
The requested module(s) as a single module or a tuple.
Format: (main_module) or (main_module, random_part, scipy_module)
where random_part is (jax.random, key) for JAX.
"""
if not self.is_jax_available or not self._jnp_module or not self._jrn_module or not self._jsp_module:
raise ValueError("JAX backend requested but required JAX components are not available.")
main_module = self._jnp_module
results: list[Any] = [main_module]
if use_random:
# If a specific seed is requested, create a new key.
# Otherwise, use the manager's default key.
current_seed = seed if seed is not None else self.default_seed
# Always create a new key when requested via get_backend.
current_key = self._create_jax_key(current_seed)
# Tuple format (module, key)
results.append((self._jrn_module, current_key))
if use_scipy:
results.append(self._jsp_module)
return tuple(results) if len(results) > 1 else main_module
# ---------------------------------------------------------------------
[docs]
def get_backend_modules(self,
backend_spec : Union[str, Any, None],
use_random : bool = False,
seed : Optional[int] = None,
use_scipy : bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Returns backend modules based on the specifier.
Args:
backend_spec :
Backend identifier. Can be:
- String: "numpy", "np", "jax", "jnp", "default".
- Module: `numpy` or `jax.numpy`.
- None: Uses the manager's active backend.
use_random :
If True, include the random module/state. For JAX,
returns (jax.random, key). For NumPy, returns
(rng_instance, None).
seed :
Seed for the random number generator. If None, uses
the manager's default seed. If provided, creates a
*new* RNG/Key for this request, independent of the
manager's default state.
use_scipy : If True, include the SciPy module.
Returns:
The requested module(s) as a single module or a tuple.
Format:
(main_module) or (main_module, random_part, scipy_module)
where random_part is (rng_instance, None) for numpy or (jrn_module, key) for jax.
"""
if backend_spec is None or backend_spec == "default":
backend_name = self.name
elif isinstance(backend_spec, str):
backend_name = backend_spec.lower()
elif backend_spec is np or backend_spec is self._np_module:
backend_name = "numpy"
elif self.is_jax_available and (backend_spec is jax or backend_spec is self._jax_module):
backend_name = "jax"
elif self.is_jax_available and self._jnp_module and (backend_spec is jnp or backend_spec is self._jnp_module or backend_spec is self._jax_module):
backend_name = "jax"
else:
raise ValueError(f"Unsupported backend specification: {backend_spec}")
#! Dispatch based on name
if backend_name in ("numpy", "np", "npy"):
return self._get_numpy_modules(use_random=use_random, seed=seed, use_scipy=use_scipy)
elif backend_name in ("jax", "jnp", "jaxpy"):
# _get_jax_modules performs its own availability check
return self._get_jax_modules(use_random=use_random, seed=seed, use_scipy=use_scipy)
else:
raise ValueError(f"Unknown backend name derived: {backend_name}")
# ---------------------------------------------------------------------
[docs]
def get_global_backend_modules(self,
use_random : bool = False,
seed : Optional[int] = None,
use_scipy : bool = False) -> Union[Any, Tuple[Any, ...]]:
"""
Returns the globally configured default backend modules from the manager.
Uses the manager's current active backend (`self.name`). If a specific `seed`
is provided, it generates a random component based on that seed for this
request, otherwise uses the manager's default seed/key framework.
Args:
use_random :
If True, include the random module/state.
seed :
Optional seed for this specific request's random component.
use_scipy :
If True, include the SciPy module.
Returns:
The global default backend module(s). See `get_backend_modules` for format.
"""
# Pass the current name and arguments to the main getter
return self.get_backend_modules(self.name, use_random=use_random, seed=seed, use_scipy=use_scipy)
# ---------------------------------------------------------------------
#! RANDOMNESS
# ---------------------------------------------------------------------
[docs]
def reseed(self, seed: int) -> RNGManager:
"""
Reseed the manager's RNGs without doing work at import time.
Returns an RNGManager instance you can stash if needed.
"""
self.default_seed = int(seed)
# NumPy: use Generator; avoid global np.random state unless explicitly requested
self.default_rng = self._create_numpy_rng(self.default_seed)
# Python stdlib random (optional, handy for code that uses it)
py_random.seed(self.default_seed)
py_state = py_random.getstate()
# JAX: reset the main key if available
if self.is_jax_available:
self.default_jax_key = self._create_jax_key(self.default_seed)
self.key = self.default_jax_key
# Keep ACTIVE_* mirrors in sync if you expose them
if self.name == "numpy":
self.random = self.default_rng
self.key = None
else: # jax active
self.random = self._jrn_module
self.key = self.default_jax_key
return RNGManager(self.default_rng, self.default_jax_key if self.is_jax_available else None, py_state)
[docs]
def next_key(self) -> PRNGKey:
"""
Return a fresh JAX subkey and advance the manager's internal key.
"""
if not (self.is_jax_available and self.key is not None):
raise RuntimeError("JAX key not available; ensure JAX backend and call reseed() first.")
self.key, sub = self._jrn_module.split(self.key)
return sub
[docs]
def split_keys(self, n: int) -> Any:
"""
Return `n` fresh subkeys and advance the manager's internal key once.
"""
if not (self.is_jax_available and self.key is not None):
raise RuntimeError("JAX key not available; ensure JAX backend and call reseed() first.")
self.key, k0 = self._jrn_module.split(self.key)
return self._jrn_module.split(k0, n)
[docs]
@contextmanager
def seed_scope(self, seed: int, *, touch_numpy_global: bool = False, touch_python_random: bool = True):
"""
Temporarily set deterministic seeds and restore previous states on exit.
Use it as:
```
with seed_scope(seed):
# Your code here
```
Parameters
----------
seed : int
The seed to set.
touch_numpy_global : bool
Whether to touch the global NumPy random state.
touch_python_random : bool
Whether to touch the Python random state.
"""
# Save old states
old_np_rng = self.default_rng
old_seed = self.default_seed
old_key = self.key
old_py = py_random.getstate()
old_np_global_state = None
if touch_numpy_global:
old_np_global_state = np.random.get_state()
# Set new
suite = self.reseed(seed)
if touch_numpy_global:
np.random.seed(seed) # legacy global
if touch_python_random:
py_random.seed(seed)
try:
yield suite
finally:
# Restore
self.default_seed = old_seed
self.default_rng = old_np_rng
self.key = old_key
if touch_python_random:
py_random.setstate(old_py)
if touch_numpy_global and old_np_global_state is not None:
np.random.set_state(old_np_global_state)
# Multithreaded jobs
[docs]
def spawn_np_generators(root_seed: int, n: int) -> list[np.random.Generator]:
"""
Create `n` independent NumPy generators using SeedSequence.
Use one per worker/process to avoid correlated streams.
"""
ss = np.random.SeedSequence(root_seed)
return [np.random.Generator(np.random.PCG64(s)) for s in ss.spawn(n)]
[docs]
def spawn_jax_keys(root_key: PRNGKey, n: int):
"""
Deterministically produce `n` independent JAX keys.
"""
return jrn.split(root_key, n) if JAX_AVAILABLE else [None] * n
[docs]
def pad_array(x, target_size: int, pad_value, *, backend=None):
"""Pad a one-dimensional array to ``target_size`` with ``pad_value``.
Parameters
----------
x
Input one-dimensional NumPy or JAX array.
target_size
Desired output length. Must be at least ``x.shape[0]``.
pad_value
Value used to initialize positions beyond the input length.
backend
Optional array module. If omitted, the backend is inferred from ``x``.
Returns
-------
array-like
Array with shape ``(target_size,)`` and the same dtype as ``x``.
"""
xp = backend or (jnp if (JAX_AVAILABLE and is_jax_array(x)) else np)
out = xp.full((target_size,), pad_value, dtype=x.dtype)
if xp is np:
out[:x.shape[0]] = x
return out
# JAX path
return out.at[:x.shape[0]].set(x)
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# ONE-TIME INITIALIZATION LOGIC
#
# This function modifies the global variables declared previously.
# It runs only once per Python session.
# ---------------------------------------------------------------------
def _qes_initialize_utils():
"""
Performs one-time setup of the backend environment. This function
modifies the module's global variables.
"""
# Tell this function we are modifying the module-level (global) variables
global log, JAX_AVAILABLE, jax, jnp, jsp, jrn, jcfg, JIT
global Array, PRNGKey, JaxDevice
global DEFAULT_JP_INT_TYPE, DEFAULT_JP_FLOAT_TYPE, DEFAULT_JP_CPX_TYPE
global BACKEND_REPR, BACKEND_DEF_SPIN
global PREFER_JAX, PREFER_64BIT, PREFER_32BIT
global _DTYPE_REGISTRY, _TYPE_TO_NAME
global backend_mgr
global ACTIVE_BACKEND_NAME, ACTIVE_NP_MODULE, ACTIVE_RANDOM
global ACTIVE_SCIPY_MODULE, ACTIVE_JIT, ACTIVE_JAX_KEY
global ACTIVE_INT_TYPE, ACTIVE_FLOAT_TYPE, ACTIVE_COMPLEX_TYPE
# 1. Setup logger (singleton when available, silent stdlib fallback otherwise)
log = _initialize_logger()
_log_message("Initializing general_python.algebra.utils...")
# 2. Environment and Core Settings
num_cores = os.cpu_count() or 1
os.environ[PY_NUM_CORES_STR] = os.getenv(PY_NUM_CORES_STR, str(num_cores))
# Configure spin value from environment
env_spin_value = os.getenv(PY_SPIN_VALUE_STR, "").strip()
if env_spin_value:
try:
BACKEND_REPR = float(env_spin_value)
_log_message(f"Spin value set from environment PY_SPIN_VALUE={BACKEND_REPR}", 1)
except ValueError:
_log_message(f"Invalid PY_SPIN_VALUE '{env_spin_value}', using default {BACKEND_REPR}", 0)
else:
_log_message(f"Using default spin value {BACKEND_REPR}", 1)
# 3. JAX Detection and Import
if PY_JAX_DONT_USE:
JAX_AVAILABLE = False
jax = jnp = jsp = jrn = jcfg = None
os.environ[PY_JAX_AVAILABLE_STR] = '0'
_log_message("JAX import skipped because PY_JAX_DONT_USE is enabled.", 0)
else:
# Avoid JAX CUDA initialization error if GPU is not present or disabled
force_cpu = (
os.environ.get("CUDA_VISIBLE_DEVICES") == ""
or os.environ.get("PY_FORCE_CPU", "0").lower() in ("1", "true", "yes")
or os.environ.get("PY_JAX_CPU_ONLY", "0").lower() in ("1", "true", "yes")
)
if force_cpu:
_log_message("GPU disabled or CPU-only requested. Setting JAX_PLATFORMS=cpu.", 1)
os.environ.setdefault("JAX_PLATFORMS", "cpu")
os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
try:
try:
import jax
except RuntimeError as e:
# Handle CUDA initialization failure gracefully by falling back to CPU
if "CUDA_ERROR_NO_DEVICE" in str(e) or "cuInit" in str(e):
_log_message("JAX CUDA initialization failed. Retrying with JAX_PLATFORMS=cpu...", 1)
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
else:
raise e
from jax import config as jax_config
jcfg = jax_config
if PREFER_64BIT:
# Enable 64-bit precision when JAX is available
jcfg.update("jax_enable_x64", True)
_log_message("JAX 64-bit precision enabled.", 1)
logging.getLogger('jax._src.xla_bridge').setLevel(logging.WARNING)
logging.getLogger('jax').setLevel(logging.WARNING)
try:
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrn
except ImportError as ie:
_log_message(f"JAX submodules could not be imported: {ie}", 0)
raise ie
JAX_AVAILABLE = True
if PREFER_JAX:
JIT = jax.jit # Overwrite the global JIT function
os.environ[PY_JAX_AVAILABLE_STR] = '1'
_log_message("JAX backend available and successfully imported.", 0)
except Exception as e:
JAX_AVAILABLE = False
os.environ[PY_JAX_AVAILABLE_STR] = '0'
_log_message(f"JAX backend not available: {e}", 0)
if JAX_AVAILABLE:
# Type aliases for JAX
PRNGKey = jrn.PRNGKey
# Modern JAX (>=0.4.0): Device is at jax.Device
if hasattr(jax, 'Device'):
JaxDevice = jax.Device
# Older JAX: Device is at jax.lib.xla_client.Device
elif hasattr(jax, 'lib') and hasattr(jax.lib, 'xla_client') and hasattr(jax.lib.xla_client, 'Device'):
JaxDevice = jax.lib.xla_client.Device
else:
JaxDevice = Any
# Type alias for arrays
Array = Union[np.ndarray, jnp.ndarray]
if PREFER_32BIT:
jcfg.update("jax_enable_x64", False)
_log_message("JAX 32-bit precision enforced.", 1)
DEFAULT_JP_INT_TYPE = getattr(jnp, 'int64', getattr(jnp, 'int32')) # Prefer 64bit if available
DEFAULT_JP_FLOAT_TYPE = getattr(jnp, 'float64', getattr(jnp, 'float32')) # Prefer 64bit if available
DEFAULT_JP_CPX_TYPE = getattr(jnp, 'complex128', getattr(jnp, 'complex64')) # Prefer 128bit if available
_log_message(f"JAX default types: int={DEFAULT_JP_INT_TYPE.__name__}, float={DEFAULT_JP_FLOAT_TYPE.__name__}, complex={DEFAULT_JP_CPX_TYPE.__name__}", 2)
else:
# Type aliases for NumPy only
PRNGKey = Any
JaxDevice = Any
Array = np.ndarray
DEFAULT_JP_INT_TYPE = None
DEFAULT_JP_FLOAT_TYPE = None
DEFAULT_JP_CPX_TYPE = None
# 4. Update Type Aliases and Registries based on JAX status
if JAX_AVAILABLE and jnp:
Array = Union[np.ndarray, jnp.ndarray]
_TYPE_TO_NAME[jnp.complex64] = 'complex64'
_TYPE_TO_NAME[jnp.complex128] = 'complex128'
_TYPE_TO_NAME[jnp.float32] = 'float32'
_TYPE_TO_NAME[jnp.float64] = 'float64'
_TYPE_TO_NAME[jnp.int32] = 'int32'
_TYPE_TO_NAME[jnp.int64] = 'int64'
_log_message(f"Type registries updated. Supported types: {list(_TYPE_TO_NAME.values())}")
# Register NumPy types as well
_DTYPE_REGISTRY['float32'] = {'numpy': np.float32, 'jax': jnp.float32 if JAX_AVAILABLE else None}
_DTYPE_REGISTRY['float64'] = {'numpy': np.float64, 'jax': jnp.float64 if JAX_AVAILABLE else None}
_DTYPE_REGISTRY['int32'] = {'numpy': np.int32, 'jax': jnp.int32 if JAX_AVAILABLE else None}
_DTYPE_REGISTRY['int64'] = {'numpy': np.int64, 'jax': jnp.int64 if JAX_AVAILABLE else None}
_DTYPE_REGISTRY['complex64'] = {'numpy': np.complex64, 'jax': jnp.complex64 if JAX_AVAILABLE else None}
_DTYPE_REGISTRY['complex128'] = {'numpy': np.complex128, 'jax': jnp.complex128 if JAX_AVAILABLE else None}
_log_message(f"Data type registry populated with NumPy and JAX types.", 2)
# 5. Instantiate and Configure the Backend Manager
backend_mgr = BackendManager(default_seed=DEFAULT_SEED, prefer_jax=PREFER_JAX)
_log_message(f"BackendManager instantiated with default seed {DEFAULT_SEED}.", 1)
# 6. Update the Global ACTIVE_* Mirrors from the Manager
# This makes the active backend components directly accessible.
ACTIVE_BACKEND_NAME = backend_mgr.name
ACTIVE_NP_MODULE = backend_mgr.np
ACTIVE_RANDOM = backend_mgr.random
ACTIVE_SCIPY_MODULE = backend_mgr.scipy
ACTIVE_JIT = backend_mgr.jit
ACTIVE_JAX_KEY = backend_mgr.key
ACTIVE_INT_TYPE = backend_mgr.int_dtype
ACTIVE_FLOAT_TYPE = backend_mgr.float_dtype
ACTIVE_COMPLEX_TYPE = backend_mgr.complex_dtype
# 7. Final Info Printout
if os.getenv(PY_INFO_VERBOSE, "0").lower() in ("1", "true", "yes", "on"):
backend_mgr.print_info()
os.environ[PY_UTILS_INIT_DONE_STR] = '1'
_log_message("'[General Python].algebra.utils initialization complete.", 0)
# ---------------------------------------------------------------------
# EXECUTION GUARD
#
# This code runs when the module is imported. It ensures that the
# initialization function is called only once.
# ---------------------------------------------------------------------
if "PY_UTILS_INIT_DONE" not in globals() or (PY_UTILS_INIT_DONE_STR not in os.environ or os.environ[PY_UTILS_INIT_DONE_STR] != '1'):
PY_UTILS_INIT_DONE = True # Mark as done immediately
try:
_qes_initialize_utils()
except Exception as e:
log.error(f"CRITICAL ERROR during backend initialization: {e}")
# We exit because the library is in an unusable state.
os._exit(1)
else:
# This message is helpful for debugging re-import issues.
_log_message("general_python.algebra.utils already initialized; skipping re-initialization.", 0)
_log_message("---------------------------------------------------------------------------------", 0)
# ---------------------------------------------------------------------
#! EOF
# ---------------------------------------------------------------------