'''
Scheduler implementations for machine learning training.
It includes various learning rate schedulers and an early stopping mechanism.
Namely, it provides:
- ConstantScheduler
- ExponentialDecayScheduler
- StepDecayScheduler
- CosineAnnealingScheduler
- AdaptiveScheduler (ReduceLROnPlateau)
For the usage, either create scheduler instances directly or use the
`choose_scheduler` factory function.
>>> # Example: Create an exponential decay scheduler
>>> from general_python.ml.schedulers import choose_scheduler
>>> scheduler = choose_scheduler('exponential', initial_lr=0.01, max_epochs=100, lr_decay=0.1)
>>> for epoch in range(10):
>>> lr = scheduler(epoch)
>>> print(f"Epoch {epoch}: Learning Rate = {lr:.6f}")
---------------------------------------------------------------
file : general_python/ml/schedulers.py
author : Maksymilian Kliczkowski
email : maksymilian.kliczkowski@pwr.edu.pl
---------------------------------------------------------------
'''
import enum
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Any, TYPE_CHECKING
if TYPE_CHECKING:
from ..common.flog import Logger
# ##############################################################################
#! Constants
# ##############################################################################
_INF = float('inf')
# ##############################################################################
#! Early Stopping
# ##############################################################################
[docs]
class BaseSchedulerLogger(ABC):
"""
Abstract Base Class providing logging capabilities to schedulers.
"""
[docs]
def __init__(self, logger: Optional["Logger"]):
super().__init__()
self._logger = logger
def _log(self, message: str, log: Union[int, str] = 'info', lvl: int = 0, color: str = "white", **kwargs):
"""Internal logging helper."""
# Format the message with class name
full_msg = f"[{self.__class__.__name__}] {message}"
if self._logger is not None and hasattr(self._logger, "say"):
if color and hasattr(self._logger, "colorize"):
full_msg = self._logger.colorize(full_msg, color)
# Map string log levels to integer if needed, or rely on Logger defaults
if isinstance(log, str) and hasattr(Logger, 'LEVELS_R'):
log_val = Logger.LEVELS_R.get(log, 20) # Default to INFO
else:
log_val = log
self._logger.say(full_msg, lvl=lvl, log=log_val, **kwargs)
else:
# Fallback if no logger provided
if log not in ['debug', 'info', 10, 20]:
print(f"{str(log).upper()}: {full_msg}")
@property
def logger(self) -> Optional['Logger']:
"""Logger used for scheduler diagnostics, if one is configured."""
return self._logger
@logger.setter
def logger(self, logger: 'Logger'):
"""Set the logger used for scheduler diagnostics."""
self._logger = logger
# ##############################################################################
[docs]
class EarlyStopping(BaseSchedulerLogger):
"""
Monitors a metric and determines if training should stop.
"""
[docs]
def __init__(self, patience: int = 0, min_delta: float = 1e-3, logger: Optional['Logger'] = None):
super().__init__(logger=logger)
if patience is not None and patience < 0: raise ValueError("Patience must be non-negative.")
if min_delta < 0.0: raise ValueError("min_delta must be non-negative.")
self._patience = patience
self._min_delta = min_delta
self._best_metric = _INF
self._epoch_since_best = 0
self._stop_training = False
[docs]
def __call__(self, _metric: Union[float, complex, np.number]) -> bool:
"""
Args:
_metric: The metric value (e.g. loss). Real part used if complex.
"""
# Type check and conversion
if not isinstance(_metric, (float, complex, np.number)):
try:
_metric = float(_metric)
except (ValueError, TypeError):
raise TypeError("Metric must be numeric.")
# Extract real part
val_r = _metric.real if isinstance(_metric, complex) else float(_metric)
val_i = _metric.imag if isinstance(_metric, complex) else 0.0
# Check NaN/Inf
if np.isnan(val_r) or np.isnan(val_i) or np.isinf(val_r) or np.isinf(val_i):
self._log("Received NaN or Inf metric. Stopping.", log='error', color='red')
return True
# Check Disabled
if not self._patience:
return False
# Logic
if val_r < (self._best_metric - self._min_delta):
self._log(f"Metric improved to {val_r:.4e}.", log='debug', lvl=1)
self._best_metric = val_r
self._epoch_since_best = 0
self._stop_training = False
else:
self._epoch_since_best += 1
self._log(f"No improvement for {self._epoch_since_best} epoch(s). Best: {self._best_metric:.4e}", log='debug', lvl=1)
if self._epoch_since_best >= self._patience:
self._log(f"Patience ({self._patience}) exceeded. Stopping.", log='info', color='yellow')
self._stop_training = True
return self._stop_training
[docs]
@classmethod
def from_kwargs(cls, **kwargs):
"""Create an :class:`EarlyStopping` instance from common keyword names."""
patience = kwargs.get(
"patience",
kwargs.get(
"early_stopping_patience",
kwargs.get("es_patience", kwargs.get("early_patience", 0)),
),
)
min_delta = kwargs.get(
"min_delta",
kwargs.get(
"early_stopping_min_delta",
kwargs.get("es_min_delta", kwargs.get("early_min_delta", 1e-3)),
),
)
logger = kwargs.get("logger", None)
return cls(patience=patience, min_delta=min_delta, logger=logger)
[docs]
def reset(self):
"""Reset the stored best metric and patience counter."""
self._best_metric = _INF
self._epoch_since_best = 0
self._stop_training = False
@property
def best_metric(self) -> float:
"""Best metric observed since initialization or last reset."""
return self._best_metric
# ##############################################################################
#! Base Parameters & Schedulers
# ##############################################################################
[docs]
class SchedulerType(enum.Enum):
"""Supported learning-rate scheduler families."""
CONSTANT = 0
EXPONENTIAL = 1
STEP = 2
COSINE = 3
ADAPTIVE = 4
LINEAR = 5
def __str__(self):
return self.name.lower()
def __repr__(self):
return self.__str__()
[docs]
class Parameters(BaseSchedulerLogger, ABC):
"""Base class for stateful learning-rate schedules."""
[docs]
def __init__(self,
initial_lr : float,
max_epochs : int,
lr_decay : float,
lr_clamp : Optional[float] = None,
logger : Optional['Logger'] = None,
es : Optional[EarlyStopping] = None):
'''
Base class for learning rate schedulers.
Args:
initial_lr:
Initial learning rate.
max_epochs:
Maximum number of epochs.
lr_decay:
Decay rate (meaning depends on scheduler type).
lr_clamp:
Minimum learning rate clamp.
logger:
Optional logger for logging.
es:
Optional EarlyStopping instance.
'''
super().__init__(logger=logger)
if initial_lr <= 0: raise ValueError("Initial LR must be positive.")
if max_epochs <= 0: raise ValueError("Max epochs must be positive.")
self._initial_lr = initial_lr
self._max_epochs = max_epochs
self._lr_decay = lr_decay
self._lr_clamp = lr_clamp
self._early_stopping = es
self._lr = initial_lr
self._lr_history = []
self._typek = None # To be set by subclass
[docs]
@abstractmethod
def __call__(self, _epoch: int, _metric: Optional[Any] = None) -> float:
"""Calculate LR for the epoch."""
pass
def _update_and_log_lr(self, new_lr: float) -> float:
"""Helper to clamp, update state, and log."""
if self._lr_clamp is not None:
new_lr = np.maximum(new_lr, self._lr_clamp)
self._lr = new_lr
self._lr_history.append(self._lr)
return self._lr
#! Common Properties
@property
def lr(self) -> float:
"""Most recently emitted learning rate."""
return self._lr
@property
def history(self) -> List[float]:
"""Learning-rate values emitted so far."""
return self._lr_history
@property
def early_stopping(self) -> Optional[EarlyStopping]:
"""Attached early-stopping monitor, if configured."""
return self._early_stopping
#! ES Proxies
[docs]
def set_early_stopping(self, patience: int, min_delta: float = 1e-3):
"""Attach a new early-stopping monitor to this scheduler."""
self._early_stopping = EarlyStopping(patience, min_delta, self._logger)
[docs]
def check_stop(self, _metric) -> bool:
"""Return whether the attached early-stopping monitor requests stop."""
if self._early_stopping: return self._early_stopping(_metric)
return False
# ##############################################################################
#! CONCRETE SCHEDULERS
# ##############################################################################
[docs]
class ConstantScheduler(Parameters):
"""Scheduler that always returns the initial learning rate."""
[docs]
def __init__(self, initial_lr: float, max_epochs: int, lr_clamp=None, logger=None, es=None, **kwargs):
super().__init__(initial_lr, max_epochs, lr_decay=1.0, lr_clamp=lr_clamp, logger=logger, es=es)
self._typek = SchedulerType.CONSTANT
def __call__(self, _epoch: int, _metric=None) -> float:
return self._update_and_log_lr(self._initial_lr)
# ------------------------------------------------------------------------------
[docs]
class ExponentialDecayScheduler(Parameters):
"""
Multiplicative exponential decay: lr = initial_lr * (gamma ^ epoch)
This follows PyTorch's ExponentialLR convention where:
- gamma = 0.99 means 1% decay per epoch
- gamma = 0.999 means 0.1% decay per epoch
For the old exp(-rate * epoch) behavior, use gamma = exp(-rate).
"""
[docs]
def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float = 0.99, lr_clamp=None, logger=None, es=None, **kwargs):
super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es)
self._typek = SchedulerType.EXPONENTIAL
def __call__(self, _epoch: int, _metric=None) -> float:
# Multiplicative decay: lr = lr_0 * gamma^epoch (like PyTorch ExponentialLR)
current_epoch = max(0, _epoch)
new_lr = self._initial_lr * np.power(self._lr_decay, current_epoch)
return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs]
class LinearScheduler(Parameters):
"""
Linearly decays LR from initial_lr to min_lr (default 0) over max_epochs.
"""
[docs]
def __init__(self, initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs):
# We store min_lr but pass 1.0 as dummy decay
super().__init__(initial_lr, max_epochs, lr_decay=1.0, lr_clamp=lr_clamp, logger=logger, es=es)
self._min_lr = min_lr
self._typek = SchedulerType.LINEAR
def __call__(self, _epoch: int, _metric=None) -> float:
current_epoch = max(0, min(_epoch, self._max_epochs))
alpha = current_epoch / self._max_epochs
# Interpolate: (1-alpha)*start + alpha*end
new_lr = (1 - alpha) * self._initial_lr + alpha * self._min_lr
return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs]
class StepDecayScheduler(Parameters):
""" lr = initial_lr * decay_factor ^ floor(epoch / step_size) """
[docs]
def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float, step_size: int, lr_clamp=None, logger=None, es=None, **kwargs):
super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es)
if step_size <= 0: raise ValueError("Step size must be positive.")
self.step_size = step_size
self._typek = SchedulerType.STEP
def __call__(self, _epoch: int, _metric=None) -> float:
exponent = np.floor(max(0, _epoch) / self.step_size)
new_lr = self._initial_lr * np.power(self._lr_decay, exponent)
return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs]
class CosineAnnealingScheduler(Parameters):
"""Cosine annealing schedule from ``initial_lr`` to ``min_lr``."""
[docs]
def __init__(self, initial_lr: float, max_epochs: int, min_lr: float = 0.0, lr_clamp=None, logger=None, es=None, **kwargs):
super().__init__(initial_lr, max_epochs, lr_decay=0.0, lr_clamp=lr_clamp, logger=logger, es=es)
self.min_lr = min_lr
self._typek = SchedulerType.COSINE
def __call__(self, _epoch: int, _metric=None) -> float:
cur_ep = np.clip(_epoch, 0, self._max_epochs)
if self._max_epochs <= 0: cosine_term = -1.0
else: cosine_term = np.cos(np.pi * cur_ep / self._max_epochs)
lr_range = self._initial_lr - self.min_lr
new_lr = self.min_lr + 0.5 * lr_range * (1.0 + cosine_term)
return self._update_and_log_lr(new_lr)
# ------------------------------------------------------------------------------
[docs]
class AdaptiveScheduler(Parameters):
""" ReduceLROnPlateau logic """
[docs]
def __init__(self, initial_lr: float, max_epochs: int, lr_decay: float=0.1, patience: int=100,
min_lr: float = 1e-5, cooldown: int = 0, min_delta: float = 1e-4,
lr_clamp=None, logger=None, es=None, **kwargs):
super().__init__(initial_lr, max_epochs, lr_decay, lr_clamp, logger, es)
self.patience = patience
self.min_lr = min_lr
self.cooldown = cooldown
self.min_delta = min_delta
self._cooldown_counter = 0
self._best_metric = _INF
self._num_bad_epochs = 0
self._typek = SchedulerType.ADAPTIVE
def __call__(self, _epoch: int, _metric: Optional[Any]) -> float:
if _metric is None: raise ValueError("AdaptiveScheduler requires a metric.")
# Safe metric extraction
if isinstance(_metric, complex):
metric_val = _metric.real
else:
metric_val = float(_metric)
if np.isnan(metric_val) or np.isinf(metric_val):
return self._update_and_log_lr(self._lr)
# Cooldown check
if self._cooldown_counter > 0:
self._cooldown_counter -= 1
self._num_bad_epochs = 0
return self._update_and_log_lr(self._lr)
# Improvement Check
if metric_val < self._best_metric - self.min_delta:
self._best_metric = metric_val
self._num_bad_epochs = 0
else:
self._num_bad_epochs += 1
# Reduction Logic
if self._num_bad_epochs > self.patience:
new_lr = max(self._lr * self._lr_decay, self.min_lr)
if new_lr < self._lr:
self._log(f"Reducing LR to {new_lr:.2e} (Plateau).", log='info', color='yellow')
self._lr = new_lr
self._cooldown_counter = self.cooldown
self._num_bad_epochs = 0
return self._update_and_log_lr(self._lr)
[docs]
def reset(self):
"""Reset plateau-tracking state while keeping scheduler configuration."""
self._cooldown_counter = 0
self._best_metric = _INF
self._num_bad_epochs = 0
# ##############################################################################
#! FACTORY
# ##############################################################################
SCHEDULER_CLASS_MAP = {
SchedulerType.CONSTANT: {'class': ConstantScheduler, 'args': []},
SchedulerType.EXPONENTIAL: {'class': ExponentialDecayScheduler, 'args': ['lr_decay']},
SchedulerType.STEP: {'class': StepDecayScheduler, 'args': ['lr_decay', 'step_size']},
SchedulerType.COSINE: {'class': CosineAnnealingScheduler, 'args': ['min_lr']},
SchedulerType.ADAPTIVE: {'class': AdaptiveScheduler, 'args': ['lr_decay', 'patience', 'min_lr', 'cooldown', 'min_delta']},
SchedulerType.LINEAR: {'class': LinearScheduler, 'args': ['min_lr']},
}
# Add string aliases
SCHEDULER_CLASS_MAP["constant"] = SCHEDULER_CLASS_MAP[SchedulerType.CONSTANT]
SCHEDULER_CLASS_MAP["exponential"] = SCHEDULER_CLASS_MAP[SchedulerType.EXPONENTIAL]
SCHEDULER_CLASS_MAP["step"] = SCHEDULER_CLASS_MAP[SchedulerType.STEP]
SCHEDULER_CLASS_MAP["cosine"] = SCHEDULER_CLASS_MAP[SchedulerType.COSINE]
SCHEDULER_CLASS_MAP["adaptive"] = SCHEDULER_CLASS_MAP[SchedulerType.ADAPTIVE]
SCHEDULER_CLASS_MAP["linear"] = SCHEDULER_CLASS_MAP[SchedulerType.LINEAR]
[docs]
def choose_scheduler(scheduler_type : Union[str, SchedulerType, Parameters],
initial_lr : float,
max_epochs : int,
logger : Optional['Logger'] = None,
**kwargs) -> Parameters:
"""
Factory to create a scheduler instance.
This function can either accept:
- A string or SchedulerType enum to specify the type of scheduler to create.
- An existing Parameters instance to reconfigure.
Args:
scheduler_type: Type of scheduler or existing instance.
initial_lr: Initial learning rate.
max_epochs: Maximum number of epochs.
logger: Optional logger for the scheduler.
**kwargs: Additional arguments specific to the scheduler type.
- lr_decay: Decay rate for exponential/step/adaptive schedulers.
- step_size: Step size for step scheduler.
- min_lr: Minimum learning rate for cosine/linear/adaptive schedulers.
- patience: Patience for adaptive scheduler.
- cooldown: Cooldown period for adaptive scheduler.
- min_delta: Minimum improvement for adaptive scheduler.
- early_stopping_patience: Patience for early stopping.
- early_stopping_min_delta: Minimum improvement for early stopping.
Returns:
An instance of Parameters (scheduler).
Raises:
ValueError: If the scheduler type is unknown.
"""
if kwargs.get('lr_step_size', None) is not None:
kwargs['step_size'] = kwargs.pop('lr_step_size')
# FINAL
if kwargs.get('lr_min', None) is not None:
kwargs['min_lr'] = kwargs.pop('lr_min')
if kwargs.get('lr_final', None) is not None:
kwargs['min_lr'] = kwargs.pop('lr_final')
if kwargs.get('lr_initial', None) is not None:
initial_lr = kwargs.pop('lr_initial')
# DECAY
if kwargs.get('lr_patience', None) is not None:
kwargs['patience'] = kwargs.pop('lr_patience')
if kwargs.get('lr_min_delta', None) is not None:
kwargs['min_delta'] = kwargs.pop('lr_min_delta')
if kwargs.get('lr_cooldown', None) is not None:
kwargs['cooldown'] = kwargs.pop('lr_cooldown')
if kwargs.get('lr_decay_rate', None) is not None:
kwargs['lr_decay'] = kwargs.pop('lr_decay_rate')
# EPOCHS
if kwargs.get('lr_max_epochs', None) is not None:
max_epochs = kwargs.pop('lr_max_epochs')
# EARLY STOPPING
if kwargs.get('lr_es', None) is not None:
kwargs['early_stopping_patience'] = kwargs.pop('lr_es')
if kwargs.get('lr_es_min'):
kwargs['early_stopping_min_delta'] = kwargs.pop('lr_es_min')
# 1. Handle Existing Instance
if isinstance(scheduler_type, Parameters):
if logger: scheduler_type.logger = logger
if kwargs.get('lr_clamp'): scheduler_type._lr_clamp = kwargs['lr_clamp']
# Reconfigure ES if args present
if 'early_stopping_patience' in kwargs:
scheduler_type.set_early_stopping(kwargs['early_stopping_patience'], kwargs.get('early_stopping_min_delta', 1e-4))
return scheduler_type
# 2. Resolve Type
key = scheduler_type.lower() if isinstance(scheduler_type, str) else scheduler_type
config = SCHEDULER_CLASS_MAP.get(key)
if not config:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
cls = config['class']
# 3. Setup Early Stopping
es = None
if 'early_stopping_patience' in kwargs:
es = EarlyStopping(kwargs['early_stopping_patience'], kwargs.get('early_stopping_min_delta', 1e-4), logger)
# 4. Build Args
# Filter kwargs to only what the specific scheduler needs + base args
valid_args = {'initial_lr', 'max_epochs', 'lr_clamp', 'logger', 'es', }
valid_args.update(config['args'])
build_kwargs = {
'initial_lr' : initial_lr,
'max_epochs' : max_epochs,
'logger' : logger,
'es' : es,
'lr_clamp' : kwargs.get('lr_clamp')
}
# Add specific args from kwargs if they exist
for arg in config['args']:
if arg in kwargs:
build_kwargs[arg] = kwargs[arg]
elif arg == 'lr_decay':
build_kwargs['lr_decay'] = 0.99 # Default
try:
if logger:
logger.info(f"Creating scheduler '{scheduler_type}' with args: {build_kwargs}", lvl=3, color='cyan')
return cls(**build_kwargs)
except TypeError as e:
if logger:
logger.say(f"Error creating scheduler '{scheduler_type}': {str(e)}", log='error', color='red')
raise e
###############################################################################
#! End of file
###############################################################################