LightningModule

LightningModule is the base class for all PyTorch Lightning models. It combines a torch.nn.Module with training logic, organizing code into self-contained, reusable units.

Core Hooks

You must implement these methods:

training_step

Called for each batch during training.

def training_step(self, batch, batch_idx: int):
    x, y = batch
    y_hat = self(x)
    loss = F.cross_entropy(y_hat, y)
    self.log("train/loss", loss)
    return loss  # must return scalar loss

configure_optimizers

Return optimizer(s) and optional schedulers.

def configure_optimizers(self):
    # Single optimizer
    return torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.01)

def configure_optimizers(self):
    # Optimizer + scheduler dict
    opt = torch.optim.Adam(self.parameters(), lr=1e-3)
    sched = {
        "scheduler": torch.optim.lr_scheduler.OneCycleLR(
            opt, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        ),
        "interval": "step",
        "frequency": 1,
    }
    return {"optimizer": opt, "lr_scheduler": sched}

def configure_optimizers(self):
    # Multiple optimizers (e.g. GAN)
    g_opt = torch.optim.Adam(self.generator.parameters(), lr=2e-4)
    d_opt = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4)
    return [g_opt, d_opt], []

Optional Hooks

Hook

Description

validation_step(batch, idx)

Called each validation batch.

test_step(batch, idx)

Called each test batch.

predict_step(batch, idx)

Called during prediction.

on_train_epoch_end()

Called at end of training epoch.

on_validation_epoch_end()

Called at end of validation epoch.

on_fit_start()

Called when fit begins.

on_save_checkpoint(ckpt)

Add custom data to checkpoint dict.

on_load_checkpoint(ckpt)

Restore custom data from checkpoint.

Logging

Use self.log() inside any hook:

self.log(
    name="val/f1",
    value=f1_score,
    on_step=False,       # log at each step
    on_epoch=True,       # aggregate and log at epoch end
    prog_bar=True,       # show in progress bar
    logger=True,         # send to loggers (wandb, mlflow, etc.)
    sync_dist=True,      # aggregate across GPUs (required for DDP metrics)
    batch_size=batch.size(0),  # used for correct epoch-level aggregation
)

Hyperparameters

Use self.save_hyperparameters() to automatically store and restore constructor args:

class MyModel(L.LightningModule):
    def __init__(self, lr=1e-3, dropout=0.1, num_layers=4):
        super().__init__()
        self.save_hyperparameters()
        # now available as self.hparams.lr, self.hparams.dropout, ...

# Load from checkpoint -- no need to pass __init__ args again
model = MyModel.load_from_checkpoint("checkpoint.ckpt")

LightningDataModule

LightningDataModule encapsulates all data loading logic:

import lightning as L
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

class CIFAR10DataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./data", batch_size: int = 32):
        super().__init__()
        self.save_hyperparameters()

    def prepare_data(self):
        # Download (called once, on rank 0 only)
        CIFAR10(self.hparams.data_dir, download=True)

    def setup(self, stage: str):
        # Called on every GPU
        if stage == "fit":
            self.train_ds = CIFAR10(self.hparams.data_dir, train=True)
            self.val_ds   = CIFAR10(self.hparams.data_dir, train=False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.hparams.batch_size)