Introduction

Welcome to our session on saving and loading PyTorch models. After investing hours or days in training a model, saving your progress is crucial. You may want to use the trained model on another device, share it, or just load it in the future. This lesson focuses on how to achieve that, so let's get started!

Brief Recap: Model Training

But before, let's refresh our memory from previous lessons. We've trained a PyTorch model to classify wines using PyTorch's nn.Sequential class, along with CrossEntropyLoss and the Adam optimizer. Here's a quick snapshot:

Saving a Model with PyTorch

After model training, it's time to preserve the model for later use. Here's where PyTorch's handy torch.save() method comes into play. It saves the model into a file ending in .pth, which indicates that this file holds a serialized PyTorch model.

Serialization is a process where an object in memory (like our PyTorch model) is converted into a format that can be saved on disk or sent over a network. Hence, when we save a model, we're serializing the entire module.

Here's the code that saves our model:

Loading a Model with PyTorch

To use the saved model, we must load it back into memory with torch.load. Note, that the weights_only parameter is set to False, which means that the entire model is loaded, including the architecture, weights, and other parameters. After loading, it's essential to set the model to evaluation mode with model.eval().

The model.eval() switches certain layers to evaluation mode (as opposed to training mode), ensuring consistent behavior of layers during inference, such as preventing dropout layers from dropping neurons. This is good practice whenever you're using the model for predictions and not training it.

Here's how to load a saved model:

Evaluating a Loaded Model

To verify the loaded model, you should evaluate it on the test data and compare its performance with the original model. Here's how you can do it:

By comparing the accuracy of the original and loaded model, you can confirm that the model was saved and loaded correctly, ensuring consistent performance. As shown on the output of the above code:

This indicates that the loaded model's performance is identical to the original model, verifying that the saving and loading process preserved the model accurately.

Lesson Summary and Practice

Great job on completing this session! You've learned how to save and load PyTorch models. Such a capability is crucial for efficient machine learning projects. Now, it's time for some hands-on practice. Your task is to train a model, save it, load it back into memory, and evaluate its performance on test data.

Remember: Practice is to learning what refinement is to crude oil. Keep up the hard work!

Additional Resources

For your convenience, here is the helpful code snippet for loading and preprocessing the Wine dataset, which you can use to ensure your data is properly prepared for training and evaluation in PyTorch:

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal