Skip to content

Checkpoint Mechanism and Usage

In practical deep learning scenarios, model training typically lasts for a period, which places higher demands on the stability and efficiency of distributed training tasks. Moreover, during actual training, unexpected interruptions can cause the loss of the model state, requiring the training process to start over. This not only wastes time and resources, which is particularly evident in LLM training, but also affects the training effectiveness of the model.

The ability to save the model state during training, so that it can be restored in case of an interruption, becomes crucial. Checkpointing is the mainstream solution to this problem. This article will introduce the basic concepts of the Checkpoint mechanism and its usage in PyTorch and TensorFlow.

What is a Checkpoint?

A checkpoint is a mechanism for saving the state of a model during training. By periodically saving checkpoints, you can restore the model in the following situations:

  • Training interruption (e.g., system crash or manual interruption)
  • Need to evaluate at a certain stage of training
  • Reuse the model in different experiments

PyTorch

In PyTorch, torch.save and torch.load are the basic functions used for saving and loading models.

Save Checkpoints in PyTorch

In PyTorch, the state_dict is typically used to save the model's parameters. Here is a simple example:

import torch
import torch.nn as nn

# Assume you have a simple neural network
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Initialize model and optimizer
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model...
# Save checkpoint
checkpoint_path = 'model_checkpoint.pth'
torch.save({
    'epoch': 10,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.02,
}, checkpoint_path)

Restore Checkpoints in PyTorch

When loading the model, you need to restore the model parameters and optimizer state, and then continue training or inference:

# Restore checkpoint
checkpoint = torch.load('model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# Continue training or inference...
  • model_state_dict: Model parameters
  • optimizer_state_dict: Optimizer state
  • epoch: Current training epoch
  • loss: Loss value
  • learning_rate: Learning rate
  • best_accuracy: Best accuracy

TensorFlow

TensorFlow provides the tf.train.Checkpoint class to manage the saving and restoring of models and optimizers.

Save Checkpoints in TensorFlow

Here is an example of saving a checkpoint in TensorFlow:

import tensorflow as tf

# Assume you have a simple model
model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=(10,))
])
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Define checkpoint
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
checkpoint_dir = './checkpoints'
checkpoint_prefix = f'{checkpoint_dir}/ckpt'

# Train the model...
# Save checkpoint
checkpoint.save(file_prefix=checkpoint_prefix)

Note

Users of DCE 5.0 AI Lab can directly mount high-performance storage as the checkpoint directory to improve the speed of saving and restoring checkpoints.

Restore Checkpoints in TensorFlow

Load the checkpoint and restore the model and optimizer state:

# Restore checkpoint
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)

# Continue training or inference...

Manage Checkpoints in Distributed Training with TensorFlow

In distributed training, TensorFlow manages checkpoints primarily through the following methods:

  • Using tf.train.Checkpoint and tf.train.CheckpointManager

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    manager = tf.train.CheckpointManager(checkpoint, directory='/tmp/model', max_to_keep=3)
    
  • Saving checkpoints within a distributed strategy

    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        manager = tf.train.CheckpointManager(checkpoint, directory='/tmp/model', max_to_keep=3)
    
  • Saving checkpoints only on the chief worker node

    if strategy.cluster_resolver.task_type == 'chief':
        manager.save()
    
  • Special handling when using MultiWorkerMirroredStrategy

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Define model
        ...
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        manager = tf.train.CheckpointManager(checkpoint, '/tmp/model', max_to_keep=3)
    
    def _chief_worker(task_type, task_id):
        return task_type is None or task_type == 'chief' or (task_type == 'worker' and task_id == 0)
    
    if _chief_worker(strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id):
        manager.save()
    
  • Using a distributed file system

    Ensure all worker nodes can access the same checkpoint directory, typically using a distributed file system such as HDFS or GCS.

  • Asynchronous saving

    Use tf.keras.callbacks.ModelCheckpoint and set the save_freq parameter to asynchronously save checkpoints during training.

  • Checkpoint restoration

    status = checkpoint.restore(manager.latest_checkpoint)
    status.assert_consumed()  # (1)!
    
    1. Ensure all variables are restored
  • Performance optimization

    • Enable mixed precision training using tf.train.experimental.enable_mixed_precision_graph_rewrite()
    • Adjust saving frequency to avoid too frequent I/O operations
    • Consider using tf.saved_model.save() to save the entire model, not just the weights

Considerations

  1. Regular Saving : Determine a suitable saving frequency based on training time and resource consumption, such as every epoch or every few training steps.

  2. Save Multiple Checkpoints : Keep the latest few checkpoints to prevent issues like file corruption or inapplicability.

  3. Record Metadata : Save additional information in the checkpoint, such as the epoch number and loss value, to better restore the training state.

  4. Use Version Control : Save checkpoints for different experiments to facilitate comparison and reuse.

  5. Validation and Testing : Use checkpoints for validation and testing at different training stages to ensure model performance and stability.

Conclusion

The checkpoint mechanism plays a crucial role in deep learning training. By effectively using the checkpoint features in PyTorch and TensorFlow, you can significantly improve the reliability and efficiency of training. The methods and best practices described in this article should help you better manage the training process of deep learning models.

Comments