Trainer

The Trainer orchestrates the full training loop, handling device placement, gradient accumulation, mixed precision, checkpointing, and logging.

Constructor

lightning.pytorch.Trainer(
    # Training duration
    max_epochs: int = 1000,
    max_steps: int = -1,
    min_epochs: int = 1,

    # Hardware
    accelerator: str = "auto",    # "cpu", "gpu", "mps", "tpu", "auto"
    devices: Union[int, str, List[int]] = "auto",
    strategy: str = "auto",       # "ddp", "fsdp", "deepspeed_stage_2", ...
    precision: Union[int, str] = 32,  # 16, 32, "bf16-mixed", "16-mixed"

    # Callbacks & logging
    callbacks: Optional[List[Callback]] = None,
    logger: Optional[Union[Logger, List[Logger]]] = None,
    log_every_n_steps: int = 50,
    enable_progress_bar: bool = True,

    # Gradient management
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: str = "norm",

    # Validation
    val_check_interval: Union[int, float] = 1.0,
    check_val_every_n_epoch: int = 1,
    num_sanity_val_steps: int = 2,

    # Checkpointing
    enable_checkpointing: bool = True,
    default_root_dir: Optional[str] = None,

    # Reproducibility
    deterministic: bool = False,
)

Key Parameters

accelerator

Controls which hardware backend is used.

Value

Description

"auto"

Automatically selects GPU > MPS > CPU in that order.

"gpu"

CUDA GPU. Requires a PyTorch CUDA build.

"mps"

Apple Silicon GPU via Metal Performance Shaders.

"cpu"

CPU only.

"tpu"

Google TPU via torch_xla.

precision

Value

Description

32

Full 32-bit float (default).

"16-mixed"

AMP with float16. Faster on Volta+ GPUs.

"bf16-mixed"

AMP with bfloat16. Numerically more stable; requires Ampere+ or TPU.

"16-true"

Full float16 (no master weights). Use with caution.

strategy

Distributed training strategies:

# Data parallel (most common)
Trainer(strategy="ddp", devices=4)

# Fully Sharded Data Parallel (large models)
Trainer(strategy="fsdp", devices=8)

# DeepSpeed ZeRO stage 3 (very large models, 100B+)
Trainer(strategy="deepspeed_stage_3", devices=8, precision="bf16-mixed")

Methods

fit

trainer.fit(
    model: LightningModule,
    train_dataloaders=None,
    val_dataloaders=None,
    datamodule=None,
    ckpt_path=None,     # resume from checkpoint
)

validate

results = trainer.validate(model, dataloaders=val_loader)
# [{"val/loss": 0.0312, "val/acc": 0.9918}]

test

results = trainer.test(model, dataloaders=test_loader, ckpt_path="best")

predict

predictions = trainer.predict(model, dataloaders=predict_loader)

Callbacks Hooks Order

During training, the Trainer calls hooks in the following order:

on_fit_start
├── on_train_start
│   └── [for each epoch]
│       ├── on_train_epoch_start
│       │   └── [for each batch]
│       │       ├── on_before_batch_transfer
│       │       ├── on_train_batch_start
│       │       ├── training_step
│       │       ├── on_before_optimizer_step
│       │       ├── optimizer.step()
│       │       └── on_train_batch_end
│       └── on_train_epoch_end
└── on_train_end
on_fit_end