Trainer ======= The :class:`~lightning.pytorch.Trainer` orchestrates the full training loop, handling device placement, gradient accumulation, mixed precision, checkpointing, and logging. .. contents:: On this page :local: :depth: 2 Constructor ----------- .. code-block:: python lightning.pytorch.Trainer( # Training duration max_epochs: int = 1000, max_steps: int = -1, min_epochs: int = 1, # Hardware accelerator: str = "auto", # "cpu", "gpu", "mps", "tpu", "auto" devices: Union[int, str, List[int]] = "auto", strategy: str = "auto", # "ddp", "fsdp", "deepspeed_stage_2", ... precision: Union[int, str] = 32, # 16, 32, "bf16-mixed", "16-mixed" # Callbacks & logging callbacks: Optional[List[Callback]] = None, logger: Optional[Union[Logger, List[Logger]]] = None, log_every_n_steps: int = 50, enable_progress_bar: bool = True, # Gradient management accumulate_grad_batches: int = 1, gradient_clip_val: Optional[float] = None, gradient_clip_algorithm: str = "norm", # Validation val_check_interval: Union[int, float] = 1.0, check_val_every_n_epoch: int = 1, num_sanity_val_steps: int = 2, # Checkpointing enable_checkpointing: bool = True, default_root_dir: Optional[str] = None, # Reproducibility deterministic: bool = False, ) Key Parameters -------------- accelerator ~~~~~~~~~~~ Controls which hardware backend is used. .. list-table:: :header-rows: 1 :widths: 20 80 * - Value - Description * - ``"auto"`` - Automatically selects GPU > MPS > CPU in that order. * - ``"gpu"`` - CUDA GPU. Requires a PyTorch CUDA build. * - ``"mps"`` - Apple Silicon GPU via Metal Performance Shaders. * - ``"cpu"`` - CPU only. * - ``"tpu"`` - Google TPU via ``torch_xla``. precision ~~~~~~~~~ .. list-table:: :header-rows: 1 :widths: 20 80 * - Value - Description * - ``32`` - Full 32-bit float (default). * - ``"16-mixed"`` - AMP with float16. Faster on Volta+ GPUs. * - ``"bf16-mixed"`` - AMP with bfloat16. Numerically more stable; requires Ampere+ or TPU. * - ``"16-true"`` - Full float16 (no master weights). Use with caution. strategy ~~~~~~~~ Distributed training strategies: .. code-block:: python # Data parallel (most common) Trainer(strategy="ddp", devices=4) # Fully Sharded Data Parallel (large models) Trainer(strategy="fsdp", devices=8) # DeepSpeed ZeRO stage 3 (very large models, 100B+) Trainer(strategy="deepspeed_stage_3", devices=8, precision="bf16-mixed") Methods ------- fit ~~~ .. code-block:: python trainer.fit( model: LightningModule, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None, # resume from checkpoint ) validate ~~~~~~~~ .. code-block:: python results = trainer.validate(model, dataloaders=val_loader) # [{"val/loss": 0.0312, "val/acc": 0.9918}] test ~~~~ .. code-block:: python results = trainer.test(model, dataloaders=test_loader, ckpt_path="best") predict ~~~~~~~ .. code-block:: python predictions = trainer.predict(model, dataloaders=predict_loader) Callbacks Hooks Order --------------------- During training, the Trainer calls hooks in the following order: .. code-block:: text on_fit_start ├── on_train_start │ └── [for each epoch] │ ├── on_train_epoch_start │ │ └── [for each batch] │ │ ├── on_before_batch_transfer │ │ ├── on_train_batch_start │ │ ├── training_step │ │ ├── on_before_optimizer_step │ │ ├── optimizer.step() │ │ └── on_train_batch_end │ └── on_train_epoch_end └── on_train_end on_fit_end