"""
general_python.ml.networks
==============================
Network Factory and Registry.
This module provides a centralized factory function, `choose_network`, for
instantiating various neural network architectures used in the general_python framework.
It uses a lazy-loading mechanism to improve startup performance.
Usage
-----
Import and use the factory to create a network. The factory takes the
network type and common parameters like `input_shape` and `dtype`.
Network-specific parameters are passed as keyword arguments.
from general_python.ml.networks import choose_network
# Create an RBM using 'alpha' (hidden unit density)
rbm_net = choose_network(
'rbm',
input_shape=(10,),
alpha=2.0, # Creates 2*10=20 hidden units
dtype='complex64'
)
# Create a CNN for an 8x8 lattice
cnn_net = choose_network(
'cnn',
input_shape=(64,),
reshape_dims=(8, 8),
features=[8, 16],
kernel_sizes=[3, 3]
)
----------------------------------------------------------
Author : Maksymilian Kliczkowski
Email : maksymilian.kliczkowski@pwr.edu.pl
Date : 01.10.2025
Description : Factory for creating neural network instances.
----------------------------------------------------------
"""
import importlib
import numpy as np
from typing import Union, Optional, Any, Type, Dict, Tuple, TYPE_CHECKING
from enum import Enum, auto
import math
try:
from .net_impl.net_general import GeneralNet, CallableNet
except ImportError as e:
raise ImportError(f"Could not import GeneralNet or CallableNet. "
f"Ensure that all dependencies are installed. Original error: {e}")
# Type checking import only (does not trigger runtime import)
if TYPE_CHECKING:
from .net_impl.interface_net_flax import FlaxInterface
from .net_impl.net_simple import SimpleNet
from flax import linen as nn
######################################################################
[docs]
class Networks(str, Enum):
"""
Enum class for available standard network architectures.
Inherits from str to allow string comparison.
"""
SIMPLE = 'simple'
RBM = 'rbm'
CNN = 'cnn'
AR = 'ar'
RESNET = 'resnet'
def __str__(self):
return self.value
######################################################################
# LAZY REGISTRIES
# Maps 'string_key' -> ('module.path', 'ClassName')
######################################################################
_BACKBONE_REGISTRY: Dict[str, Tuple[str, str]] = {
'simple' : ('.net_impl.net_simple', 'SimpleNet'),
'rbm' : ('.net_impl.networks.net_rbm', 'RBM'),
'cnn' : ('.net_impl.networks.net_cnn', 'CNN'),
'res' : ('.net_impl.networks.net_res', 'ResNet'),
'resnet' : ('.net_impl.networks.net_res', 'ResNet'),
'mlp' : ('.net_impl.networks.net_mlp', 'MLP'),
'gcnn' : ('.net_impl.networks.net_gcnn', 'GCNN'),
'transformer' : ('.net_impl.networks.net_transformer', 'Transformer'),
# Add future networks here without importing them!
}
_ANSATZ_REGISTRY: Dict[str, Tuple[str, str]] = {
'ar' : ('.net_impl.ansatze.autoregressive', 'ComplexAR'),
'pp' : ('.net_impl.ansatze.pair_product', 'PairProduct'),
'rbmpp' : ('.net_impl.ansatze.pair_product', 'PairProduct'),
'approx_symmetric' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'),
'approxsym' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'),
'asym' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'),
'jastrow' : ('.net_impl.ansatze.jastrow', 'Jastrow'),
'mps' : ('.net_impl.ansatze.mps', 'MPS'),
'amplitude_phase' : ('.net_impl.ansatze.amplitude_phase', 'AmplitudePhase'),
}
def _lazy_load_class(key: str) -> Type[GeneralNet]:
"""Helper to import network classes only when requested."""
registry = _BACKBONE_REGISTRY if key in _BACKBONE_REGISTRY else _ANSATZ_REGISTRY
if key not in registry:
raise ValueError(f"Network '{key}' is not registered in general_python.")
mod_path, cls_name = registry[key]
try:
# Relative import requires the package context
# We use __package__ to support importing as QES.general_python.ml or general_python.ml
module = importlib.import_module(mod_path, package=__package__)
return getattr(module, cls_name)
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to lazy load '{cls_name}' from '{mod_path}'.\nError: {e}")
def _compact_spatial_dims(dims: Tuple[int, ...]) -> Tuple[int, ...]:
compact = tuple(int(d) for d in dims if int(d) > 1)
return compact if compact else (1,)
def _extract_axis_hints(kwargs: Dict[str, Any]) -> Optional[Tuple[int, int, int]]:
"""Extract axis hints from kwargs without mutating them."""
x_dim = kwargs.get("x_dim", kwargs.get("lx", None))
y_dim = kwargs.get("y_dim", kwargs.get("ly", None))
z_dim = kwargs.get("z_dim", kwargs.get("lz", None))
if x_dim is None and y_dim is None and z_dim is None:
return None
x = 1 if x_dim is None else int(x_dim)
y = 1 if y_dim is None else int(y_dim)
z = 1 if z_dim is None else int(z_dim)
return x, y, z
def _extract_lattice_hints(kwargs: Dict[str, Any]) -> Optional[Tuple[int, int, int]]:
"""Extract lattice extents from an optional lattice object."""
lattice = kwargs.get("lattice", None)
if lattice is None:
return None
x = getattr(lattice, "lx", None)
y = getattr(lattice, "ly", None)
z = getattr(lattice, "lz", None)
if x is None and y is None and z is None:
return None
return (
int(1 if x is None else x),
int(1 if y is None else y),
int(1 if z is None else z),
)
def _factorized_default_reshape(n_visible: int) -> Tuple[int, ...]:
"""Infer a stable reshape when no spatial hints are provided."""
n_visible = int(n_visible)
if n_visible <= 1:
return (1,)
root = int(math.sqrt(n_visible))
if root * root == n_visible:
return (root, root)
for factor in range(root, 1, -1):
if n_visible % factor == 0:
return (factor, n_visible // factor)
return (n_visible,)
def _resolve_conv_reshape_dims(input_shape: Optional[tuple], kwargs: Dict[str, Any]) -> Optional[Tuple[int, ...]]:
"""Resolve spatial dims for convolution-like models in a model-agnostic way."""
if input_shape is None:
return None
n_visible = int(np.prod(input_shape))
if n_visible <= 0:
return None
explicit = kwargs.get("reshape_dims", None)
if explicit is not None:
raw = tuple(int(d) for d in explicit)
if len(raw) == 0:
return (n_visible,)
compact = tuple(raw)
while len(compact) > 1 and compact[-1] == 1:
compact = compact[:-1]
if np.prod(compact) != n_visible:
raise ValueError(f"reshape_dims {compact} product != input length {n_visible}")
return compact
dims = _extract_axis_hints(kwargs)
if dims is None:
dims = _extract_lattice_hints(kwargs)
if dims is not None:
x_dim, y_dim, z_dim = dims
base = (max(1, x_dim), max(1, y_dim), max(1, z_dim))
base_prod = int(np.prod(base))
if base_prod <= 0:
return _factorized_default_reshape(n_visible)
# Apply multipartity mismatch on x-axis first.
if n_visible % base_prod == 0:
mult = n_visible // base_prod
reshaped = (base[0] * mult, base[1], base[2])
return _compact_spatial_dims(reshaped)
yz = max(1, base[1] * base[2])
if n_visible % yz == 0:
return _compact_spatial_dims((n_visible // yz, base[1], base[2]))
return _factorized_default_reshape(n_visible)
def _consume_conv_shape_hints(kwargs: Dict[str, Any]) -> None:
"""Remove generic shape-hint aliases before passing kwargs to model constructors."""
for key in ("x_dim", "y_dim", "z_dim", "lx", "ly", "lz", "lattice"):
kwargs.pop(key, None)
######################################################################
[docs]
def choose_network(network_type : Union[str, Networks, Type[Any], Any],
input_shape : Optional[tuple] = None,
backend : str = 'jax',
dtype : Any = None,
param_dtype : Any = None,
seed : Optional[int] = None,
**kwargs) -> GeneralNet:
r"""
Smart factory to instantiate a network.
This factory can create networks by name (e.g., 'rbm', 'cnn'), wrap raw Flax modules,
or instantiate custom `GeneralNet` subclasses. It handles network-specific arguments
passed via `**kwargs` and provides conveniences like `alpha` for RBMs.
Parameters
----------
network_type : Union[str, Networks, Type, Any]
- **String/Enum** : 'rbm', 'cnn', 'simple', 'ar'. The factory will lazy-load and instantiate the corresponding class.
- **Flax Module Class** : A raw `flax.linen.nn.Module` class. The factory will automatically wrap it in a `FlaxInterface`.
- **GeneralNet Class** : A class inheriting from `GeneralNet`. The factory will instantiate it.
- **Instance** : If an already-initialized network instance is passed, it is returned as-is.
input_shape : Optional[tuple]
The shape of the input to the network, e.g., `(n_spins,)`.
backend : str
The computational backend to use ('jax' or 'numpy'). Defaults to 'jax'.
dtype : Any
The data type for the network's computations (e.g., 'float32', 'complex64').
param_dtype : Any
The data type for the network's parameters. If `None`, defaults to `dtype`.
seed : Optional[int]
Random seed for network initialization, if applicable.
**kwargs :
Network-specific keyword arguments. See below for details on each network type.
Using Custom Flax Modules
-------------------------
You can pass your own `flax.linen.nn.Module` class as the `network_type`.
The factory will wrap it in a `general_python.ml.net_impl.interface_net_flax.FlaxInterface`
to make it compatible with the general_python ecosystem.
**Requirements for your custom module:**
1. It must be a valid `nn.Module`.
2. Its `__call__` method should accept a `(batch, n_visible)` JAX array as input.
3. It should return the log-amplitude of the wavefunction, typically with shape `(batch,)`.
**Example:**
>>> import flax.linen as nn
>>> import jax.numpy as jnp
>>>
>>> class MyCustomNet(nn.Module):
... @nn.compact
... def __call__(self, s):
... # A simple dense layer
... x = nn.Dense(features=128)(s)
... x = nn.relu(x)
... log_psi = nn.Dense(features=1)(x)
... return jnp.squeeze(log_psi, axis=-1)
>>>
>>> # The factory handles the wrapping
>>> custom_net = choose_network(
... MyCustomNet,
... input_shape=(100,),
... dtype='complex64' # Passed to the interface
... )
>>> print(custom_net)
Example
-------
The following are examples of how to create different network types using the factory.
1. Restricted Boltzmann Machine (RBM) for NQS.
The RBM is a single-layer dense network. It connects all visible spins
to a layer of hidden units. It is the standard baseline for NQS.
Usage
-----
from general_python.ml.networks import choose_network
# 1. Define RBM Parameters
# ------------------------
# An alpha (density) of 2 means: n_hidden = 2 * n_visible
rbm_params = {
'input_shape': (100,), # 100 spins
'alpha': 2, # Density of hidden units
'use_bias': True,
'dtype': 'complex128' # Essential for quantum phases
}
# 2. Create the Network
# ---------------------
# 'rbm' key triggers the RBM class factory
net = choose_network('rbm', **rbm_params)
# 3. Initialize & Run
# -------------------
# Initialize with a random key (handled internally or explicitly)
# params = net.init(jax.random.PRNGKey(0))
# log_psi = net(params, sample_configuration)
2. Convolutional Neural Network (CNN) for Lattice Systems.
A deep architecture that respects the locality of physical interactions.
Essential for 2D frustrated systems (like Kitaev or J1-J2 models) where
local correlations are complex.
Features:
- Periodic Boundary Conditions (Torus geometry).
- Sum Pooling: Ensures energy is extensive (scales with N).
- Complex Weights: Captures the sign structure of the wavefunction.
Usage
-----
from general_python.ml.networks import choose_network
import jax.numpy as jnp
# 1. Define CNN Parameters
# ------------------------
cnn_params = {
'input_shape': (n_sites,),
'reshape_dims': (L, L), # Reshape 1D input to 2D grid
'features': (16, 32, 64), # Deep network with increasing channels
'kernel_sizes': ((3,3), (3,3), (3,3)),
'activations': ['lncosh'] * 3, # Holomorphic activation
'periodic': True, # Wrap edges (Torus)
'sum_pooling': True, # Sum output over all spatial sites
'dtype': 'complex128'
}
# 2. Create the Network
# ---------------------
net = choose_network('cnn', **cnn_params)
# 3. Debug/Check
# --------------
print(f"Total Parameters: {net.nparams}")
# > Total Parameters: ~25k (Complex)
Keyword Args (by network_type)
------------------------------
Keyword Arguments
-----------------
- input_activation (Optional[Union[str, Callable]]) :
Activation function applied to the input layer.
Useful for preprocessing inputs (e.g., scaling or encoding).
**For 'rbm'**:
- `alpha` (float) : Hidden unit density. `n_hidden` will be `int(alpha * n_visible)`.
- `n_hidden` (int) : Number of hidden units. If `alpha` is also given, `alpha` takes precedence.
- `bias` (bool) : Whether to use a bias for the hidden layer. Default: `True`.
- `visible_bias` (bool) : Whether to use a bias for the visible layer. Default: `True`.
**For 'cnn'**:
- `reshape_dims` (Tuple[int, ...]) : The spatial dimensions to reshape the 1D input into (e.g., `(8, 8)`).
- `features` (Sequence[int]) : Number of output channels for each convolutional layer.
- `kernel_sizes` (Sequence[Union[int, Tuple]]) : Size of the kernel for each conv layer.
- `strides` (Sequence[Union[int, Tuple]]) : Stride for each conv layer. Defaults to 1.
- `output_shape` (Tuple[int, ...]) : Shape of the final output. Default: `(1,)`.
- `activations` (Union[str, Sequence[Union[str, Callable]]]) : Activation function(s) for each conv layer.
- `periodic` (bool) : Whether to use periodic boundary conditions. Default-: `True`.
- `sum_pooling` (bool) : Whether to sum pool the final output over spatial dimensions. Default: `True`.
**For 'simple'**:
- `layers` (Tuple[int, ...]) : A tuple defining the number of neurons in each hidden layer.
- `output_shape` (Tuple[int, ...]) : Shape of the final output. Default: `(1,)`.
- `act_fun` (Tuple[Union[str, Callable],...]) : Activation functions for each layer.
**For 'ar' (Autoregressive)**:
- `depth` (int) : Number of layers in the model.
- `num_hidden` (int) : Number of hidden units in each layer.
- `rnn_type` (str) : Type of recurrent cell, if applicable (e.g., 'lstm', 'gru').
**For 'res' or 'resnet' (Residual Network)**:
- `reshape_dims` (Tuple[int, ...]) : Lattice dimensions for reshaping the input, e.g., `(Lx, Ly)`.
- `features` (int) : Number of feature channels (network width). Default: 32.
- `depth` (int) : Number of residual blocks. Default: 4.
- `kernel_size` (Union[int, Tuple[int,...]]) : Spatial kernel size. Default: 3 (becomes (3,3) for 2D).
Returns
-------
GeneralNet
An initialized or wrapped network instance compatible with the general_python framework.
The returned object is callable with signature:
``log_psi = net(params, inputs)`` where inputs has shape ``(batch, features)``
and output has shape ``(batch,)``.
"""
# 1. Handle Strings and Enums (Lazy Load Path)
if isinstance(network_type, (str, Networks)):
key = str(network_type).lower()
# Argument Pre-processing for convenience
if key == 'rbm':
# Allow `alpha` or `hidden_density` to define `n_hidden`
alpha = kwargs.pop('alpha', None) or kwargs.pop('hidden_density', None)
if alpha is not None and 'n_hidden' not in kwargs:
if not input_shape:
raise ValueError("`input_shape` must be provided when using `alpha` for RBM.")
n_visible = np.prod(input_shape)
kwargs['n_hidden'] = int(alpha * n_visible)
elif key == 'rbmpp':
kwargs['use_rbm'] = True
elif key in ('cnn', 'res', 'resnet'):
reshape_dims = _resolve_conv_reshape_dims(input_shape, kwargs)
if reshape_dims is not None:
kwargs['reshape_dims'] = reshape_dims
_consume_conv_shape_hints(kwargs)
net_cls = _lazy_load_class(key)
return net_cls(input_shape=input_shape, backend=backend, dtype=dtype, param_dtype=param_dtype, seed=seed, **kwargs)
# 2. Handle Existing Instances (Return as-is)
if not isinstance(network_type, type) and (
isinstance(network_type, GeneralNet)
or (hasattr(network_type, 'get_params') and hasattr(network_type, 'apply'))
):
return network_type
# 3. Handle Types/Classes
if isinstance(network_type, type):
# It is a subclass of GeneralNet (e.g. user imported RBM manually)
if issubclass(network_type, GeneralNet):
return network_type(input_shape=input_shape, backend=backend, dtype=dtype, param_dtype=param_dtype, seed=seed, **kwargs)
# It is a Flax Module (Auto-Wrap Logic)
# We check this loosely to avoid importing flax if not needed.
is_flax = False
flax_instance = None
try:
import flax.linen as nn
if isinstance(network_type, type) and issubclass(network_type, nn.Module):
is_flax = True
elif isinstance(network_type, nn.Module):
is_flax = True
flax_instance = network_type
except ImportError:
pass
if is_flax:
# Lazy import the interface wrapper
from .net_impl.interface_net_flax import FlaxInterface
# We pass only module-constructor kwargs to FlaxInterface.
all_kwargs = kwargs.copy()
all_kwargs.pop('input_shape', None) # Remove if present in kwargs
all_kwargs.pop('backend', None)
all_kwargs.pop('dtype', None)
all_kwargs.pop('param_dtype', None)
net_module = flax_instance if flax_instance is not None else network_type
return FlaxInterface(
net_module = net_module,
net_kwargs = all_kwargs,
input_shape = input_shape,
backend = backend,
dtype = dtype,
seed = seed,
input_activation = kwargs.get('input_activation', None)
)
# Handle generic Callables (Factories)
if callable(network_type):
return CallableNet(callable_fun=network_type, input_shape=input_shape, backend=backend, dtype=dtype, **kwargs)
raise ValueError(f"Unknown network type: {type(network_type)}")
######################################################################
#! END OF FILE
######################################################################