Introduction

Welcome back to the fourth lesson of JAX in Action: Building an Image Classifier! We've constructed a solid foundation over our previous lessons: efficient data loading with preprocessing pipelines, a sophisticated CNN architecture using Flax, and the essential training and evaluation utilities that form the computational core of our learning system. Now comes the exciting culmination, where we orchestrate all these components into a complete, functioning training pipeline.

In this lesson, we'll build the main orchestration layer that brings everything together in a seamless training and evaluation workflow. We'll implement the full pipeline in our main module, complete with hyperparameter configuration, systematic epoch-based training loops, and comprehensive evaluation cycles that track our model's learning progress. This represents the final piece of our machine learning puzzle, transforming our modular components into a cohesive system capable of training a CNN from random initialization to high accuracy on MNIST digit classification.

Setting Up Hyperparameters

Let's begin our implementation by establishing the fundamental configuration that will support our entire training process:

This configuration section establishes the hyperparameter foundation for our training process. We set a modest 5 epochs since we're limiting our dataset size to keep execution time reasonable for demonstration purposes. The batch size of 64 provides a good balance between computational efficiency and gradient noise, while the learning rate of 1e-3 represents a proven starting point for Adam optimization. The PRNG seed ensures reproducible results across different runs, which is crucial for debugging and comparison purposes.

Setting up Data Loading

Let's now focus on the data infrastructure:

The data loading configuration deliberately limits our dataset size to 1,000 training samples and 200 test samples. This limitation serves a practical purpose: keeping our training demonstration fast and focused while still providing meaningful learning dynamics. We calculate the exact number of batches upfront, which will be essential for our training loop logic and ensures we know precisely how many iterations to perform in each epoch.

Initializing Model and Optimizer Components

With our data pipeline established, we now initialize the core learning components that will drive our training process:

This initialization sequence demonstrates the usual JAX practices for model setup. We create a proper JAX random key using the modern jax.random.key function, then split it to ensure clean separation between different random operations. The model initialization requires a dummy input tensor matching our expected data shape, which triggers Flax's lazy initialization mechanism to determine all parameter dimensions automatically.

The optimizer initialization creates the internal state that Adam requires to track momentum and variance estimates for each parameter. This state will evolve throughout training, accumulating the historical information that enables Adam's adaptive learning rate behavior. By storing these components in current_params and current_opt_state, we establish the mutable state that our training loop will continuously update.

Creating a Reusable Evaluation Function

Now, let's implement a reusable evaluation function that will serve both our initial baseline assessment and our epoch-by-epoch monitoring:

This function encapsulates the complete evaluation workflow with proper metric aggregation. By extracting this logic into a standalone function, we eliminate code duplication and ensure consistent evaluation behavior throughout our training pipeline. The function handles weighted averaging correctly by multiplying each batch's metrics by its actual size before accumulating, then dividing by the total number of samples. This approach gracefully handles variable batch sizes and provides accurate overall metrics.

Establishing the Initial Baseline

Before beginning training, it's crucial to establish a performance baseline that shows our model's capabilities with random initialization:

This initial evaluation serves multiple important purposes. First, it confirms that our evaluation pipeline works correctly before we begin the training process. Second, it provides a crucial baseline that helps us understand the magnitude of improvement achieved through training. For a 10-class classification problem like MNIST, we expect random performance to hover around 10% accuracy; any significant deviation from this baseline could indicate initialization issues or implementation bugs that should be addressed before proceeding with training.

Implementing the Main Training Loop

Now we implement the core training loop that orchestrates the learning process across multiple epochs:

The training loop structure follows a nested iteration pattern: the outer loop progresses through epochs, while the inner loop processes individual batches within each epoch. We measure epoch duration to provide insight into training performance, which becomes crucial when scaling to larger datasets.

Each batch processing cycle converts the TensorFlow dataset outputs to JAX arrays, then calls our train_step function with the current model state. Notice how we continuously update current_params and current_opt_state — this functional approach ensures that each training step builds upon the previous one's learning. The loss accumulation allows us to compute average training loss per epoch, providing a smooth metric for monitoring training progress.

Continuous Evaluation with Our Reusable Function

The evaluation phase leverages our dedicated evaluation function for clean, consistent metric computation:

By using our evaluate_model function, we ensure that both our initial baseline and epoch-by-epoch evaluations follow identical logic and produce comparable results. This consistency is crucial for tracking learning progress accurately and identifying potential issues like overfitting or convergence problems. The function handles all the complexity of batch processing and metric aggregation, allowing our main training loop to focus on the high-level orchestration.

Logging the Progress

Let's complete our training pipeline by adding comprehensive logging:

When we execute this complete training pipeline, we observe the following progression:

This output reveals compelling learning dynamics: starting from the expected random baseline of 9.38% accuracy, our model demonstrates dramatic improvement to 45.31% after just the first epoch. The training loss decreases steadily from 2.33 to 0.58, while evaluation accuracy continues its remarkable climb to 81.25%. The decreasing epoch duration after the first epoch reflects JAX's JIT compilation reaching a steady state after the initial compilation overhead. The consistent improvement in both training and evaluation metrics indicates healthy learning without overfitting, validating our architecture and hyperparameter choices.

Conclusion and Next Steps

Congratulations! You've successfully implemented a complete end-to-end training pipeline that orchestrates all the components we've built throughout this course. Your implementation demonstrates professional machine learning practices: modular design with clear separation of concerns, proper state management across training steps, comprehensive metric aggregation, and systematic progress monitoring.

The remarkable journey from 9.38% to 81.25% accuracy showcases the power of modern deep learning frameworks like JAX, Flax, and Optax working in harmony. You've experienced firsthand how a randomly initialized CNN can learn to recognize handwritten digits through the elegant interplay of forward propagation, backpropagation, and parameter optimization. Now it's time to apply these concepts in practice and solidify your understanding by implementing variations and enhancements to this foundation.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal