Callbacks ========= Callbacks allow you to hook into every part of the training loop without modifying model code. Built-in Callbacks ------------------ ModelCheckpoint ~~~~~~~~~~~~~~~ Saves model checkpoints based on a monitored metric. .. code-block:: python from lightning.pytorch.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="{epoch:02d}-{val_acc:.4f}", monitor="val/acc", mode="max", # "min" or "max" save_top_k=3, # keep best 3 save_last=True, # always save last epoch every_n_epochs=1, verbose=True, ) trainer = Trainer(callbacks=[checkpoint_callback]) trainer.fit(model, ...) print(checkpoint_callback.best_model_path) print(checkpoint_callback.best_model_score) EarlyStopping ~~~~~~~~~~~~~ Stops training when a metric stops improving. .. code-block:: python from lightning.pytorch.callbacks import EarlyStopping early_stop = EarlyStopping( monitor="val/loss", patience=10, # epochs to wait min_delta=1e-4, # minimum improvement mode="min", strict=True, # fail if metric not found check_finite=True, # stop if NaN/Inf stopping_threshold=0.001, # stop immediately if metric below this ) LearningRateMonitor ~~~~~~~~~~~~~~~~~~~ Logs learning rate(s) automatically: .. code-block:: python from lightning.pytorch.callbacks import LearningRateMonitor lr_monitor = LearningRateMonitor(logging_interval="step") RichProgressBar ~~~~~~~~~~~~~~~ A beautiful progress bar using the ``rich`` library: .. code-block:: python from lightning.pytorch.callbacks import RichProgressBar trainer = Trainer(callbacks=[RichProgressBar(leave=True)]) Custom Callbacks ---------------- Subclass :class:`~lightning.pytorch.callbacks.Callback` and override any hook: .. code-block:: python from lightning.pytorch.callbacks import Callback class GradientNormLogger(Callback): def on_before_optimizer_step(self, trainer, pl_module, optimizer): total_norm = 0.0 for p in pl_module.parameters(): if p.grad is not None: total_norm += p.grad.data.norm(2).item() ** 2 total_norm = total_norm ** 0.5 pl_module.log("grad/total_norm", total_norm) class LRWarmupCallback(Callback): def __init__(self, warmup_steps: int = 1000): self.warmup_steps = warmup_steps def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): step = trainer.global_step if step < self.warmup_steps: lr_scale = min(1.0, step / self.warmup_steps) for pg in trainer.optimizers[0].param_groups: pg["lr"] = pg["initial_lr"] * lr_scale All Available Hooks ------------------- .. code-block:: python class Callback: # Fit def on_fit_start(self, trainer, pl_module): ... def on_fit_end(self, trainer, pl_module): ... # Train epoch def on_train_epoch_start(self, trainer, pl_module): ... def on_train_epoch_end(self, trainer, pl_module): ... # Train batch def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): ... def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): ... # Validation epoch def on_validation_epoch_start(self, trainer, pl_module): ... def on_validation_epoch_end(self, trainer, pl_module): ... # Optimizer def on_before_optimizer_step(self, trainer, pl_module, optimizer): ... def on_before_zero_grad(self, trainer, pl_module, optimizer): ... # Checkpointing def on_save_checkpoint(self, trainer, pl_module, checkpoint): ... def on_load_checkpoint(self, trainer, pl_module, checkpoint): ... # Exception def on_exception(self, trainer, pl_module, exception): ...