Welcome back! In the last lesson, you learned how to use random search to tune hyperparameters for your machine learning models. As a quick reminder, both grid search and random search use cross-validation to evaluate how well different parameter settings perform. Cross-validation is a key technique in machine learning because it helps you get a more reliable estimate of your model’s performance. Instead of just splitting your data into a single training and test set, cross-validation splits your data into several parts, trains the model on some parts, and tests it on the others. This way, you can be more confident that your results are not just due to a lucky or unlucky split.
In this lesson, you will learn about a specific type of cross-validation called StratifiedKFold. This method is especially useful for classification problems, where you want to make sure that each fold has a similar distribution of classes. By the end of this lesson, you will know how to use StratifiedKFold with scikit-learn
to evaluate your models more effectively.
StratifiedKFold is a variation of the standard K-Fold cross-validation technique. In regular K-Fold, the data is split into k
equal parts, or "folds." The model is trained on k-1
folds and tested on the remaining fold, repeating this process k
times so that each fold is used as a test set once. However, if your dataset is imbalanced (for example, if you have many more examples of one class than another), regular K-Fold might create folds that do not represent the true distribution of your data.
StratifiedKFold solves this problem by making sure that each fold has roughly the same proportion of each class as the whole dataset. This is especially important in classification tasks, where you want your model to be tested on data that looks like the real-world data it will see in practice. Using stratification helps you avoid misleading results that can happen if some folds have too many or too few examples of a certain class.
Before you can use StratifiedKFold, you need to choose a model to evaluate. In this example, we will use a Support Vector Classifier (SVC
) from scikit-learn
. SVC is a popular model for classification tasks and works well with cross-validation. Using SVC will help you see how StratifiedKFold can be applied to real-world classification problems.
Let’s walk through a practical example of using StratifiedKFold with scikit-learn
’s cross_val_score
function. First, let’s assume you have your features in X
and your labels in y
. If you don’t already have data loaded, you can use a sample dataset from scikit-learn
for demonstration purposes. For example, here’s how you can load the Iris dataset:
Now, you want to evaluate how well an SVC model performs using 5-fold stratified cross-validation. Here is how you can do it:
In this code, you first import the necessary classes. The SVC()
creates a support vector classifier. The StratifiedKFold(n_splits=5)
object tells scikit-learn
to split the data into 5 folds, making sure each fold has a similar class distribution. The cross_val_score
function then trains and tests the model on each fold, returning an array of scores — one for each fold.
For example, the output might look like this:
This means that, across the 5 folds, your model achieved scores between 0.89 and 0.93, with an average score of 0.91. This average gives you a good idea of how well your model is likely to perform on new, unseen data.
In this lesson, you learned how to use StratifiedKFold to perform cross-validation in a way that respects the class distribution of your data. This is especially important for classification problems, where imbalanced classes can lead to misleading results if not handled properly. You saw how to set up your data and model, and how to use cross_val_score
with a stratified cross-validator to get a reliable estimate of your model’s performance.
Next, you will get a chance to practice these steps yourself. In the upcoming exercises, you will use StratifiedKFold to evaluate different models and see how cross-validation can help you build more reliable machine learning solutions. Take your time to understand the output and think about how you can use these techniques in your own projects. Good luck!
