Callbacks

Callbacks allow you to hook into every part of the training loop without modifying model code.

Built-in Callbacks

ModelCheckpoint

Saves model checkpoints based on a monitored metric.

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="{epoch:02d}-{val_acc:.4f}",
    monitor="val/acc",
    mode="max",          # "min" or "max"
    save_top_k=3,        # keep best 3
    save_last=True,      # always save last epoch
    every_n_epochs=1,
    verbose=True,
)

trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, ...)

print(checkpoint_callback.best_model_path)
print(checkpoint_callback.best_model_score)

EarlyStopping

Stops training when a metric stops improving.

from lightning.pytorch.callbacks import EarlyStopping

early_stop = EarlyStopping(
    monitor="val/loss",
    patience=10,          # epochs to wait
    min_delta=1e-4,       # minimum improvement
    mode="min",
    strict=True,          # fail if metric not found
    check_finite=True,    # stop if NaN/Inf
    stopping_threshold=0.001,  # stop immediately if metric below this
)

LearningRateMonitor

Logs learning rate(s) automatically:

from lightning.pytorch.callbacks import LearningRateMonitor

lr_monitor = LearningRateMonitor(logging_interval="step")

RichProgressBar

A beautiful progress bar using the rich library:

from lightning.pytorch.callbacks import RichProgressBar

trainer = Trainer(callbacks=[RichProgressBar(leave=True)])

Custom Callbacks

Subclass Callback and override any hook:

from lightning.pytorch.callbacks import Callback

class GradientNormLogger(Callback):
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        total_norm = 0.0
        for p in pl_module.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        total_norm = total_norm ** 0.5
        pl_module.log("grad/total_norm", total_norm)

class LRWarmupCallback(Callback):
    def __init__(self, warmup_steps: int = 1000):
        self.warmup_steps = warmup_steps

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        step = trainer.global_step
        if step < self.warmup_steps:
            lr_scale = min(1.0, step / self.warmup_steps)
            for pg in trainer.optimizers[0].param_groups:
                pg["lr"] = pg["initial_lr"] * lr_scale

All Available Hooks

class Callback:
    # Fit
    def on_fit_start(self, trainer, pl_module): ...
    def on_fit_end(self, trainer, pl_module): ...

    # Train epoch
    def on_train_epoch_start(self, trainer, pl_module): ...
    def on_train_epoch_end(self, trainer, pl_module): ...

    # Train batch
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): ...
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): ...

    # Validation epoch
    def on_validation_epoch_start(self, trainer, pl_module): ...
    def on_validation_epoch_end(self, trainer, pl_module): ...

    # Optimizer
    def on_before_optimizer_step(self, trainer, pl_module, optimizer): ...
    def on_before_zero_grad(self, trainer, pl_module, optimizer): ...

    # Checkpointing
    def on_save_checkpoint(self, trainer, pl_module, checkpoint): ...
    def on_load_checkpoint(self, trainer, pl_module, checkpoint): ...

    # Exception
    def on_exception(self, trainer, pl_module, exception): ...