Welcome to our session on saving and loading PyTorch models. After investing hours or days in training a model, saving your progress is crucial. You may want to use the trained model on another device, share it, or just load it in the future. This lesson focuses on how to achieve that, so let's get started!
But before, let's refresh our memory from previous lessons. We've trained a PyTorch model to classify wines using PyTorch's nn.Sequential
class, along with CrossEntropyLoss
and the Adam
optimizer. Here's a quick snapshot:
After model training, it's time to preserve the model for later use. Here's where PyTorch's handy torch.save()
method comes into play. It saves the model into a file ending in .pth
, which indicates that this file holds a serialized PyTorch model.
Serialization is a process where an object in memory (like our PyTorch model) is converted into a format that can be saved on disk or sent over a network. Hence, when we save a model, we're serializing the entire module.
Here's the code that saves our model:
To use the saved model, we must load it back into memory with torch.load
. Note, that the weights_only
parameter is set to False
, which means that the entire model is loaded, including the architecture, weights, and other parameters. After loading, it's essential to set the model to evaluation mode with model.eval()
.
The model.eval()
switches certain layers to evaluation mode (as opposed to training mode), ensuring consistent behavior of layers during inference, such as preventing dropout layers from dropping neurons. This is good practice whenever you're using the model for predictions and not training it.
Here's how to load a saved model:
To verify the loaded model, you should evaluate it on the test data and compare its performance with the original model. Here's how you can do it:
By comparing the accuracy of the original and loaded model, you can confirm that the model was saved and loaded correctly, ensuring consistent performance. As shown on the output of the above code:
This indicates that the loaded model's performance is identical to the original model, verifying that the saving and loading process preserved the model accurately.
Great job on completing this session! You've learned how to save and load PyTorch models. Such a capability is crucial for efficient machine learning projects. Now, it's time for some hands-on practice. Your task is to train a model, save it, load it back into memory, and evaluate its performance on test data.
Remember: Practice is to learning what refinement is to crude oil. Keep up the hard work!
For your convenience, here is the helpful code snippet for loading and preprocessing the Wine dataset, which you can use to ensure your data is properly prepared for training and evaluation in PyTorch:
