LightningModule =============== :class:`~lightning.pytorch.LightningModule` is the base class for all PyTorch Lightning models. It combines a ``torch.nn.Module`` with training logic, organizing code into self-contained, reusable units. Core Hooks ---------- You must implement these methods: training_step ~~~~~~~~~~~~~ Called for each batch during training. .. code-block:: python def training_step(self, batch, batch_idx: int): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log("train/loss", loss) return loss # must return scalar loss configure_optimizers ~~~~~~~~~~~~~~~~~~~~ Return optimizer(s) and optional schedulers. .. code-block:: python def configure_optimizers(self): # Single optimizer return torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.01) def configure_optimizers(self): # Optimizer + scheduler dict opt = torch.optim.Adam(self.parameters(), lr=1e-3) sched = { "scheduler": torch.optim.lr_scheduler.OneCycleLR( opt, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ), "interval": "step", "frequency": 1, } return {"optimizer": opt, "lr_scheduler": sched} def configure_optimizers(self): # Multiple optimizers (e.g. GAN) g_opt = torch.optim.Adam(self.generator.parameters(), lr=2e-4) d_opt = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4) return [g_opt, d_opt], [] Optional Hooks -------------- .. list-table:: :header-rows: 1 :widths: 35 65 * - Hook - Description * - ``validation_step(batch, idx)`` - Called each validation batch. * - ``test_step(batch, idx)`` - Called each test batch. * - ``predict_step(batch, idx)`` - Called during prediction. * - ``on_train_epoch_end()`` - Called at end of training epoch. * - ``on_validation_epoch_end()`` - Called at end of validation epoch. * - ``on_fit_start()`` - Called when fit begins. * - ``on_save_checkpoint(ckpt)`` - Add custom data to checkpoint dict. * - ``on_load_checkpoint(ckpt)`` - Restore custom data from checkpoint. Logging ------- Use ``self.log()`` inside any hook: .. code-block:: python self.log( name="val/f1", value=f1_score, on_step=False, # log at each step on_epoch=True, # aggregate and log at epoch end prog_bar=True, # show in progress bar logger=True, # send to loggers (wandb, mlflow, etc.) sync_dist=True, # aggregate across GPUs (required for DDP metrics) batch_size=batch.size(0), # used for correct epoch-level aggregation ) Hyperparameters --------------- Use ``self.save_hyperparameters()`` to automatically store and restore constructor args: .. code-block:: python class MyModel(L.LightningModule): def __init__(self, lr=1e-3, dropout=0.1, num_layers=4): super().__init__() self.save_hyperparameters() # now available as self.hparams.lr, self.hparams.dropout, ... # Load from checkpoint -- no need to pass __init__ args again model = MyModel.load_from_checkpoint("checkpoint.ckpt") LightningDataModule ------------------- :class:`~lightning.pytorch.LightningDataModule` encapsulates all data loading logic: .. code-block:: python import lightning as L from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader class CIFAR10DataModule(L.LightningDataModule): def __init__(self, data_dir: str = "./data", batch_size: int = 32): super().__init__() self.save_hyperparameters() def prepare_data(self): # Download (called once, on rank 0 only) CIFAR10(self.hparams.data_dir, download=True) def setup(self, stage: str): # Called on every GPU if stage == "fit": self.train_ds = CIFAR10(self.hparams.data_dir, train=True) self.val_ds = CIFAR10(self.hparams.data_dir, train=False) def train_dataloader(self): return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.val_ds, batch_size=self.hparams.batch_size)