Source code for general_python.ml.training_phases

"""
Learning phase framework for Neural Quantum State training.

This module implements a multi-phase training system for NQS, allowing:
- Phase transitions with configurable parameters
- Phase-specific callbacks and hooks
- Adaptive learning rates per phase
- Regularization scheduling per phase
- Progress tracking and reporting

Learning phases represent different stages of optimization:

1. **Pre-training**: Initialize network with simple loss, high learning rate
2. **Main Optimization**: Full Hamiltonian, adaptive learning rate
3. **Refinement**: Fine-tune observables, low learning rate, high regularization

Quick Start
-----------
**Using Presets:**

>>> from general_python.ml.training_phases import create_phase_schedulers
>>> lr_sched, reg_sched = create_phase_schedulers('default')
>>> # Pass to NQSTrainer: phases=(lr_sched, reg_sched)

**Creating Custom Phases:**

>>> from general_python.ml.training_phases import LearningPhase, PhaseType, PhaseScheduler
>>> 
>>> my_phases = [
...     LearningPhase(
...         name="warmup", epochs=50,
...         lr=0.1, lr_schedule="exponential", lr_kwargs={'lr_decay': 0.05},
...         reg=0.01
...     ),
...     LearningPhase(
...         name="main", epochs=300,
...         lr=0.02, lr_schedule="adaptive", lr_kwargs={'patience': 20, 'lr_decay': 0.5},
...         reg=0.001
...     ),
... ]
>>> lr_sched = PhaseScheduler(my_phases, param_type='lr')
>>> reg_sched = PhaseScheduler(my_phases, param_type='reg')

Available Scheduler Types
-------------------------
- ``'constant'``: Fixed value
- ``'exponential'``: Exponential decay: lr * exp(-decay * epoch)
- ``'step'``: Step decay: lr * gamma^floor(epoch/step_size)
- ``'cosine'``: Cosine annealing to min_lr
- ``'linear'``: Linear decay to min_lr
- ``'adaptive'``: ReduceLROnPlateau (requires loss)

Available Presets
-----------------
- ``'default'``: 3-phase (pre_training: 50, main: 200, refinement: 100)
- ``'kitaev'``: Specialized for frustrated spin systems (pre: 100, main: 300, fine: 150)

----------------------------------------
File        : NQS/src/learning_phases.py
Author      : Maksymilian Kliczkowski
Email       : maksymilian.kliczkowski@pwr.edu.pl
Date        : November 1, 2025
----------------------------------------
"""

from dataclasses import dataclass, field
from typing import Optional, Callable, List, Dict, Any
from enum import Enum, auto
import numpy as np

try:
    from .schedulers import choose_scheduler, Parameters, SchedulerType
except ImportError as e:
    raise ImportError("Failed to import schedulers module. Ensure general_python package is correctly installed.") from e

from dataclasses import dataclass, field
from typing import Optional, Callable, List, Dict, Any, Union, Tuple
from enum import Enum, auto
import numpy as np

# ----------------------------------------------------------------------
# Core Data Structures
# ----------------------------------------------------------------------

[docs] class PhaseType(Enum): """Semantic categories for training phases.""" PRE_TRAINING = auto() MAIN = auto() REFINEMENT = auto() CUSTOM = auto()
[docs] @dataclass class LearningPhase: """ Configuration for a specific training phase. Each phase defines learning rate and regularization schedules that are active for a specific number of epochs. Phases are processed sequentially by the PhaseScheduler. Attributes ---------- name : str Human-readable phase identifier (e.g., 'warmup', 'main', 'fine'). epochs : int Number of epochs this phase lasts. phase_type : PhaseType Semantic type (PRE_TRAINING, MAIN, REFINEMENT, CUSTOM). lr : float Initial learning rate for this phase. lr_schedule : str Scheduler type for LR. Options: - 'constant': Fixed lr throughout phase - 'exponential': lr * exp(-lr_decay * local_epoch) - 'step': lr * lr_decay^floor(local_epoch/step_size) - 'cosine': Cosine annealing from lr to min_lr - 'linear': Linear decay from lr to min_lr - 'adaptive': ReduceLROnPlateau (requires loss) lr_kwargs : Dict[str, Any] Extra arguments for the LR scheduler. Common keys: - 'lr_decay': Decay rate (exponential, step, adaptive) - 'step_size': Steps between decays (step scheduler) - 'min_lr': Minimum LR (cosine, linear, adaptive) - 'patience': Epochs before reduction (adaptive) - 'min_delta': Minimum improvement threshold (adaptive) reg : float Initial regularization (diagonal shift) for this phase. reg_schedule : str Scheduler type for regularization. Same options as lr_schedule. reg_kwargs : Dict[str, Any] Extra arguments for the regularization scheduler. loss_type : str Loss function type (default: 'energy'). beta_penalty : float Penalty coefficient for excited state targeting. on_phase_start : Callable, optional Callback executed when phase begins. on_phase_end : Callable, optional Callback executed when phase ends. Examples -------- >>> # Exponential decay warmup >>> warmup = LearningPhase( ... name='warmup', epochs=50, ... lr=0.1, lr_schedule='exponential', lr_kwargs={'lr_decay': 0.05}, ... reg=0.01 ... ) >>> >>> # Adaptive main phase (ReduceLROnPlateau) >>> main = LearningPhase( ... name='main', epochs=300, ... lr=0.02, lr_schedule='adaptive', ... lr_kwargs={'patience': 20, 'lr_decay': 0.5, 'min_lr': 1e-4}, ... reg=0.001 ... ) >>> >>> # Cosine annealing refinement >>> refine = LearningPhase( ... name='fine', epochs=100, ... lr=0.01, lr_schedule='cosine', lr_kwargs={'min_lr': 1e-5}, ... reg=0.005 ... ) """ name : str = "phase" epochs : int = 100 phase_type : PhaseType = PhaseType.MAIN # LR Configuration lr : float = 1e-2 # Initial LR for this phase lr_schedule : str = "constant" # Extra args passed to scheduler factory (e.g., {'lr_decay': 0.9, 'step_size': 10}) lr_kwargs : Dict[str, Any] = field(default_factory=dict) # Regularization Configuration reg : float = 1e-3 # Initial Reg for this phase reg_schedule : str = "constant" # Extra args passed to scheduler factory reg_kwargs : Dict[str, Any] = field(default_factory=dict) # Physics/Loss specifics loss_type : str = "energy" beta_penalty : float = 0.0 # Callbacks on_phase_start : Optional[Callable] = None on_phase_end : Optional[Callable] = None
# ---------------------------------------------------------------------- # Presets # ---------------------------------------------------------------------- def _get_presets() -> Dict[str, List[LearningPhase]]: return { "default": [ # 1. Pre-training: Initialize network with simple loss, high learning rate LearningPhase( name = "pre_training", epochs = 50, phase_type = PhaseType.PRE_TRAINING, lr = 1e-1, lr_schedule = "exponential", lr_kwargs = {'lr_decay': 1e-2}, # decay rate reg = 5e-2, reg_schedule = "constant" ), # 2. Main Optimization: Full Hamiltonian, adaptive learning rate LearningPhase( name = "main", epochs = 200, phase_type = PhaseType.MAIN, lr = 3e-2, lr_schedule = "adaptive", lr_kwargs = {'patience': 20, 'lr_decay': 0.5}, reg = 1e-3, reg_schedule = "constant" ), # 3. Refinement: Fine-tune observables with low learning rate LearningPhase( name = "refinement", epochs = 100, phase_type = PhaseType.REFINEMENT, lr = 1e-2, lr_schedule = "cosine", lr_kwargs = {'min_lr': 1e-5}, reg = 5e-3, reg_schedule = "constant" ) ], # Example of a specialized preset for frustrated systems "kitaev": [ LearningPhase( name = "pre", lr = 6e-2, lr_schedule = "step", lr_kwargs = {'step_size': 150, 'lr_decay': 0.5}, reg = 5e-2 ), LearningPhase( name = "main", epochs = 300, phase_type = PhaseType.MAIN, lr = 3e-2, lr_schedule = "adaptive", lr_kwargs = {'patience': 100, 'min_delta': 1e-4}, reg = 1e-3 ), # In refinement, we might want to increase regularization (annealing) LearningPhase( name = "fine", epochs = 150, phase_type = PhaseType.REFINEMENT, lr = 5e-3, lr_schedule = "cosine", lr_kwargs = {'min_lr': 1e-4, 'lr_decay': 1e-2}, reg = 1e-3, reg_schedule = "linear", reg_kwargs = {} ) ] } # ---------------------------------------------------------------------- # The Smart Scheduler (The Orchestrator) # ----------------------------------------------------------------------
[docs] class PhaseScheduler: """ Manages transitions between training phases. The PhaseScheduler orchestrates multi-phase training by: 1. Tracking the current phase based on global epoch count 2. Instantiating appropriate low-level schedulers for each phase 3. Firing callbacks on phase transitions 4. Returning scheduled values via __call__ Parameters ---------- phases : List[LearningPhase] Ordered list of training phases to execute. param_type : str, default='lr' Which parameter to schedule ('lr' or 'reg'). logger : Logger, optional Logger for phase transition messages. Attributes ---------- current_phase : LearningPhase Currently active phase. history : List[float] All scheduled values returned. Examples -------- >>> from general_python.ml.training_phases import LearningPhase, PhaseScheduler >>> >>> phases = [ ... LearningPhase(name='warmup', epochs=50, lr=0.1, lr_schedule='exponential', ... lr_kwargs={'lr_decay': 0.05}), ... LearningPhase(name='main', epochs=200, lr=0.02, lr_schedule='constant'), ... ] >>> >>> lr_scheduler = PhaseScheduler(phases, param_type='lr') >>> reg_scheduler = PhaseScheduler(phases, param_type='reg') >>> >>> # Use in training loop >>> for epoch in range(250): ... lr = lr_scheduler(epoch, loss=current_loss) # Auto phase transition ... reg = reg_scheduler(epoch, loss=current_loss) """
[docs] def __init__(self, phases: List[LearningPhase], param_type: str = 'lr', logger=None): self.phases = phases self.param_type = param_type # 'lr' or 'reg' self.logger = logger self.history = [] # State self._current_phase_idx = 0 self._epochs_completed_in_prev_phases = 0 # The active engine (instance of Parameters from schedulers.py) self._active_engine: Optional[Parameters] = None self._init_current_phase_engine()
@property def current_phase(self) -> LearningPhase: """Currently active learning phase. If all configured phases are exhausted, the final phase is returned so callers can continue querying terminal settings. """ if self._current_phase_idx >= len(self.phases): return self.phases[-1] return self.phases[self._current_phase_idx] # --------- def _init_current_phase_engine(self): """Uses the Factory to create the scheduler for the current phase.""" phase = self.current_phase if self.param_type == 'lr' or self.param_type == 'dt': init_val = phase.lr sched_type = phase.lr_schedule kwargs = phase.lr_kwargs else: init_val = phase.reg sched_type = phase.reg_schedule kwargs = phase.reg_kwargs # Logging if self.logger: self.logger.info(f"Phase '{phase.name}': Init {self.param_type.upper()} scheduler '{sched_type}, val={init_val:.2e}, epochs={phase.epochs}", color='cyan') # Instantiate via your factory self._active_engine = choose_scheduler( scheduler_type = sched_type, initial_lr = init_val, # Factory expects 'initial_lr', works for Reg too max_epochs = phase.epochs, logger = self.logger, **kwargs ) def _update_phase_state(self, global_epoch: int): """Check if we need to switch phases.""" local_epoch = global_epoch - self._epochs_completed_in_prev_phases # Check transition condition while local_epoch >= self.current_phase.epochs and self._current_phase_idx < len(self.phases) - 1: # Fire end callback if self.current_phase.on_phase_end: self.current_phase.on_phase_end() # Advance self._epochs_completed_in_prev_phases += self.current_phase.epochs self._current_phase_idx += 1 local_epoch = global_epoch - self._epochs_completed_in_prev_phases # Fire start callback if self.current_phase.on_phase_start: self.current_phase.on_phase_start() # Re-Initialize the Engine for the new phase self._init_current_phase_engine() return max(0, local_epoch)
[docs] def __call__(self, global_epoch: int, loss: float = None) -> float: """ Delegates calculation to the specific scheduler instance. """ # Delegate math to the active engine (schedulers.py) local_epoch = self._update_phase_state(global_epoch) val = self._active_engine(local_epoch, _metric=loss) self.history.append(val) return val
# ---------------------------------------------------------------------- # Factory # ----------------------------------------------------------------------
[docs] def create_phase_schedulers(preset: str = 'default', logger=None): """ Factory function to create LR and Reg schedulers from a preset. Parameters ---------- preset : str, default='default' Preset name. Available: - 'default': 3-phase training (350 total epochs) - 'kitaev': Specialized for frustrated systems (550 total epochs) logger : Logger, optional Logger for scheduler messages. Returns ------- Tuple[PhaseScheduler, PhaseScheduler] (lr_scheduler, reg_scheduler) tuple. Raises ------ ValueError If preset name is not recognized. Examples -------- >>> lr_sched, reg_sched = create_phase_schedulers('default') >>> >>> # Pass to NQSTrainer >>> trainer = NQSTrainer(nqs, phases=(lr_sched, reg_sched)) >>> >>> # Or use preset string directly >>> trainer = NQSTrainer(nqs, phases='default') # Equivalent """ presets = _get_presets() if preset not in presets: raise ValueError(f"Unknown preset '{preset}'. Available: {list(presets.keys())}") phases = presets[preset] lr_sched = PhaseScheduler(phases, 'lr', logger) reg_sched = PhaseScheduler(phases, 'reg', logger) return lr_sched, reg_sched
PRESETS = _get_presets() # ----------------------------------- #! End of file # -----------------------------------