Callbacks¶
Callbacks allow you to hook into every part of the training loop without modifying model code.
Built-in Callbacks¶
ModelCheckpoint¶
Saves model checkpoints based on a monitored metric.
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch:02d}-{val_acc:.4f}",
monitor="val/acc",
mode="max", # "min" or "max"
save_top_k=3, # keep best 3
save_last=True, # always save last epoch
every_n_epochs=1,
verbose=True,
)
trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, ...)
print(checkpoint_callback.best_model_path)
print(checkpoint_callback.best_model_score)
EarlyStopping¶
Stops training when a metric stops improving.
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor="val/loss",
patience=10, # epochs to wait
min_delta=1e-4, # minimum improvement
mode="min",
strict=True, # fail if metric not found
check_finite=True, # stop if NaN/Inf
stopping_threshold=0.001, # stop immediately if metric below this
)
LearningRateMonitor¶
Logs learning rate(s) automatically:
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval="step")
RichProgressBar¶
A beautiful progress bar using the rich library:
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=[RichProgressBar(leave=True)])
Custom Callbacks¶
Subclass Callback and override any hook:
from lightning.pytorch.callbacks import Callback
class GradientNormLogger(Callback):
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
total_norm = 0.0
for p in pl_module.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
pl_module.log("grad/total_norm", total_norm)
class LRWarmupCallback(Callback):
def __init__(self, warmup_steps: int = 1000):
self.warmup_steps = warmup_steps
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
step = trainer.global_step
if step < self.warmup_steps:
lr_scale = min(1.0, step / self.warmup_steps)
for pg in trainer.optimizers[0].param_groups:
pg["lr"] = pg["initial_lr"] * lr_scale
All Available Hooks¶
class Callback:
# Fit
def on_fit_start(self, trainer, pl_module): ...
def on_fit_end(self, trainer, pl_module): ...
# Train epoch
def on_train_epoch_start(self, trainer, pl_module): ...
def on_train_epoch_end(self, trainer, pl_module): ...
# Train batch
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): ...
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): ...
# Validation epoch
def on_validation_epoch_start(self, trainer, pl_module): ...
def on_validation_epoch_end(self, trainer, pl_module): ...
# Optimizer
def on_before_optimizer_step(self, trainer, pl_module, optimizer): ...
def on_before_zero_grad(self, trainer, pl_module, optimizer): ...
# Checkpointing
def on_save_checkpoint(self, trainer, pl_module, checkpoint): ...
def on_load_checkpoint(self, trainer, pl_module, checkpoint): ...
# Exception
def on_exception(self, trainer, pl_module, exception): ...