Introduction

Welcome back to the final lesson of JAX in Action: Neural Networks from Scratch! What an incredible journey we've taken together through the fundamentals of building neural networks from the ground up. We started by tackling the classic XOR problem and establishing our MLP architecture with proper parameter initialization. We then implemented forward propagation to see our network generate predictions, and in our previous lesson, we developed the crucial ability to measure prediction quality using Binary Cross Entropy loss and compute gradients with JAX's automatic differentiation.

Today marks the culmination of our neural network construction project. We'll implement a complete training loop that iteratively improves our MLP's performance using gradient descent. You'll learn to orchestrate the training process by repeatedly computing predictions, calculating loss and gradients, and updating parameters to minimize error. By the end of this lesson, you'll witness your network actually learning to solve the XOR problem through experience — a truly magical moment in machine learning!

The Core of Machine Learning: Iterative Improvement

At its heart, machine learning is about iterative improvement. Think of learning to play a musical instrument: you don't master it overnight, but through countless repetitions, gradually adjusting your technique based on feedback. Neural network training follows this same principle.

Our training process operates in cycles called epochs. During each epoch, we present our network with the training data and perform these key steps: compute predictions using forward propagation, calculate how wrong these predictions are using our loss function, determine the direction to adjust parameters using gradients, and finally update parameters to (hopefully) improve performance. This cycle repeats hundreds or thousands of times until our network performs well.

The beauty of this approach lies in its simplicity — we don't need to manually figure out what each parameter should be. Instead, we let the mathematical framework of gradient descent guide us toward better solutions automatically. Each iteration brings us closer to a network that can accurately solve our XOR problem.

Gradient Descent: The Update Rule

The mathematical foundation of our training process is gradient descent, an optimization algorithm that systematically adjusts parameters to minimize our loss function. The core update rule is elegantly simple:

Wnew=WoldαLWW_{new} = W_{old} - \alpha \cdot \frac{\partial \mathcal{L}}{\partial W}
Implementing Parameter Updates with tree_map

JAX provides powerful utilities for working with PyTrees through the jax.tree_util module; recall that the tree_map function that we encountered previously is particularly useful for applying an operation element-wise across corresponding leaves of one or more PyTrees that share the same structure.

Let's say we have our current parameters (current_params) and their corresponding gradients, both structured as PyTrees. We want to apply the gradient descent update to each parameter. We can define a simple function for the update rule:

Here's what makes this approach powerful:

  • apply_gradient_step takes a single parameter value (param_leaf) and its corresponding gradient (grad_leaf). It uses learning_rate (which we'll define in our training setup) to compute the updated parameter.
  • jtu.tree_map(apply_gradient_step, current_params, gradients) then intelligently traverses both the current_params PyTree and the gradients PyTree. For every matching pair of parameter leaf and gradient leaf, it calls and constructs a new PyTree () with the updated values. This means we don't need to manually iterate through layers or dictionaries; handles the PyTree structure for us.
Building the Complete Training Loop

Now we'll construct the main training loop that orchestrates the entire learning process. This loop will cycle through epochs, and in each epoch, it will compute loss and gradients, then apply updates to our network's parameters.

First, let's set up our training configuration and recall our loss_and_grad_fn from the previous lesson, which gives us both the loss and the gradients.

Let's break down this crucial piece of code:

  • We set a learning_rate and the num_epochs (how many times we'll iterate through the data).
  • loss_and_grad_fn is our JIT-compiled function from the last lesson that efficiently calculates both the loss and the gradients for current_params given the xor_X data and xor_y targets.
  • The for loop iterates for num_epochs.
Monitoring Training Progress

To understand how well our training is progressing, we need to monitor and display the loss values. This feedback helps us verify that learning is occurring and provides insights into the training dynamics. We can simply add a print statement inside our loop.

When we integrate this into our training loop and run it (assuming all previous functions like initialize_mlp_params, compute_loss_for_mlp, and the XOR data are set up), we should see output similar to this:

This progress log tells a beautiful story of learning! Notice how the loss starts around 0.778 (quite high for our normalized BCE loss) and steadily decreases to about 0.019 by the end of 2000 epochs. The most dramatic improvement often happens once the network starts to "understand" the underlying pattern, as seen by the significant drop around epoch 1000. This steady decrease in loss is exactly what we want to see during successful training, indicating our MLP is getting better at predicting the XOR outputs.

Evaluating the Trained Model

After training is complete, we need to evaluate how well our network has learned to solve the XOR problem. This involves running the forward pass with the trained parameters, calculating the final loss, and, importantly, checking the actual predictions against the targets to compute accuracy.

Running this evaluation code after our training loop should produce results like this:

These results are remarkable! Our network, trained from scratch using JAX, has achieved 100% accuracy on the XOR problem. Notice how the raw prediction values are very close to the target binary values (e.g., 0.0131 for a target of 0.0, and 0.9829 for a target of 1.0). This confirms that our MLP has successfully learned the non-linear XOR logical function through iterative gradient descent optimization.

Conclusion and What's Next

Congratulations on this momentous milestone! You've successfully completed JAX in Action: Neural Networks from Scratch, building a neural network training system entirely from scratch. You've mastered forward propagation, loss functions, automatic differentiation, and now, the complete gradient descent training loop using pure JAX.

The foundational principles you've learned here are pivotal for modern deep learning. Get ready for your next adventure, Beyond Pure JAX: Flax & Optax for Elegant ML! In this upcoming course, you'll leverage powerful JAX ecosystem libraries like Flax and Optax to construct sophisticated models with greater ease and efficiency.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal