LightningModule¶
LightningModule is the base class for all PyTorch Lightning models.
It combines a torch.nn.Module with training logic, organizing code into
self-contained, reusable units.
Core Hooks¶
You must implement these methods:
training_step¶
Called for each batch during training.
def training_step(self, batch, batch_idx: int):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train/loss", loss)
return loss # must return scalar loss
configure_optimizers¶
Return optimizer(s) and optional schedulers.
def configure_optimizers(self):
# Single optimizer
return torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.01)
def configure_optimizers(self):
# Optimizer + scheduler dict
opt = torch.optim.Adam(self.parameters(), lr=1e-3)
sched = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
opt, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
),
"interval": "step",
"frequency": 1,
}
return {"optimizer": opt, "lr_scheduler": sched}
def configure_optimizers(self):
# Multiple optimizers (e.g. GAN)
g_opt = torch.optim.Adam(self.generator.parameters(), lr=2e-4)
d_opt = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4)
return [g_opt, d_opt], []
Optional Hooks¶
Hook |
Description |
|---|---|
|
Called each validation batch. |
|
Called each test batch. |
|
Called during prediction. |
|
Called at end of training epoch. |
|
Called at end of validation epoch. |
|
Called when fit begins. |
|
Add custom data to checkpoint dict. |
|
Restore custom data from checkpoint. |
Logging¶
Use self.log() inside any hook:
self.log(
name="val/f1",
value=f1_score,
on_step=False, # log at each step
on_epoch=True, # aggregate and log at epoch end
prog_bar=True, # show in progress bar
logger=True, # send to loggers (wandb, mlflow, etc.)
sync_dist=True, # aggregate across GPUs (required for DDP metrics)
batch_size=batch.size(0), # used for correct epoch-level aggregation
)
Hyperparameters¶
Use self.save_hyperparameters() to automatically store and restore constructor args:
class MyModel(L.LightningModule):
def __init__(self, lr=1e-3, dropout=0.1, num_layers=4):
super().__init__()
self.save_hyperparameters()
# now available as self.hparams.lr, self.hparams.dropout, ...
# Load from checkpoint -- no need to pass __init__ args again
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
LightningDataModule¶
LightningDataModule encapsulates all data loading logic:
import lightning as L
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
class CIFAR10DataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./data", batch_size: int = 32):
super().__init__()
self.save_hyperparameters()
def prepare_data(self):
# Download (called once, on rank 0 only)
CIFAR10(self.hparams.data_dir, download=True)
def setup(self, stage: str):
# Called on every GPU
if stage == "fit":
self.train_ds = CIFAR10(self.hparams.data_dir, train=True)
self.val_ds = CIFAR10(self.hparams.data_dir, train=False)
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.hparams.batch_size)