Source code for general_python.ml.networks

"""
general_python.ml.networks
==============================

Network Factory and Registry.

This module provides a centralized factory function, `choose_network`, for
instantiating various neural network architectures used in the general_python framework.
It uses a lazy-loading mechanism to improve startup performance.

Usage
-----
Import and use the factory to create a network. The factory takes the
network type and common parameters like `input_shape` and `dtype`.
Network-specific parameters are passed as keyword arguments.

    from general_python.ml.networks import choose_network
    
    # Create an RBM using 'alpha' (hidden unit density)
    rbm_net = choose_network(
        'rbm',
        input_shape=(10,),
        alpha=2.0,          # Creates 2*10=20 hidden units
        dtype='complex64'
    )
    
    # Create a CNN for an 8x8 lattice
    cnn_net = choose_network(
        'cnn',
        input_shape=(64,),
        reshape_dims=(8, 8),
        features=[8, 16],
        kernel_sizes=[3, 3]
    )

----------------------------------------------------------
Author          : Maksymilian Kliczkowski
Email           : maksymilian.kliczkowski@pwr.edu.pl
Date            : 01.10.2025
Description     : Factory for creating neural network instances.
----------------------------------------------------------
"""
import importlib
import numpy as np
from typing import Union, Optional, Any, Type, Dict, Tuple, TYPE_CHECKING
from enum import Enum, auto
import math

try:
    from .net_impl.net_general          import GeneralNet, CallableNet
except ImportError as e:
    raise ImportError(f"Could not import GeneralNet or CallableNet. "
                      f"Ensure that all dependencies are installed. Original error: {e}")

# Type checking import only (does not trigger runtime import)
if TYPE_CHECKING:
    from .net_impl.interface_net_flax   import FlaxInterface
    from .net_impl.net_simple           import SimpleNet
    from flax                           import linen as nn

######################################################################

