Introduction

Welcome to the third lesson of JAX in Action: Building an Image Classifier! Your journey from data preprocessing to neural architecture definition has laid a strong foundation, and now we stand at a pivotal moment. With MNIST data flowing through our pipeline and a sophisticated CNN architecture ready to learn, it's time to breathe life into our system by implementing the mechanisms that will transform random parameters into learned knowledge.

This lesson dives into the core engine of machine learning: the training and evaluation functions that orchestrate the learning process. We'll craft JIT-compiled functions that handle loss computation, automatic differentiation, parameter updates, and performance assessment — the essential components that separate a static model from an active learning system. JAX's functional programming paradigm and powerful compilation capabilities will enable us to build these components with both elegance and efficiency. By lesson's end, you'll have implemented a complete training utilities module that can guide our CNN from random initialization to accurate digit recognition.

Understanding Training vs Evaluation Modes

Before we implement our training engine, it's crucial to understand the fundamental distinction between training and evaluation modes in neural networks. This distinction goes beyond simple terminology: it represents two different computational pathways that serve complementary purposes in the machine learning workflow.

During training mode, our system operates as an active learner, constantly adjusting its understanding based on feedback. The training process follows a specific computational flow: forward propagation generates predictions from input data, a loss function quantifies the prediction errors, automatic differentiation computes gradients showing how to reduce these errors, and an optimizer translates gradients into parameter updates. This cycle requires maintaining additional state (like momentum in the Adam optimizer) and performing expensive gradient computations, but it's the only way our model can learn from data.

Evaluation mode strips away the learning machinery to focus purely on assessment. Here, we perform only forward propagation to generate predictions, then compute metrics like accuracy and loss to gauge performance. No gradients are calculated, no parameters are updated, and no optimizer state is maintained. This streamlined approach makes evaluation much faster than training, allowing us to efficiently monitor learning progress during training or assess final model quality on test data. Think of training mode as a student actively learning and adjusting their understanding, while evaluation mode is like taking a test where knowledge is assessed but not modified.

Building the Loss Function Foundation

The heart of any supervised learning system is its loss function, that is, the mathematical expression that quantifies how wrong our predictions are. For multi-class classification tasks like MNIST digit recognition, softmax cross-entropy is considered the gold standard:

This apparently simple function encapsulates sophisticated numerical computation. The optax.softmax_cross_entropy_with_integer_labels function expects raw logits (the unnormalized outputs from our CNN's final layer) and integer labels (0 through 9 for our digits). Internally, it applies the softmax function to convert logits into probabilities, then computes the negative log-likelihood of the true class. This computation is numerically stabilized to avoid overflow issues that plague naive implementations. By taking the mean across the batch, we obtain a single scalar that represents the average prediction error, which is exactly what our optimizer needs to guide learning. The beauty of this loss function lies in its gradient properties: it provides strong learning signals when predictions are wrong while naturally tapering off as the model improves.

Implementing JIT-Compiled Training Steps

Now we reach the centerpiece of our training engine: the function that orchestrates a complete parameter update cycle. This implementation showcases JAX's ability to combine multiple computational patterns into a single, efficient operation:

The JIT compilation strategy deserves special attention. By using functools.partial with static_argnums=(0, 3), we tell JAX that the model function and optimizer are static, that is, they define the computation graph structure rather than flowing through it as data. This prevents costly recompilation when only parameter values change between training steps. The nested compute_loss_for_grad function creates a clean interface for automatic differentiation: it takes parameters as its first argument and returns a scalar loss, exactly what jax.value_and_grad expects. This function simultaneously computes the loss value (for monitoring) and its gradients (for learning) in a single forward-backward pass, maximizing computational efficiency. The Optax optimizer then transforms raw gradients into sophisticated parameter updates (incorporating momentum, adaptive learning rates, or other techniques), which we apply to obtain our evolved parameters. This functional decomposition of separating gradient computation, update calculation, and parameter application provides both clarity and flexibility.

Creating Evaluation Functions for Model Assessment

While training drives learning, evaluation provides the compass that guides our assessment of the model's behaviour. Our evaluation function distills model assessment to its essential components:

This streamlined function demonstrates evaluation's computational simplicity compared to training. We perform a single forward pass to obtain logits, then compute two key metrics. The loss calculation reuses our cross-entropy function, providing a consistent measure across training and evaluation. For accuracy, we extract predicted classes using jnp.argmax (finding the digit with the highest confidence) and compare them against the true labels. The element-wise equality check produces a boolean array, whose mean gives us the fraction of correct predictions — an intuitive metric that complements the more abstract loss value. Notice that we only mark the model function as static since evaluation doesn't involve optimizers, resulting in even more efficient JIT compilation.

Integrating Training Utilities in Main Pipeline

With our core functions implemented, let's integrate them into a cohesive system that brings together data, model, and optimization components:

This integration exemplifies modular design principles in machine learning systems. By importing our training utilities from a dedicated module, we maintain a clear separation of concerns: data loading handles input processing, the model defines architecture, and training utilities manage the learning process. The hyperparameters are defined as constants at the function's beginning, making experimentation straightforward. We convert the NumPy's iterators returned by load_mnist_dataset to JAX arrays to ensure compatibility with our JAX-based training functions.

Testing Our Training Engine

Let's validate our training engine with a comprehensive test that exercises both training and evaluation paths:

This test produces output that confirms our implementation's correctness:

The parameter change verification using JAX's tree utilities confirms that training actually updates our model. Remember that the tree_map applies element-wise comparison across our entire parameter tree, while tree_all aggregates these comparisons into a single boolean. The initial loss around 2.3 aligns perfectly with theoretical expectations: for 10-class classification with random initialization, we expect ln(1/10)2.3-\ln(1/10) \approx 2.3. Moreover, the evaluation step also executes correctly. These metrics confirm our training engine is ready to guide the model toward meaningful learning.

Conclusion and Next Steps

You've successfully implemented the computational heart of a modern deep learning system: JIT-compiled training and evaluation functions that efficiently orchestrate the entire learning pipeline. Your modular design separates concerns elegantly: dedicated utilities for training operations, clean integration with data and model components, and efficient use of JAX's compilation capabilities for maximum performance.

The stage is now perfectly set for the exciting culmination of our work: implementing complete training loops that will transform our randomly initialized CNN into an accurate digit classifier. In our next lesson, we'll build upon this foundation to create full training runs with multiple epochs, progress monitoring, and systematic evaluation, bringing us to the thrilling moment when our model achieves high accuracy on the MNIST dataset!

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal