Checkpointing

Prerequisites

Make sure to read the following sections of the documentation before using this example:

The full source code for this example is available on the mila-docs GitHub repository.

job.sh

 # distributed/single_gpu/job.sh -> good_practices/checkpointing/job.sh
old mode 100644
new mode 100755
 #!/bin/bash
-#SBATCH --gpus-per-task=rtx8000:1
+#SBATCH --gpus-per-task=1
 #SBATCH --cpus-per-task=4
 #SBATCH --ntasks-per-node=1
 #SBATCH --mem=16G
 #SBATCH --time=00:15:00
-
+#SBATCH --requeue
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
+                            # min before its time ends to give it a chance for
+                            # better cleanup. If you cancel the job manually,
+                            # make sure that you specify the signal as TERM like
+                            # so `scancel --signal=TERM <jobid>`.
+                            # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/
 
 # Echo time and hostname into log
 echo "Date:     $(date)"
 echo "Hostname: $(hostname)"
 
 
 # Ensure only anaconda/3 module loaded.
 module --quiet purge
 # This example uses Conda to manage package dependencies.
 # See https://docs.mila.quebec/Userguide.html#conda for more information.
 module load anaconda/3
 module load cuda/11.7
 
+
 # Creating the environment for the first time:
 # conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
-#     pytorch-cuda=11.7 -c pytorch -c nvidia
+#     pytorch-cuda=11.7 scipy -c pytorch -c nvidia
 # Other conda packages:
 # conda install -y -n pytorch -c conda-forge rich tqdm
 
 # Activate pre-existing environment.
 conda activate pytorch
 
 
 # Stage dataset into $SLURM_TMPDIR
 mkdir -p $SLURM_TMPDIR/data
-cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
+# Use --update to only copy newer files (since this might have already been executed)
+cp --update /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
 # General-purpose alternatives combining copy and unpack:
 #     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
 #     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
 
 
 # Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
 unset CUDA_VISIBLE_DEVICES
 
 # Execute Python script
-python main.py
+exec python main.py

main.py

 # distributed/single_gpu/main.py -> good_practices/checkpointing/main.py
-"""Single-GPU training example."""
+"""Checkpointing example."""
+from __future__ import annotations
+
 import argparse
 import logging
 import os
+import random
+import shutil
+import signal
+import uuid
+import warnings
+from logging import getLogger as get_logger
 from pathlib import Path
+from types import FrameType
+from typing import Any, TypedDict
 
+import numpy
 import rich.logging
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F
 from torch.utils.data import DataLoader, random_split
 from torchvision import transforms
 from torchvision.datasets import CIFAR10
 from torchvision.models import resnet18
 from tqdm import tqdm
 
+SCRATCH = Path(os.environ["SCRATCH"])
+SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"])
+SLURM_JOBID = os.environ["SLURM_JOBID"]
+
+CHECKPOINT_FILE_NAME = "checkpoint.pth"
+
+logger = get_logger(__name__)
+
+
+class RunState(TypedDict):
+    """Typed dictionary containing the state of the training run which is saved at each epoch.
+
+    Using type hints helps prevent bugs and makes your code easier to read for both humans and
+    machines (e.g. Copilot). This leads to less time spent debugging and better code suggestions.
+    """
+
+    epoch: int
+    best_acc: float
+    model_state: dict[str, Tensor]
+    optimizer_state: dict[str, Tensor]
+
+    random_state: tuple[Any, ...]
+    numpy_random_state: dict[str, Any]
+    torch_random_state: Tensor
+    torch_cuda_random_state: list[Tensor]
+
 
 def main():
     # Use an argument parser so we can pass hyperparameters from the command line.
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--epochs", type=int, default=10)
     parser.add_argument("--learning-rate", type=float, default=5e-4)
     parser.add_argument("--weight-decay", type=float, default=1e-4)
     parser.add_argument("--batch-size", type=int, default=128)
+    parser.add_argument(
+        "--run-dir", type=Path, default=SCRATCH / "checkpointing_example" / SLURM_JOBID
+    )
+    parser.add_argument("--random-seed", type=int, default=123)
     args = parser.parse_args()
 
     epochs: int = args.epochs
     learning_rate: float = args.learning_rate
     weight_decay: float = args.weight_decay
     batch_size: int = args.batch_size
+    run_dir: Path = args.run_dir
+    random_seed: int = args.random_seed
+
+    checkpoint_dir = run_dir / "checkpoints"
+    start_epoch: int = 0
+    best_acc: float = 0.0
 
     # Check that the GPU is available
     assert torch.cuda.is_available() and torch.cuda.device_count() > 0
     device = torch.device("cuda", 0)
 
+    # Seed the random number generators as early as possible.
+    random.seed(random_seed)
+    numpy.random.seed(random_seed)
+    torch.random.manual_seed(random_seed)
+    torch.cuda.manual_seed_all(random_seed)
+
     # Setup logging (optional, but much better than using print statements)
     logging.basicConfig(
         level=logging.INFO,
+        format="%(message)s",
         handlers=[rich.logging.RichHandler(markup=True)],  # Very pretty, uses the `rich` package.
     )
 
