"""Initial-value ODE integrators with NumPy and optional JAX backends.
The module provides a compact interface for explicit one-step methods used in
simulation loops where the caller owns time evolution state. Integrators return
``(y_next, dt, info)`` and do not impose a global solver object or trajectory
storage.
Use :func:`choose_ode` for string-based construction or instantiate
:class:`Euler`, :class:`Heun`, :class:`AdaptiveHeun`, :class:`RK`, and
:class:`ScipyRK` directly when full configuration is needed.
-----------------------------------------------
File : general_python/algebra/ode.py
Author : Maksymilian Kliczkowski
email : maxgrom97@gmail.com
-----------------------------------------------
"""
import time
import numpy as np
import warnings
import inspect
from typing import Union, Any, Tuple, Callable
from abc import ABC, abstractmethod
try:
from scipy.integrate import solve_ivp
except ImportError as e:
raise ImportError("Failed to import scipy.integrate module. Ensure general_python package is correctly installed.") from e
try:
import jax
import jax.numpy as jnp
from jax import jit
JAX_AVAILABLE = True
except ImportError:
jnp = None
jit = None
jax = None
JAX_AVAILABLE = False
########################################################################
#! General class for ODE integration
########################################################################
[docs]
class IVP(ABC):
r"""
Abstract initial value problem solver interface.
Methods
-------
step(f, t, y, **rhs_args)
Compute one integration step without modifying internal state.
update(y, h, f, t, **rhs_args)
Update and return new state given current y and step size h.
dt(h, i)
Return the time step used (may depend on h or step index i).
Attributes
----------
xp
Array module (numpy or jax.numpy) selected by backend.
"""
[docs]
def __init__(self, backend: str = 'numpy', rhs_prefactor: float = 1.0, dt: float = 1e-3):
"""
Initialize the ODE solver with a specified backend.
Parameters
----------
backend : str
Backend to use for numerical operations ('numpy' or 'jax').
rhs_prefactor : float
Prefactor for the right-hand side of the ODE.
"""
try:
from ..algebra.utils import get_backend
except ImportError:
def get_backend(backend_str: str): return np if isinstance(backend_str, str) and backend_str.lower() == 'numpy' else jnp if JAX_AVAILABLE and backend_str.lower() == 'jax' else np
self.backend = backend
self.xp = get_backend(backend)
self.backendstr = 'numpy' if self.xp is np else 'jax'
if self.xp is None:
raise ValueError(f"Backend '{backend}' is not supported. Choose 'numpy' or 'jax'.")
self._isjax = not (self.xp is np)
self._isnpy = not self._isjax
self._dt = dt
self._rhs_prefactor = rhs_prefactor
def _call_rhs(self, f, t: float, y, int_step: int = 0, **rhs_args):
"""
Call the user-provided RHS function `f`, handling different signatures.
f may accept:
- positional or keyword args for state (y or y0), time t, int_step
- additional rhs_args
It may return:
- (dy, info, other)
- dy only (scalar or array)
"""
if self._isjax:
#! assume f is a jax function with signature f(y, t, **rhs_args, int_step=int_step)
out = f(y, t, **rhs_args, int_step=int_step)
else:
sig = inspect.signature(f)
kwargs = {}
# bind state
if 'y' in sig.parameters:
kwargs['y'] = y
elif 'y0' in sig.parameters:
kwargs['y0'] = y
else:
pass
#! bind time
if 't' in sig.parameters:
kwargs['t'] = t
#! bind int_step
if 'int_step' in sig.parameters:
kwargs['int_step'] = int_step
#! bind additional args
for name, val in rhs_args.items():
if name in sig.parameters:
kwargs[name] = val
out = f(**kwargs)
#! normalize outputs
if isinstance(out, tuple):
if len(out) == 3:
return out # (dy, info, other)
elif len(out) == 2:
dy, info = out
return dy, info, None
else:
dy = out[0]
return dy, None, None
else:
return out, None, None
# -------------------------------------------------------
[docs]
def dt(self, h: float = 0.0, i: int = 0) -> float:
"""Return the step size used for step index ``i``.
Parameters
----------
h
Optional external step-size proposal. Fixed-step integrators ignore
it and return their configured ``dt``.
i
Optional integration step index.
"""
return self._dt
[docs]
def set_dt(self, dt: float):
"""
Set the time step for the integration.
Parameters
----------
dt : float
The new time step to set.
"""
self._dt = dt
# -------------------------------------------------------
[docs]
@abstractmethod
def step(self, f, t: float, y, **rhs_args):
'''
Compute one integration step without modifying internal state.
This method should be implemented by subclasses.
'''
raise NotImplementedError
[docs]
def update(self, y, h: float, f, t: float, **rhs_args):
"""Advance ``y`` by one step and return only the updated state.
This convenience method discards the ``dt`` and auxiliary info returned
by :meth:`step`. Subclasses can override it if they need custom update
semantics.
"""
# Default: call step then return state
yout, _ = self.step(f, t, y, **rhs_args)
return yout
@property
def order(self) -> int:
'''
Return the order of the integration method.
This method should be implemented by subclasses.
'''
return 1
@property
def is_jax(self) -> bool:
"""
Check if the backend is JAX.
Returns
-------
bool
True if the backend is JAX, False otherwise.
"""
return self._isjax
@property
def is_numpy(self) -> bool:
"""
Check if the backend is NumPy.
Returns
-------
bool
True if the backend is NumPy, False otherwise.
"""
return self._isnpy
[docs]
def __repr__(self):
"""
Return a string representation of the IVP object.
"""
return f"SimpleIVP(backend={self.backendstr})"
[docs]
def __str__(self):
"""
Return a string representation of the IVP object.
"""
return self.__repr__()
[docs]
def __call__(self, f, t, y, **rhs_args):
"""
Call the step method to compute one integration step.
Parameters
----------
f : callable
Function representing the right-hand side of the ODE.
t : float
Current time.
y : array-like
Current state.
**rhs_args : keyword arguments
Additional arguments to pass to the function f.
Returns
-------
yout : array-like
New state after one integration step.
dt : float
The step size used.
"""
return self.step(f, t, y, **rhs_args)
[docs]
def __len__(self):
"""
Return the length of the IVP object.
"""
return self.order
#######################################################################
#! Euler integration
#######################################################################
[docs]
class Euler(IVP):
r"""
Simple forward Euler integrator.
Parameters
----------
dt : float
Fixed step size for the integration.
backend : str
'numpy' or 'jax'
"""
[docs]
def __init__(self, dt: float = 1e-3, backend: str = 'numpy', rhs_prefactor: float = 1.0):
"""
Initializes the object with a specified time step and computational backend.
Args:
dt (float, optional):
The time step to use for the ODE solver. Defaults to 1e-3.
backend (str, optional):
The computational backend to use (e.g., 'numpy'). Defaults to 'numpy'.
rhs_prefactor (float, optional):
A prefactor to multiply with the right-hand side of the ODE. Defaults to 1.0.
"""
super().__init__(backend, rhs_prefactor=rhs_prefactor, dt=dt)
[docs]
def step(self, f, t: float, y, **rhs_args):
r"""
Compute one Euler step: y_{n+1} = y_n + \Delta t * f(y_n, t).
Returns
-------
yout
New state after one Euler step.
dt
The step size used.
"""
#! Evaluate derivative through function
dy, step_info, other = self._call_rhs(f, t, y, int_step=0, **rhs_args)
yout = y + (self._dt * self._rhs_prefactor) * dy
return yout, self._dt, (step_info, other)
[docs]
def __repr__(self):
"""
Return a string representation of the Euler object.
"""
return f"Euler(dt={self._dt}, backend={self.backendstr}, rhs_p={self._rhs_prefactor})"
########################################################################
#! Heun integration
########################################################################
[docs]
class Heun(IVP):
r"""
Second-order Heun (explicit trapezoidal) integrator.
Parameters
----------
dt : float
Fixed step size delta t (can be adapted externally).
backend : str
'numpy' or 'jax'
"""
[docs]
def __init__(self, dt: float = 1e-3, backend: str = 'numpy', rhs_prefactor: float = 1.0):
super().__init__(backend, dt=dt, rhs_prefactor=rhs_prefactor)
[docs]
def step(self, f, t: float, y, **rhs_args):
r"""
Compute one Heun step:
.. math::
y_{n+1} = y_n + \frac{\Delta t}{2} \left( f(y_n, t) + f(y_n + \Delta t f(y_n, t), t + \Delta t) \right)
where :math:`\Delta t` is the time step.
This is a second-order Runge-Kutta method.
Parameters
----------
f : callable
Function representing the right-hand side of the ODE.
t : float
Current time.
y : array-like
Current state.
>>> k0 = f(y, t) # slope at t
>>> y_pred = y + dt * k0 # predictor step
>>> k1 = f(y_pred, t + dt) # slope at t + dt
>>> yout = y + (dt / 2) * (k0 + k1) # corrector step
Returns
-------
yout and dt.
"""
dt = self._dt
multiplier = self._rhs_prefactor * dt
# Predictor slope
k0, step_info, other = self._call_rhs(f, t, y, int_step=0, **rhs_args)
# Predictor step
y_pred = y + multiplier * k0
# Corrector slope
k1, step_info, other = self._call_rhs(f, t + dt, y_pred, int_step=1, **rhs_args)
# Combine as average
yout = y + 0.5 * multiplier * (k0 + k1)
return yout, dt, (step_info, other)
[docs]
def __repr__(self):
"""
Return a string representation of the Heun object.
"""
return f"Heun(dt={self._dt}, backend={self.backendstr}, rhs_p={self._rhs_prefactor})"
########################################################################
#! Adaptive Heun integration
########################################################################
[docs]
class AdaptiveHeun(IVP):
"""
Adaptive second-order Heun integrator with error control.
Parameters
----------
dt : float
Initial time step delta t.
tol : float
Error tolerance.
max_step : float
Maximum allowed time step.
backend : str
'numpy' or 'jax'
"""
[docs]
def __init__(self,
dt : float = 1e-3,
tol : float = 1e-8,
max_step : float = 1.0,
backend : str = 'numpy',
rhs_prefactor : float = 1.0):
super().__init__(backend, dt=dt, rhs_prefactor=rhs_prefactor)
self.tolerance = tol
self.max_step = max_step
[docs]
def step(self, f, t: float, y, norm_fun=None, **rhs_args):
"""Perform one adaptive Heun step with local error control.
Parameters
----------
f
Right-hand side callable.
t
Current time.
y
Current state.
norm_fun
Optional norm function used for the local error estimate. Defaults
to the active backend's vector norm.
**rhs_args
Extra keyword arguments forwarded to ``f``.
Returns
-------
tuple
``(y_next, dt_used, info)`` where ``info`` contains the last
right-hand-side metadata returned by ``f``.
"""
if norm_fun is None:
norm_fun = self.xp.linalg.norm
dt = self._dt
y0 = y
fe = 0.0
mult= self._rhs_prefactor * dt
#! Adapt until accepted
while fe < 1.0:
# Full step - 1st order (just the simple Heun step)
k0, step_info, other = f(y0=y0, t=t, **rhs_args, int_step=0)
y_full = y0 + dt * k0
k1, step_info, other = f(y0=y_full, t=t + dt, **rhs_args, int_step=1)
dy_full = 0.5 * mult * (k0 + k1)
# Two half steps - 2nd order (search for a better step)
k0h = k0
y_half = y0 + 0.5 * mult * k0h
k1h, step_info, other = f(y0=y_half, t=t + 0.5 * dt, **rhs_args, int_step=2)
dy_half = 0.5 * mult * k1h
y_half2 = y_half + 0.5 * mult * k1h
k2h, step_info, other = f(y0=y_half2, t=t + dt, **rhs_args, int_step=3)
dy_half = 0.25 * mult * (k0h + k2h)
#! Error estimate
err = norm_fun(dy_half - dy_full) # absolute error
fe = self.tolerance / (err + 1e-15) # relative error
#! Step size control
fac = 0.9 * fe**(1/3)
fac = self.xp.clip(fac, 0.2, 2.0)
dt_new = dt * fac
dt = min(dt_new, self.max_step)
# Accept step
self._dt = dt
yout = y0 + dy_half
return yout, dt, (step_info, other)
[docs]
def __repr__(self):
"""
Return a string representation of the AdaptiveHeun object.
"""
return f"AdaptiveHeun(dt={self._dt}, tol={self.tolerance}, max_step={self.max_step}, backend={self.backendstr})"
#########################################################################
#! General Runge-Kutta integration
#########################################################################
[docs]
class RK(IVP):
"""
General explicit Runge-Kutta integrator with arbitrary Butcher tableau.
Parameters
----------
a : array-like of shape (s, s)
Lower-triangular matrix of stage coefficients.
b : array-like of length s
Weights for final combination.
c : array-like of length s
Nodes (c = sum of rows of a).
dt : float
Fixed step size dt.
backend : str
'numpy' or 'jax'
Notes
-----
For common orders, use the `from_order` classmethod.
"""
[docs]
def __init__(self,
a : list, # Butcher tableau
b : list, # Weights
c : list, # Nodes
dt : float = 1e-3,
backend : str = 'numpy',
rhs_prefactor : float = 1.0):
# Initialize base with rhs_prefactor for API consistency
super().__init__(backend, rhs_prefactor=rhs_prefactor, dt=dt)
xp = self.xp
# Convert tableau to arrays
self.a = xp.array(a, dtype=xp.float64)
self.b = xp.array(b, dtype=xp.float64)
self.c = xp.array(c, dtype=xp.float64)
self.stages = len(self.b)
@property
def order(self) -> int:
"""Return the number of stages in the configured Butcher tableau."""
return len(self.b)
[docs]
@classmethod
def from_order(cls, order: int, dt: float = 1e-3, backend: str = 'numpy', rhs_prefactor: float = 1.0):
"""
Create a Runge-Kutta method instance from a specified order.
Parameters:
order (int):
The order of the Runge-Kutta method. Supported values are 1 (Euler), 2 (RK2), and 4 (RK4).
dt (float, optional):
The time step size. Defaults to 1e-3.
backend (str, optional):
The computational backend to use (e.g., 'numpy'). Defaults to 'numpy'.
Returns:
cls:
An instance of the class initialized with the appropriate Butcher tableau for the specified order.
Raises:
ValueError: If the specified order is not supported.
"""
# Define tableau for orders 1,2,4
if order == 1:
a = [[0.0]]
b = [1.0]
c = [0.0]
elif order == 2:
a = [[0.0, 0.0], [1.0, 0.0]]
b = [0.5, 0.5]
c = [0.0, 1.0]
elif order == 4:
a = [[0.0, 0.0, 0.0, 0.0],
[0.5, 0.0, 0.0, 0.0],
[0.0, 0.5, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0]]
b = [1/6, 1/3, 1/3, 1/6]
c = [0.0, 0.5, 0.5, 1.0]
else:
raise ValueError(f"Unsupported order: {order}")
return cls(a, b, c, dt=dt, backend=backend, rhs_prefactor=rhs_prefactor)
[docs]
def step(self, f, t: float, y, **rhs_args):
r"""
Perform one Runge-Kutta step.
The right-hand side prefactor is treated as part of the ODE,
``dy/dt = a f(t, y)``, not as a final post-processing multiplier.
This matters for complex real-time evolution in TDVP, where the stage
states must be evaluated at
y_i = y_n + h a \sum_{j < i} a_{ij} k_j
in order for RK2 / RK4 to integrate the same physical time variable as
Euler and exact diagonalization.
Returns
-------
yout :
new state
dt :
step size used
"""
xp = self.xp
h = self._dt
k = [None] * self.stages
#! Compute stages
for i in range(self.stages):
ti = t + self.c[i] * h
yi = y
for j in range(i):
yi = yi + (self._rhs_prefactor * h) * self.a[i, j] * k[j]
k[i], step_info, other = f(yi, ti, **rhs_args, int_step=i)
#! Combine
yout = y
for i in range(self.stages):
yout = yout + (self._rhs_prefactor * h) * self.b[i] * k[i]
return yout, h, (step_info, other)
[docs]
def __repr__(self):
"""
Return a string representation of the RK object.
"""
return f"RK(order={self.order}, dt={self._dt}, backend={self.backendstr})"
#########################################################################
#! From scipy.integrate import solve_ivp
#########################################################################
[docs]
class ScipyRK(IVP):
"""
Wrapper for scipy.integrate.solve_ivp with explicit Runge-Kutta methods.
Parameters
----------
dt : float
Initial and maximum time step delta t.
tol : float
Relative and absolute tolerance for solver.
max_step : float or None
Maximum allowed step size; if None uses dt.
method : str
Solver method: one of 'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'.
backend : str
'numpy' or 'jax'. JAX backend is not supported and will fallback to NumPy.
"""
[docs]
def __init__(self,
dt : float = 1e-3,
tol : float = 1e-6,
max_step : float = None,
method : str = 'RK45',
backend : str = 'numpy',
rhs_prefactor: float = 1.0):
super().__init__(backend, rhs_prefactor=rhs_prefactor, dt=dt)
self._dt = float(dt)
self.tol = float(tol)
self.max_step = float(max_step) if max_step is not None else None
self.method = method
self.supported_methods = ['RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA']
if method not in self.supported_methods:
raise ValueError(f"Method '{method}' not supported. Choose from {self.supported_methods}.")
if backend == 'jax':
warnings.warn("ScipyRK does not support JAX backend; using NumPy internally.")
self.xp = np
[docs]
def step(self, f, t: float, y, **rhs_args):
"""
Perform one adaptive step using solve_ivp over [t, t + dt].
Returns
-------
yout : ndarray
State at end of interval.
dt_actual : float
Actual time step taken.
"""
def _ode_system(t_i, y_i):
# Multiply RHS by prefactor for consistency with other integrators
return self._rhs_prefactor * f(y_i, t_i, **rhs_args)[0]
t_span = (t, t + self._dt)
sol = solve_ivp(
fun = _ode_system,
t_span = t_span,
y0 = y,
method = self.method,
rtol = self.tol,
atol = self.tol,
max_step = self.max_step or self._dt
)
if not sol.success:
raise RuntimeError(f"SciPy solver failed: {sol.message}")
yout = sol.y[:, -1]
dt_actual = sol.t[-1] - t
self._dt = dt_actual
return yout, dt_actual, (None, None)
[docs]
def __repr__(self):
"""
Return a string representation of the ScipyRK object.
"""
return f"ScipyRK(method={self.method}, dt={self._dt}, tol={self.tol}, max_step={self.max_step}, backend={self.backendstr})"
#########################################################################
[docs]
class OdeTypes:
"""
Enum-like class for ODE types.
"""
EULER = 'euler'
HEUN = 'heun'
RK2 = 'rk2'
RK4 = 'rk4'
ADAPTIVE = 'adaptive'
SCIPY = 'scipy'
[docs]
def choose_ode(ode_type: Union[str, int, OdeTypes], *, dt: float = 1e-1, rhs_prefactor: float = 1.0, backend: Any = 'numpy', **kwargs) -> IVP:
"""
Choose an ODE solver based on the specified type.
Parameters
----------
ode_type : str or int
Type of ODE solver to use.
dt : float, optional
Time step size. Default is 1e-1.
backend : str, optional
Computational backend to use. Default is 'numpy'.
**kwargs : keyword arguments
Additional arguments for the ODE solver.
Returns
-------
IVP
An instance of the selected ODE solver.
"""
if isinstance(ode_type, int):
ode_type = OdeTypes(ode_type)
if isinstance(ode_type, str):
ode_type = ode_type.lower()
if ode_type == OdeTypes.EULER:
return Euler(dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
elif ode_type == OdeTypes.HEUN:
return Heun(dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
elif ode_type == OdeTypes.RK2:
return RK.from_order(2, dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
elif ode_type == OdeTypes.RK4:
return RK.from_order(4, dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
elif ode_type == OdeTypes.ADAPTIVE:
return AdaptiveHeun(dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
elif ode_type == OdeTypes.SCIPY:
return ScipyRK(dt=dt, backend=backend, rhs_prefactor=rhs_prefactor, **kwargs)
raise ValueError(f"Unknown ODE type: {ode_type}")
#########################################################################
#! End of file
#########################################################################