Quick Start

This guide walks you through training your first model with PyTorch Lightning in under 5 minutes.

Defining a LightningModule

Subclass LightningModule and implement the required methods:

import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L


class MNISTClassifier(L.LightningModule):
    def __init__(self, hidden_size: int = 256, learning_rate: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_epochs
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

Training

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(".", train=True, download=True, transform=transform)
val_data   = datasets.MNIST(".", train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_data, batch_size=256, num_workers=4)

# Callbacks
callbacks = [
    EarlyStopping(monitor="val/loss", patience=5),
    ModelCheckpoint(monitor="val/acc", mode="max", save_top_k=3),
]

# Trainer
trainer = Trainer(
    max_epochs=50,
    accelerator="auto",       # auto-detects GPU/CPU/MPS
    devices="auto",
    callbacks=callbacks,
    log_every_n_steps=10,
)

model = MNISTClassifier(hidden_size=512, learning_rate=3e-4)
trainer.fit(model, train_loader, val_loader)

print(f"Best val accuracy: {trainer.callback_metrics[val/acc]:.4f}")

Inference

# Load from checkpoint
model = MNISTClassifier.load_from_checkpoint("checkpoints/epoch=12-step=9750.ckpt")
model.eval()
model.freeze()

with torch.no_grad():
    prediction = model(test_image.unsqueeze(0)).argmax(dim=1)

Multi-GPU Training

PyTorch Lightning handles distributed training transparently:

trainer = Trainer(
    max_epochs=100,
    accelerator="gpu",
    devices=4,               # use 4 GPUs
    strategy="ddp",          # DistributedDataParallel
)

No changes to the LightningModule code are required.

See also

Trainer — full Trainer API reference.