Trainer¶
The Trainer orchestrates the full training loop, handling device
placement, gradient accumulation, mixed precision, checkpointing, and logging.
Constructor¶
lightning.pytorch.Trainer(
# Training duration
max_epochs: int = 1000,
max_steps: int = -1,
min_epochs: int = 1,
# Hardware
accelerator: str = "auto", # "cpu", "gpu", "mps", "tpu", "auto"
devices: Union[int, str, List[int]] = "auto",
strategy: str = "auto", # "ddp", "fsdp", "deepspeed_stage_2", ...
precision: Union[int, str] = 32, # 16, 32, "bf16-mixed", "16-mixed"
# Callbacks & logging
callbacks: Optional[List[Callback]] = None,
logger: Optional[Union[Logger, List[Logger]]] = None,
log_every_n_steps: int = 50,
enable_progress_bar: bool = True,
# Gradient management
accumulate_grad_batches: int = 1,
gradient_clip_val: Optional[float] = None,
gradient_clip_algorithm: str = "norm",
# Validation
val_check_interval: Union[int, float] = 1.0,
check_val_every_n_epoch: int = 1,
num_sanity_val_steps: int = 2,
# Checkpointing
enable_checkpointing: bool = True,
default_root_dir: Optional[str] = None,
# Reproducibility
deterministic: bool = False,
)
Key Parameters¶
accelerator¶
Controls which hardware backend is used.
Value |
Description |
|---|---|
|
Automatically selects GPU > MPS > CPU in that order. |
|
CUDA GPU. Requires a PyTorch CUDA build. |
|
Apple Silicon GPU via Metal Performance Shaders. |
|
CPU only. |
|
Google TPU via |
precision¶
Value |
Description |
|---|---|
|
Full 32-bit float (default). |
|
AMP with float16. Faster on Volta+ GPUs. |
|
AMP with bfloat16. Numerically more stable; requires Ampere+ or TPU. |
|
Full float16 (no master weights). Use with caution. |
strategy¶
Distributed training strategies:
# Data parallel (most common)
Trainer(strategy="ddp", devices=4)
# Fully Sharded Data Parallel (large models)
Trainer(strategy="fsdp", devices=8)
# DeepSpeed ZeRO stage 3 (very large models, 100B+)
Trainer(strategy="deepspeed_stage_3", devices=8, precision="bf16-mixed")
Methods¶
fit¶
trainer.fit(
model: LightningModule,
train_dataloaders=None,
val_dataloaders=None,
datamodule=None,
ckpt_path=None, # resume from checkpoint
)
validate¶
results = trainer.validate(model, dataloaders=val_loader)
# [{"val/loss": 0.0312, "val/acc": 0.9918}]
test¶
results = trainer.test(model, dataloaders=test_loader, ckpt_path="best")
predict¶
predictions = trainer.predict(model, dataloaders=predict_loader)
Callbacks Hooks Order¶
During training, the Trainer calls hooks in the following order:
on_fit_start
├── on_train_start
│ └── [for each epoch]
│ ├── on_train_epoch_start
│ │ └── [for each batch]
│ │ ├── on_before_batch_transfer
│ │ ├── on_train_batch_start
│ │ ├── training_step
│ │ ├── on_before_optimizer_step
│ │ ├── optimizer.step()
│ │ └── on_train_batch_end
│ └── on_train_epoch_end
└── on_train_end
on_fit_end