Quick Start¶
This guide walks you through training your first model with PyTorch Lightning in under 5 minutes.
Defining a LightningModule¶
Subclass LightningModule and implement the required methods:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
class MNISTClassifier(L.LightningModule):
def __init__(self, hidden_size: int = 256, learning_rate: float = 1e-3):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, 10),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log("train/loss", loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val/loss", loss, prog_bar=True)
self.log("val/acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.trainer.max_epochs
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
Training¶
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(".", train=True, download=True, transform=transform)
val_data = datasets.MNIST(".", train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=256, num_workers=4)
# Callbacks
callbacks = [
EarlyStopping(monitor="val/loss", patience=5),
ModelCheckpoint(monitor="val/acc", mode="max", save_top_k=3),
]
# Trainer
trainer = Trainer(
max_epochs=50,
accelerator="auto", # auto-detects GPU/CPU/MPS
devices="auto",
callbacks=callbacks,
log_every_n_steps=10,
)
model = MNISTClassifier(hidden_size=512, learning_rate=3e-4)
trainer.fit(model, train_loader, val_loader)
print(f"Best val accuracy: {trainer.callback_metrics[val/acc]:.4f}")
Inference¶
# Load from checkpoint
model = MNISTClassifier.load_from_checkpoint("checkpoints/epoch=12-step=9750.ckpt")
model.eval()
model.freeze()
with torch.no_grad():
prediction = model(test_image.unsqueeze(0)).argmax(dim=1)
Multi-GPU Training¶
PyTorch Lightning handles distributed training transparently:
trainer = Trainer(
max_epochs=100,
accelerator="gpu",
devices=4, # use 4 GPUs
strategy="ddp", # DistributedDataParallel
)
No changes to the LightningModule code are required.
See also
Trainer — full Trainer API reference.