-    logger = logging.getLogger(__name__)
-
-    # Create a model and move it to the GPU.
+    # Create a model.
     model = resnet18(num_classes=10)
+
+    # Move the model to the GPU.
     model.to(device=device)
 
     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
 
-    # Setup CIFAR10
+    # Try to resume from a checkpoint, if one exists.
+    checkpoint: RunState | None = load_checkpoint(checkpoint_dir, map_location=device)
+    if checkpoint:
+        start_epoch = checkpoint["epoch"] + 1  # +1 to start at the next epoch.
+        best_acc = checkpoint["best_acc"]
+        model.load_state_dict(checkpoint["model_state"])
+        optimizer.load_state_dict(checkpoint["optimizer_state"])
+        random.setstate(checkpoint["random_state"])
+        numpy.random.set_state(checkpoint["numpy_random_state"])
+        # NOTE: Need to move those tensors to CPU before they can be loaded.
+        torch.random.set_rng_state(checkpoint["torch_random_state"].cpu())
+        torch.cuda.random.set_rng_state_all(t.cpu() for t in checkpoint["torch_cuda_random_state"])
+        logger.info(f"Resuming training at epoch {start_epoch} (best_acc={best_acc:.2%}).")
+    else:
+        logger.info(f"No checkpoints found in {checkpoint_dir}. Training from scratch.")
+
+    # Setup the dataset
     num_workers = get_num_workers()
-    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
+    dataset_path = (SLURM_TMPDIR or Path("..")) / "data"
+
     train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
     train_dataloader = DataLoader(
         train_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=True,
+        # generator=torch.Generator().manual_seed(random_seed),
     )
     valid_dataloader = DataLoader(
         valid_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        # generator=torch.Generator().manual_seed(random_seed),
     )
-    test_dataloader = DataLoader(  # NOTE: Not used in this example.
+    test_dataloader = DataLoader(  # NOTE: Not used in this example.  # noqa
         test_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
     )
 
-    # Checkout the "checkpointing and preemption" example for more info!
-    logger.debug("Starting training from scratch.")
+    def signal_handler(signum: int, frame: FrameType | None):
+        """Called before the job gets pre-empted or reaches the time-limit.
 
-    for epoch in range(epochs):
+        This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
+        """
+        signal_enum = signal.Signals(signum)
+        logger.error(f"Job received a {signal_enum.name} signal!")
+        # Perform quick actions that will help the job resume later.
+        # If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
+        # if wandb.run:
+        #     wandb.mark_preempting()
+
+    signal.signal(signal.SIGTERM, signal_handler)  # Before getting pre-empted and requeued.
+    signal.signal(signal.SIGUSR1, signal_handler)  # Before reaching the end of the time limit.
+
+    for epoch in range(start_epoch, epochs):
         logger.debug(f"Starting epoch {epoch}/{epochs}")
 
-        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
+        # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
         model.train()
 
-        # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
+        # NOTE: using a progress bar from tqdm much nicer than using `print`s).
         progress_bar = tqdm(
             total=len(train_dataloader),
             desc=f"Train epoch {epoch}",
+            unit_scale=train_dataloader.batch_size or 1,
+            unit="samples",
         )
 
         # Training loop
+        batch: tuple[Tensor, Tensor]
         for batch in train_dataloader:
             # Move the batch to the GPU before we pass it to the model
             batch = tuple(item.to(device) for item in batch)
             x, y = batch
 
             # Forward pass
             logits: Tensor = model(x)
 
             loss = F.cross_entropy(logits, y)
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
 
             # Calculate some metrics:
             n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
             n_samples = y.shape[0]
             accuracy = n_correct_predictions / n_samples
 
             logger.debug(f"Accuracy: {accuracy.item():.2%}")
             logger.debug(f"Average Loss: {loss.item()}")
 
-            # Advance the progress bar one step and update the progress bar text.
+            # Advance the progress bar one step, and update the text displayed in the progress bar.
             progress_bar.update(1)
             progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
         progress_bar.close()
 
         val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
         logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
 
+        # remember best accuracy and save the current state.
+        is_best = val_accuracy > best_acc
+        best_acc = max(val_accuracy, best_acc)
+
+        if checkpoint_dir is not None:
+            save_checkpoint(
+                checkpoint_dir,
+                is_best,
+                RunState(
+                    epoch=epoch,
+                    model_state=model.state_dict(),
+                    optimizer_state=optimizer.state_dict(),
+                    random_state=random.getstate(),
+                    numpy_random_state=numpy.random.get_state(legacy=False),
+                    torch_random_state=torch.random.get_rng_state(),
+                    torch_cuda_random_state=torch.cuda.random.get_rng_state_all(),
+                    best_acc=best_acc,
+                ),
+            )
+
     print("Done!")
 
 
 @torch.no_grad()
 def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
     model.eval()
 
     total_loss = 0.0
     n_samples = 0
     correct_predictions = 0
 
     for batch in dataloader:
         batch = tuple(item.to(device) for item in batch)
         x, y = batch
 
         logits: Tensor = model(x)
         loss = F.cross_entropy(logits, y)
 
         batch_n_samples = x.shape[0]
-        batch_correct_predictions = logits.argmax(-1).eq(y).sum()
+        batch_correct_predictions = logits.argmax(-1).eq(y).sum().item()
 
         total_loss += loss.item()
         n_samples += batch_n_samples
-        correct_predictions += batch_correct_predictions
+        correct_predictions += int(batch_correct_predictions)
 
     accuracy = correct_predictions / n_samples
     return total_loss, accuracy
 
 
 def make_datasets(
     dataset_path: str,
     val_split: float = 0.1,
     val_split_seed: int = 42,
 ):
     """Returns the training, validation, and test splits for CIFAR10.
 
     NOTE: We don't use image transforms here for simplicity.
     Having different transformations for train and validation would complicate things a bit.
     Later examples will show how to do the train/val/test split properly when using transforms.
     """
     train_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
     )
     test_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
     )
     # Split the training dataset into a training and validation set.
-    n_samples = len(train_dataset)
-    n_valid = int(val_split * n_samples)
-    n_train = n_samples - n_valid
     train_dataset, valid_dataset = random_split(
-        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
+        train_dataset, ((1 - val_split), val_split), torch.Generator().manual_seed(val_split_seed)
     )
     return train_dataset, valid_dataset, test_dataset
 
 
 def get_num_workers() -> int:
-    """Gets the optimal number of DatLoader workers to use in the current job."""
+    """Gets the optimal number of DataLoader workers to use in the current job."""
     if "SLURM_CPUS_PER_TASK" in os.environ:
         return int(os.environ["SLURM_CPUS_PER_TASK"])
     if hasattr(os, "sched_getaffinity"):
         return len(os.sched_getaffinity(0))
     return torch.multiprocessing.cpu_count()
 
 
+def load_checkpoint(checkpoint_dir: Path, **torch_load_kwargs) -> RunState | None:
+    """Loads the latest checkpoint if possible, otherwise returns `None`."""
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+    restart_count = int(os.environ.get("SLURM_RESTART_COUNT", 0))
+    if restart_count:
+        logger.info(f"NOTE: This job has been restarted {restart_count} times by SLURM.")
+
+    if not checkpoint_file.exists():
+        logger.debug(f"No checkpoint found in checkpoints dir ({checkpoint_dir}).")
+        if restart_count:
+            logger.warning(
+                RuntimeWarning(
+                    f"This job has been restarted {restart_count} times by SLURM, but no "
+                    "checkpoint was found! This either means that your checkpointing code is "
+                    "broken, or that the job did not reach the checkpointing portion of your "
+                    "training loop."
+                )
+            )
+        return None
+
+    checkpoint_state: dict = torch.load(checkpoint_file, **torch_load_kwargs)
+
+    missing_keys = set(checkpoint_state.keys()) - RunState.__required_keys__
+    if missing_keys:
+        warnings.warn(
+            RuntimeWarning(
+                f"Checkpoint at {checkpoint_file} is missing the following keys: {missing_keys}. "
+                f"Ignoring this checkpoint."
+            )
+        )
+        return None
+
+    logger.debug(f"Resuming from the checkpoint file at {checkpoint_file}")
+    state: RunState = checkpoint_state  # type: ignore
+    return state
+
+
+def save_checkpoint(checkpoint_dir: Path, is_best: bool, state: RunState):
+    """Saves a checkpoint with the current state of the run in the checkpoint dir.
+
+    The best checkpoint is also updated if `is_best` is `True`.
+
+    Parameters
+    ----------
+    checkpoint_dir: The checkpoint directory.
+    is_best: Whether this is the best checkpoint so far.
+    state: The dictionary containing all the things to save.
+    """
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+
+    # Use a unique ID to avoid any potential collisions.
+    unique_id = uuid.uuid1()
+    temp_checkpoint_file = checkpoint_file.with_suffix(f".temp{unique_id}")
+
+    torch.save(state, temp_checkpoint_file)
+    os.replace(temp_checkpoint_file, checkpoint_file)
+
+    if is_best:
+        best_checkpoint_file = checkpoint_file.with_name("model_best.pth")
+        temp_best_checkpoint_file = best_checkpoint_file.with_suffix(f".temp{unique_id}")
+        shutil.copyfile(checkpoint_file, temp_best_checkpoint_file)
+        os.replace(temp_best_checkpoint_file, best_checkpoint_file)
+
+
 if __name__ == "__main__":
     main()

Running this example

$ sbatch job.sh