[docs] class Networks(str, Enum): """ Enum class for available standard network architectures. Inherits from str to allow string comparison. """ SIMPLE = 'simple' RBM = 'rbm' CNN = 'cnn' AR = 'ar' RESNET = 'resnet' def __str__(self): return self.value
###################################################################### # LAZY REGISTRIES # Maps 'string_key' -> ('module.path', 'ClassName') ###################################################################### _BACKBONE_REGISTRY: Dict[str, Tuple[str, str]] = { 'simple' : ('.net_impl.net_simple', 'SimpleNet'), 'rbm' : ('.net_impl.networks.net_rbm', 'RBM'), 'cnn' : ('.net_impl.networks.net_cnn', 'CNN'), 'res' : ('.net_impl.networks.net_res', 'ResNet'), 'resnet' : ('.net_impl.networks.net_res', 'ResNet'), 'mlp' : ('.net_impl.networks.net_mlp', 'MLP'), 'gcnn' : ('.net_impl.networks.net_gcnn', 'GCNN'), 'transformer' : ('.net_impl.networks.net_transformer', 'Transformer'), # Add future networks here without importing them! } _ANSATZ_REGISTRY: Dict[str, Tuple[str, str]] = { 'ar' : ('.net_impl.ansatze.autoregressive', 'ComplexAR'), 'pp' : ('.net_impl.ansatze.pair_product', 'PairProduct'), 'rbmpp' : ('.net_impl.ansatze.pair_product', 'PairProduct'), 'approx_symmetric' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'), 'approxsym' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'), 'asym' : ('.net_impl.ansatze.approx_symmetric', 'AnsatzApproxSymmetric'), 'jastrow' : ('.net_impl.ansatze.jastrow', 'Jastrow'), 'mps' : ('.net_impl.ansatze.mps', 'MPS'), 'amplitude_phase' : ('.net_impl.ansatze.amplitude_phase', 'AmplitudePhase'), } def _lazy_load_class(key: str) -> Type[GeneralNet]: """Helper to import network classes only when requested.""" registry = _BACKBONE_REGISTRY if key in _BACKBONE_REGISTRY else _ANSATZ_REGISTRY if key not in registry: raise ValueError(f"Network '{key}' is not registered in general_python.") mod_path, cls_name = registry[key] try: # Relative import requires the package context # We use __package__ to support importing as QES.general_python.ml or general_python.ml module = importlib.import_module(mod_path, package=__package__) return getattr(module, cls_name) except (ImportError, AttributeError) as e: raise ImportError(f"Failed to lazy load '{cls_name}' from '{mod_path}'.\nError: {e}") def _compact_spatial_dims(dims: Tuple[int, ...]) -> Tuple[int, ...]: compact = tuple(int(d) for d in dims if int(d) > 1) return compact if compact else (1,) def _extract_axis_hints(kwargs: Dict[str, Any]) -> Optional[Tuple[int, int, int]]: """Extract axis hints from kwargs without mutating them.""" x_dim = kwargs.get("x_dim", kwargs.get("lx", None)) y_dim = kwargs.get("y_dim", kwargs.get("ly", None)) z_dim = kwargs.get("z_dim", kwargs.get("lz", None)) if x_dim is None and y_dim is None and z_dim is None: return None x = 1 if x_dim is None else int(x_dim) y = 1 if y_dim is None else int(y_dim) z = 1 if z_dim is None else int(z_dim) return x, y, z def _extract_lattice_hints(kwargs: Dict[str, Any]) -> Optional[Tuple[int, int, int]]: """Extract lattice extents from an optional lattice object.""" lattice = kwargs.get("lattice", None) if lattice is None: return None x = getattr(lattice, "lx", None) y = getattr(lattice, "ly", None) z = getattr(lattice, "lz", None) if x is None and y is None and z is None: return None return ( int(1 if x is None else x), int(1 if y is None else y), int(1 if z is None else z), ) def _factorized_default_reshape(n_visible: int) -> Tuple[int, ...]: """Infer a stable reshape when no spatial hints are provided.""" n_visible = int(n_visible) if n_visible <= 1: return (1,) root = int(math.sqrt(n_visible)) if root * root == n_visible: return (root, root) for factor in range(root, 1, -1): if n_visible % factor == 0: return (factor, n_visible // factor) return (n_visible,) def _resolve_conv_reshape_dims(input_shape: Optional[tuple], kwargs: Dict[str, Any]) -> Optional[Tuple[int, ...]]: """Resolve spatial dims for convolution-like models in a model-agnostic way.""" if input_shape is None: return None n_visible = int(np.prod(input_shape)) if n_visible <= 0: return None explicit = kwargs.get("reshape_dims", None) if explicit is not None: raw = tuple(int(d) for d in explicit) if len(raw) == 0: return (n_visible,) compact = tuple(raw) while len(compact) > 1 and compact[-1] == 1: compact = compact[:-1] if np.prod(compact) != n_visible: raise ValueError(f"reshape_dims {compact} product != input length {n_visible}") return compact dims = _extract_axis_hints(kwargs) if dims is None: dims = _extract_lattice_hints(kwargs) if dims is not None: x_dim, y_dim, z_dim = dims base = (max(1, x_dim), max(1, y_dim), max(1, z_dim)) base_prod = int(np.prod(base)) if base_prod <= 0: return _factorized_default_reshape(n_visible) # Apply multipartity mismatch on x-axis first. if n_visible % base_prod == 0: mult = n_visible // base_prod reshaped = (base[0] * mult, base[1], base[2]) return _compact_spatial_dims(reshaped) yz = max(1, base[1] * base[2]) if n_visible % yz == 0: return _compact_spatial_dims((n_visible // yz, base[1], base[2])) return _factorized_default_reshape(n_visible) def _consume_conv_shape_hints(kwargs: Dict[str, Any]) -> None: """Remove generic shape-hint aliases before passing kwargs to model constructors.""" for key in ("x_dim", "y_dim", "z_dim", "lx", "ly", "lz", "lattice"): kwargs.pop(key, None) ######################################################################
[docs] def choose_network(network_type : Union[str, Networks, Type[Any], Any], input_shape : Optional[tuple] = None, backend : str = 'jax', dtype : Any = None, param_dtype : Any = None, seed : Optional[int] = None, **kwargs) -> GeneralNet: r""" Smart factory to instantiate a network. This factory can create networks by name (e.g., 'rbm', 'cnn'), wrap raw Flax modules, or instantiate custom `GeneralNet` subclasses. It handles network-specific arguments passed via `**kwargs` and provides conveniences like `alpha` for RBMs. Parameters ---------- network_type : Union[str, Networks, Type, Any] - **String/Enum** : 'rbm', 'cnn', 'simple', 'ar'. The factory will lazy-load and instantiate the corresponding class. - **Flax Module Class** : A raw `flax.linen.nn.Module` class. The factory will automatically wrap it in a `FlaxInterface`. - **GeneralNet Class** : A class inheriting from `GeneralNet`. The factory will instantiate it. - **Instance** : If an already-initialized network instance is passed, it is returned as-is. input_shape : Optional[tuple] The shape of the input to the network, e.g., `(n_spins,)`. backend : str The computational backend to use ('jax' or 'numpy'). Defaults to 'jax'. dtype : Any The data type for the network's computations (e.g., 'float32', 'complex64'). param_dtype : Any The data type for the network's parameters. If `None`, defaults to `dtype`. seed : Optional[int] Random seed for network initialization, if applicable. **kwargs : Network-specific keyword arguments. See below for details on each network type. Using Custom Flax Modules ------------------------- You can pass your own `flax.linen.nn.Module` class as the `network_type`. The factory will wrap it in a `general_python.ml.net_impl.interface_net_flax.FlaxInterface` to make it compatible with the general_python ecosystem. **Requirements for your custom module:** 1. It must be a valid `nn.Module`. 2. Its `__call__` method should accept a `(batch, n_visible)` JAX array as input. 3. It should return the log-amplitude of the wavefunction, typically with shape `(batch,)`. **Example:** >>> import flax.linen as nn >>> import jax.numpy as jnp >>> >>> class MyCustomNet(nn.Module): ... @nn.compact ... def __call__(self, s): ... # A simple dense layer ... x = nn.Dense(features=128)(s) ... x = nn.relu(x) ... log_psi = nn.Dense(features=1)(x) ... return jnp.squeeze(log_psi, axis=-1) >>> >>> # The factory handles the wrapping >>> custom_net = choose_network( ... MyCustomNet, ... input_shape=(100,), ... dtype='complex64' # Passed to the interface ... ) >>> print(custom_net) Example ------- The following are examples of how to create different network types using the factory. 1. Restricted Boltzmann Machine (RBM) for NQS. The RBM is a single-layer dense network. It connects all visible spins to a layer of hidden units. It is the standard baseline for NQS. Usage ----- from general_python.ml.networks import choose_network # 1. Define RBM Parameters # ------------------------ # An alpha (density) of 2 means: n_hidden = 2 * n_visible rbm_params = { 'input_shape': (100,), # 100 spins 'alpha': 2, # Density of hidden units 'use_bias': True, 'dtype': 'complex128' # Essential for quantum phases } # 2. Create the Network # --------------------- # 'rbm' key triggers the RBM class factory net = choose_network('rbm', **rbm_params) # 3. Initialize & Run # ------------------- # Initialize with a random key (handled internally or explicitly) # params = net.init(jax.random.PRNGKey(0)) # log_psi = net(params, sample_configuration) 2. Convolutional Neural Network (CNN) for Lattice Systems. A deep architecture that respects the locality of physical interactions. Essential for 2D frustrated systems (like Kitaev or J1-J2 models) where local correlations are complex. Features: - Periodic Boundary Conditions (Torus geometry). - Sum Pooling: Ensures energy is extensive (scales with N). - Complex Weights: Captures the sign structure of the wavefunction. Usage ----- from general_python.ml.networks import choose_network import jax.numpy as jnp # 1. Define CNN Parameters # ------------------------ cnn_params = { 'input_shape': (n_sites,), 'reshape_dims': (L, L), # Reshape 1D input to 2D grid 'features': (16, 32, 64), # Deep network with increasing channels 'kernel_sizes': ((3,3), (3,3), (3,3)), 'activations': ['lncosh'] * 3, # Holomorphic activation 'periodic': True, # Wrap edges (Torus) 'sum_pooling': True, # Sum output over all spatial sites 'dtype': 'complex128' } # 2. Create the Network # --------------------- net = choose_network('cnn', **cnn_params) # 3. Debug/Check # -------------- print(f"Total Parameters: {net.nparams}") # > Total Parameters: ~25k (Complex) Keyword Args (by network_type) ------------------------------ Keyword Arguments ----------------- - input_activation (Optional[Union[str, Callable]]) : Activation function applied to the input layer. Useful for preprocessing inputs (e.g., scaling or encoding). **For 'rbm'**: - `alpha` (float) : Hidden unit density. `n_hidden` will be `int(alpha * n_visible)`. - `n_hidden` (int) : Number of hidden units. If `alpha` is also given, `alpha` takes precedence. - `bias` (bool) : Whether to use a bias for the hidden layer. Default: `True`. - `visible_bias` (bool) : Whether to use a bias for the visible layer. Default: `True`. **For 'cnn'**: - `reshape_dims` (Tuple[int, ...]) : The spatial dimensions to reshape the 1D input into (e.g., `(8, 8)`). - `features` (Sequence[int]) : Number of output channels for each convolutional layer. - `kernel_sizes` (Sequence[Union[int, Tuple]]) : Size of the kernel for each conv layer. - `strides` (Sequence[Union[int, Tuple]]) : Stride for each conv layer. Defaults to 1. - `output_shape` (Tuple[int, ...]) : Shape of the final output. Default: `(1,)`. - `activations` (Union[str, Sequence[Union[str, Callable]]]) : Activation function(s) for each conv layer. - `periodic` (bool) : Whether to use periodic boundary conditions. Default-: `True`. - `sum_pooling` (bool) : Whether to sum pool the final output over spatial dimensions. Default: `True`. **For 'simple'**: - `layers` (Tuple[int, ...]) : A tuple defining the number of neurons in each hidden layer. - `output_shape` (Tuple[int, ...]) : Shape of the final output. Default: `(1,)`. - `act_fun` (Tuple[Union[str, Callable],...]) : Activation functions for each layer. **For 'ar' (Autoregressive)**: - `depth` (int) : Number of layers in the model. - `num_hidden` (int) : Number of hidden units in each layer. - `rnn_type` (str) : Type of recurrent cell, if applicable (e.g., 'lstm', 'gru'). **For 'res' or 'resnet' (Residual Network)**: - `reshape_dims` (Tuple[int, ...]) : Lattice dimensions for reshaping the input, e.g., `(Lx, Ly)`. - `features` (int) : Number of feature channels (network width). Default: 32. - `depth` (int) : Number of residual blocks. Default: 4. - `kernel_size` (Union[int, Tuple[int,...]]) : Spatial kernel size. Default: 3 (becomes (3,3) for 2D). Returns ------- GeneralNet An initialized or wrapped network instance compatible with the general_python framework. The returned object is callable with signature: ``log_psi = net(params, inputs)`` where inputs has shape ``(batch, features)`` and output has shape ``(batch,)``. """ # 1. Handle Strings and Enums (Lazy Load Path) if isinstance(network_type, (str, Networks)): key = str(network_type).lower() # Argument Pre-processing for convenience if key == 'rbm': # Allow `alpha` or `hidden_density` to define `n_hidden` alpha = kwargs.pop('alpha', None) or kwargs.pop('hidden_density', None) if alpha is not None and 'n_hidden' not in kwargs: if not input_shape: raise ValueError("`input_shape` must be provided when using `alpha` for RBM.") n_visible = np.prod(input_shape) kwargs['n_hidden'] = int(alpha * n_visible) elif key == 'rbmpp': kwargs['use_rbm'] = True elif key in ('cnn', 'res', 'resnet'): reshape_dims = _resolve_conv_reshape_dims(input_shape, kwargs) if reshape_dims is not None: kwargs['reshape_dims'] = reshape_dims _consume_conv_shape_hints(kwargs) net_cls = _lazy_load_class(key) return net_cls(input_shape=input_shape, backend=backend, dtype=dtype, param_dtype=param_dtype, seed=seed, **kwargs) # 2. Handle Existing Instances (Return as-is) if not isinstance(network_type, type) and ( isinstance(network_type, GeneralNet) or (hasattr(network_type, 'get_params') and hasattr(network_type, 'apply')) ): return network_type # 3. Handle Types/Classes if isinstance(network_type, type): # It is a subclass of GeneralNet (e.g. user imported RBM manually) if issubclass(network_type, GeneralNet): return network_type(input_shape=input_shape, backend=backend, dtype=dtype, param_dtype=param_dtype, seed=seed, **kwargs) # It is a Flax Module (Auto-Wrap Logic) # We check this loosely to avoid importing flax if not needed. is_flax = False flax_instance = None try: import flax.linen as nn if isinstance(network_type, type) and issubclass(network_type, nn.Module): is_flax = True elif isinstance(network_type, nn.Module): is_flax = True flax_instance = network_type except ImportError: pass if is_flax: # Lazy import the interface wrapper from .net_impl.interface_net_flax import FlaxInterface # We pass only module-constructor kwargs to FlaxInterface. all_kwargs = kwargs.copy() all_kwargs.pop('input_shape', None) # Remove if present in kwargs all_kwargs.pop('backend', None) all_kwargs.pop('dtype', None) all_kwargs.pop('param_dtype', None) net_module = flax_instance if flax_instance is not None else network_type return FlaxInterface( net_module = net_module, net_kwargs = all_kwargs, input_shape = input_shape, backend = backend, dtype = dtype, seed = seed, input_activation = kwargs.get('input_activation', None) ) # Handle generic Callables (Factories) if callable(network_type): return CallableNet(callable_fun=network_type, input_shape=input_shape, backend=backend, dtype=dtype, **kwargs) raise ValueError(f"Unknown network type: {type(network_type)}")
###################################################################### #! END OF FILE ######################################################################