Source code for general_python.ml.schedulers

'''

Scheduler implementations for machine learning training.
It includes various learning rate schedulers and an early stopping mechanism.

Namely, it provides:
- ConstantScheduler
- ExponentialDecayScheduler
- StepDecayScheduler
- CosineAnnealingScheduler
- AdaptiveScheduler (ReduceLROnPlateau)

For the usage, either create scheduler instances directly or use the
`choose_scheduler` factory function.

>>> # Example: Create an exponential decay scheduler
>>> from general_python.ml.schedulers import choose_scheduler
>>> scheduler = choose_scheduler('exponential', initial_lr=0.01, max_epochs=100, lr_decay=0.1)
>>> for epoch in range(10):
>>>     lr = scheduler(epoch)
>>>     print(f"Epoch {epoch}: Learning Rate = {lr:.6f}")

---------------------------------------------------------------
file    : general_python/ml/schedulers.py
author  : Maksymilian Kliczkowski
email   : maksymilian.kliczkowski@pwr.edu.pl
---------------------------------------------------------------
'''

import enum
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Any, TYPE_CHECKING

if TYPE_CHECKING:
    from ..common.flog import Logger

# ##############################################################################
#! Constants
# ##############################################################################

_INF = float('inf')

# ##############################################################################
#! Early Stopping
# ##############################################################################

[docs] class BaseSchedulerLogger(ABC): """ Abstract Base Class providing logging capabilities to schedulers. """
[docs] def __init__(self, logger: Optional["Logger"]): super().__init__() self._logger = logger
def _log(self, message: str, log: Union[int, str] = 'info', lvl: int = 0, color: str = "white", **kwargs): """Internal logging helper.""" # Format the message with class name full_msg = f"[{self.__class__.__name__}] {message}" if self._logger is not None and hasattr(self._logger, "say"): if color and hasattr(self._logger, "colorize"): full_msg = self._logger.colorize(full_msg, color) # Map string log levels to integer if needed, or rely on Logger defaults if isinstance(log, str) and hasattr(Logger, 'LEVELS_R'): log_val = Logger.LEVELS_R.get(log, 20) # Default to INFO else: log_val = log self._logger.say(full_msg, lvl=lvl, log=log_val, **kwargs) else: # Fallback if no logger provided if log not in ['debug', 'info', 10, 20]: print(f"{str(log).upper()}: {full_msg}") @property def logger(self) -> Optional['Logger']: """Logger used for scheduler diagnostics, if one is configured.""" return self._logger @logger.setter def logger(self, logger: 'Logger'): """Set the logger used for scheduler diagnostics.""" self._logger = logger
# ##############################################################################
[docs] class EarlyStopping(BaseSchedulerLogger): """ Monitors a metric and determines if training should stop. """
[docs] def __init__(self, patience: int = 0, min_delta: float = 1e-3, logger: Optional['Logger'] = None): super().__init__(logger=logger) if patience is not None and patience < 0: raise ValueError("Patience must be non-negative.") if min_delta < 0.0: raise ValueError("min_delta must be non-negative.") self._patience = patience self._min_delta = min_delta self._best_metric = _INF self._epoch_since_best = 0 self._stop_training = False
[docs] def __call__(self, _metric: Union[float, complex, np.number]) -> bool: """ Args: _metric: The metric value (e.g. loss). Real part used if complex. """ # Type check and conversion if not isinstance(_metric, (float, complex, np.number)): try: _metric = float(_metric) except (ValueError, TypeError): raise TypeError("Metric must be numeric.") # Extract real part val_r = _metric.real if isinstance(_metric, complex) else float(_metric) val_i = _metric.imag if isinstance(_metric, complex) else 0.0 # Check NaN/Inf if np.isnan(val_r) or np.isnan(val_i) or np.isinf(val_r) or np.isinf(val_i): self._log("Received NaN or Inf metric. Stopping.", log='error', color='red') return True # Check Disabled if not self._patience: return False # Logic if val_r < (self._best_metric - self._min_delta): self._log(f"Metric improved to {val_r:.4e}.", log='debug', lvl=1) self._best_metric = val_r self._epoch_since_best = 0 self._stop_training = False else: self._epoch_since_best += 1 self._log(f"No improvement for {self._epoch_since_best} epoch(s). Best: {self._best_metric:.4e}", log='debug', lvl=1) if self._epoch_since_best >= self._patience: self._log(f"Patience ({self._patience}) exceeded. Stopping.", log='info', color='yellow') self._stop_training = True return self._stop_training
[docs] @classmethod def from_kwargs(cls, **kwargs): """Create an :class:`EarlyStopping` instance from common keyword names.""" patience = kwargs.get( "patience", kwargs.get( "early_stopping_patience", kwargs.get("es_patience", kwargs.get("early_patience", 0)), ), ) min_delta = kwargs.get( "min_delta", kwargs.get( "early_stopping_min_delta", kwargs.get("es_min_delta", kwargs.get("early_min_delta", 1e-3)), ), ) logger = kwargs.get("logger", None) return cls(patience=patience, min_delta=min_delta, logger=logger)
[docs] def reset(self): """Reset the stored best metric and patience counter.""" self._best_metric = _INF self._epoch_since_best = 0 self._stop_training = False
@property def best_metric(self) -> float: """Best metric observed since initialization or last reset.""" return self._best_metric
# ############################################################################## #! Base Parameters & Schedulers # ##############################################################################
[docs] class SchedulerType(enum.Enum): """Supported learning-rate scheduler families.""" CONSTANT = 0 EXPONENTIAL = 1 STEP = 2 COSINE = 3 ADAPTIVE = 4 LINEAR = 5 def __str__(self): return self.name.lower() def __repr__(self): return self.__str__()
[docs] class Parameters(BaseSchedulerLogger, ABC): """Base class for stateful learning-rate schedules."""
[docs] def __init__(self, initial_lr : float, max_epochs : int, lr_decay : float, lr_clamp : Optional[float] = None, logger : Optional['Logger'] = None, es : Optional[EarlyStopping] = None): ''' Base class for learning rate schedulers. Args: initial_lr: Initial learning rate. max_epochs: Maximum number of epochs. lr_decay: Decay rate (meaning depends on scheduler type). lr_clamp: Minimum learning rate clamp. logger: Optional logger for logging. es: Optional EarlyStopping instance. ''' super().__init__(logger=logger) if initial_lr <= 0: raise ValueError("Initial LR must be positive.") if max_epochs <= 0: raise ValueError("Max epochs must be positive.") self._initial_lr = initial_lr self._max_epochs = max_epochs self._lr_decay = lr_decay self._lr_clamp = lr_clamp self._early_stopping = es self._lr = initial_lr self._lr_history = [] self._typek = None # To be set by subclass
[docs] @abstractmethod def __call__(self, _epoch: int, _metric: Optional[Any] = None) -> float: """Calculate LR for the epoch.""" pass
def _update_and_log_lr(self, new_lr: float) -> float: """Helper to clamp, update state, and log.""" if self._lr_clamp is not None: new_lr = np.maximum(new_lr, self._lr_clamp) self._lr = new_lr self._lr_history.append(self._lr) return self._lr #! Common Properties @property def lr(self) -> float: """Most recently emitted learning rate.""" return self._lr @property def history(self) -> List[float]: """Learning-rate values emitted so far.""" return self._lr_history @property def early_stopping(self) -> Optional[EarlyStopping]: """Attached early-stopping monitor, if configured.""" return self._early_stopping #! ES Proxies
[docs] def set_early_stopping(self, patience: int, min_delta: float = 1e-3): """Attach a new early-stopping monitor to this scheduler.""" self._early_stopping = EarlyStopping(patience, min_delta, self._logger)
[docs] def check_stop(self, _metric) -> bool: """Return whether the attached early-stopping monitor requests stop.""" if self._early_stopping: return self._early_stopping(_metric) return False
# ############################################################################## #! CONCRETE SCHEDULERS # ##############################################################################
[docs] class ConstantScheduler(Parameters): """Scheduler that always returns the initial learning rate."""
[docs] def __init__(self, initial_lr: float, max_epochs: int, lr_clamp=None, logger=None, es=None, **kwargs): super().__init__(initial_lr, max_epochs, lr_decay=1.0, lr_clamp=lr_clamp, logger=logger, es=es) self._typek = SchedulerType.CONSTANT
def __call__(self, _epoch: int, _metric=None) -> float: return self._update_and_log_lr(self._initial_lr)
# ------------------------------------------------------------------------------
[docs] class ExponentialDecayScheduler(Parameters): """ Multiplicative exponential decay: lr = initial_lr * (gamma ^ epoch) This follows PyTorch's ExponentialLR convention where: - gamma = 0.99 means 1% decay per epoch - gamma = 0.999 means 0.1% decay per epoch For the old exp(-rate * epoch) behavior, use gamma = exp(-rate). """
[docs] def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float = 0.99, lr_clamp=None, logger=None, es=None, **kwargs): super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es) self._typek = SchedulerType.EXPONENTIAL
def __call__(self, _epoch: int, _metric=None) -> float: # Multiplicative decay: lr = lr_0 * gamma^epoch (like PyTorch ExponentialLR) current_epoch = max(0, _epoch) new_lr = self._initial_lr * np.power(self._lr_decay, current_epoch) return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs] class LinearScheduler(Parameters): """ Linearly decays LR from initial_lr to min_lr (default 0) over max_epochs. """
[docs] def __init__(self, initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs): # We store min_lr but pass 1.0 as dummy decay super().__init__(initial_lr, max_epochs, lr_decay=1.0, lr_clamp=lr_clamp, logger=logger, es=es) self._min_lr = min_lr self._typek = SchedulerType.LINEAR
def __call__(self, _epoch: int, _metric=None) -> float: current_epoch = max(0, min(_epoch, self._max_epochs)) alpha = current_epoch / self._max_epochs # Interpolate: (1-alpha)*start + alpha*end new_lr = (1 - alpha) * self._initial_lr + alpha * self._min_lr return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs] class StepDecayScheduler(Parameters): """ lr = initial_lr * decay_factor ^ floor(epoch / step_size) """
[docs] def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float, step_size: int, lr_clamp=None, logger=None, es=None, **kwargs): super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es) if step_size <= 0: raise ValueError("Step size must be positive.") self.step_size = step_size self._typek = SchedulerType.STEP
def __call__(self, _epoch: int, _metric=None) -> float: exponent = np.floor(max(0, _epoch) / self.step_size) new_lr = self._initial_lr * np.power(self._lr_decay, exponent) return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs] class CosineAnnealingScheduler(Parameters): """Cosine annealing schedule from ``initial_lr`` to ``min_lr``."""
[docs] def __init__(self, initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs): super().__init__(initial_lr, max_epochs, lr_decay=0.0, lr_clamp=lr_clamp, logger=logger, es=es) self.min_lr = min_lr self._typek = SchedulerType.COSINE
def __call__(self, _epoch: int, _metric=None) -> float: cur_ep = np.clip(_epoch, 0, self._max_epochs) if self._max_epochs <= 0: cosine_term = -1.0 else: cosine_term = np.cos(np.pi * cur_ep / self._max_epochs) lr_range = self._initial_lr - self.min_lr new_lr = self.min_lr + 0.5 * lr_range * (1.0 + cosine_term) return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs] class AdaptiveScheduler(Parameters): """ ReduceLROnPlateau logic """
[docs] def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float=0.1, patience: int=100, min_lr: float = 1e-5, cooldown: int = 0, min_delta: float = 1e-4, lr_clamp=None, logger=None, es=None, **kwargs): super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es) self.patience = patience self.min_lr = min_lr self.cooldown = cooldown self.min_delta = min_delta self._cooldown_counter = 0 self._best_metric = _INF self._num_bad_epochs = 0 self._typek = SchedulerType.ADAPTIVE
def __call__(self, _epoch: int, _metric: Optional[Any]) -> float: if _metric is None: raise ValueError("AdaptiveScheduler requires a metric.") # Safe metric extraction if isinstance(_metric, complex): metric_val = _metric.real else: metric_val = float(_metric) if np.isnan(metric_val) or np.isinf(metric_val): return self._update_and_log_lr(self._lr) # Cooldown check if self._cooldown_counter > 0: self._cooldown_counter -= 1 self._num_bad_epochs = 0 return self._update_and_log_lr(self._lr) # Improvement Check if metric_val < self._best_metric - self.min_delta: self._best_metric = metric_val self._num_bad_epochs = 0 else: self._num_bad_epochs += 1 # Reduction Logic if self._num_bad_epochs > self.patience: new_lr = max(self._lr * self._lr_decay, self.min_lr) if new_lr < self._lr: self._log(f"Reducing LR to {new_lr:.2e} (Plateau).", log='info', color='yellow') self._lr = new_lr self._cooldown_counter = self.cooldown self._num_bad_epochs = 0 return self._update_and_log_lr(self._lr)
[docs] def reset(self): """Reset plateau-tracking state while keeping scheduler configuration.""" self._cooldown_counter = 0 self._best_metric = _INF self._num_bad_epochs = 0
# ############################################################################## #! FACTORY # ############################################################################## SCHEDULER_CLASS_MAP = { SchedulerType.CONSTANT: {'class': ConstantScheduler, 'args': []}, SchedulerType.EXPONENTIAL: {'class': ExponentialDecayScheduler, 'args': ['lr_decay']}, SchedulerType.STEP: {'class': StepDecayScheduler, 'args': ['lr_decay', 'step_size']}, SchedulerType.COSINE: {'class': CosineAnnealingScheduler, 'args': ['min_lr']}, SchedulerType.ADAPTIVE: {'class': AdaptiveScheduler, 'args': ['lr_decay', 'patience', 'min_lr', 'cooldown', 'min_delta']}, SchedulerType.LINEAR: {'class': LinearScheduler, 'args': ['min_lr']}, } # Add string aliases SCHEDULER_CLASS_MAP["constant"] = SCHEDULER_CLASS_MAP[SchedulerType.CONSTANT] SCHEDULER_CLASS_MAP["exponential"] = SCHEDULER_CLASS_MAP[SchedulerType.EXPONENTIAL] SCHEDULER_CLASS_MAP["step"] = SCHEDULER_CLASS_MAP[SchedulerType.STEP] SCHEDULER_CLASS_MAP["cosine"] = SCHEDULER_CLASS_MAP[SchedulerType.COSINE] SCHEDULER_CLASS_MAP["adaptive"] = SCHEDULER_CLASS_MAP[SchedulerType.ADAPTIVE] SCHEDULER_CLASS_MAP["linear"] = SCHEDULER_CLASS_MAP[SchedulerType.LINEAR]
[docs] def choose_scheduler(scheduler_type : Union[str, SchedulerType, Parameters], initial_lr : float, max_epochs : int, logger : Optional['Logger'] = None, **kwargs) -> Parameters: """ Factory to create a scheduler instance. This function can either accept: - A string or SchedulerType enum to specify the type of scheduler to create. - An existing Parameters instance to reconfigure. Args: scheduler_type: Type of scheduler or existing instance. initial_lr: Initial learning rate. max_epochs: Maximum number of epochs. logger: Optional logger for the scheduler. **kwargs: Additional arguments specific to the scheduler type. - lr_decay: Decay rate for exponential/step/adaptive schedulers. - step_size: Step size for step scheduler. - min_lr: Minimum learning rate for cosine/linear/adaptive schedulers. - patience: Patience for adaptive scheduler. - cooldown: Cooldown period for adaptive scheduler. - min_delta: Minimum improvement for adaptive scheduler. - early_stopping_patience: Patience for early stopping. - early_stopping_min_delta: Minimum improvement for early stopping. Returns: An instance of Parameters (scheduler). Raises: ValueError: If the scheduler type is unknown. """ if kwargs.get('lr_step_size', None) is not None: kwargs['step_size'] = kwargs.pop('lr_step_size') # FINAL if kwargs.get('lr_min', None) is not None: kwargs['min_lr'] = kwargs.pop('lr_min') if kwargs.get('lr_final', None) is not None: kwargs['min_lr'] = kwargs.pop('lr_final') if kwargs.get('lr_initial', None) is not None: initial_lr = kwargs.pop('lr_initial') # DECAY if kwargs.get('lr_patience', None) is not None: kwargs['patience'] = kwargs.pop('lr_patience') if kwargs.get('lr_min_delta', None) is not None: kwargs['min_delta'] = kwargs.pop('lr_min_delta') if kwargs.get('lr_cooldown', None) is not None: kwargs['cooldown'] = kwargs.pop('lr_cooldown') if kwargs.get('lr_decay_rate', None) is not None: kwargs['lr_decay'] = kwargs.pop('lr_decay_rate') # EPOCHS if kwargs.get('lr_max_epochs', None) is not None: max_epochs = kwargs.pop('lr_max_epochs') # EARLY STOPPING if kwargs.get('lr_es', None) is not None: kwargs['early_stopping_patience'] = kwargs.pop('lr_es') if kwargs.get('lr_es_min'): kwargs['early_stopping_min_delta'] = kwargs.pop('lr_es_min') # 1. Handle Existing Instance if isinstance(scheduler_type, Parameters): if logger: scheduler_type.logger = logger if kwargs.get('lr_clamp'): scheduler_type._lr_clamp = kwargs['lr_clamp'] # Reconfigure ES if args present if 'early_stopping_patience' in kwargs: scheduler_type.set_early_stopping(kwargs['early_stopping_patience'], kwargs.get('early_stopping_min_delta', 1e-4)) return scheduler_type # 2. Resolve Type key = scheduler_type.lower() if isinstance(scheduler_type, str) else scheduler_type config = SCHEDULER_CLASS_MAP.get(key) if not config: raise ValueError(f"Unknown scheduler type: {scheduler_type}") cls = config['class'] # 3. Setup Early Stopping es = None if 'early_stopping_patience' in kwargs: es = EarlyStopping(kwargs['early_stopping_patience'], kwargs.get('early_stopping_min_delta', 1e-4), logger) # 4. Build Args # Filter kwargs to only what the specific scheduler needs + base args valid_args = {'initial_lr', 'max_epochs', 'lr_clamp', 'logger', 'es', } valid_args.update(config['args']) build_kwargs = { 'initial_lr' : initial_lr, 'max_epochs' : max_epochs, 'logger' : logger, 'es' : es, 'lr_clamp' : kwargs.get('lr_clamp') } # Add specific args from kwargs if they exist for arg in config['args']: if arg in kwargs: build_kwargs[arg] = kwargs[arg] elif arg == 'lr_decay': build_kwargs['lr_decay'] = 0.99 # Default try: if logger: logger.info(f"Creating scheduler '{scheduler_type}' with args: {build_kwargs}", lvl=3, color='cyan') return cls(**build_kwargs) except TypeError as e: if logger: logger.say(f"Error creating scheduler '{scheduler_type}': {str(e)}", log='error', color='red') raise e
############################################################################### #! End of file ###############################################################################