Welcome back to the final lesson of JAX in Action: Neural Networks from Scratch! What an incredible journey we've taken together through the fundamentals of building neural networks from the ground up. We started by tackling the classic XOR problem and establishing our MLP architecture with proper parameter initialization. We then implemented forward propagation to see our network generate predictions, and in our previous lesson, we developed the crucial ability to measure prediction quality using Binary Cross Entropy loss and compute gradients with JAX's automatic differentiation.
Today marks the culmination of our neural network construction project. We'll implement a complete training loop that iteratively improves our MLP's performance using gradient descent. You'll learn to orchestrate the training process by repeatedly computing predictions, calculating loss and gradients, and updating parameters to minimize error. By the end of this lesson, you'll witness your network actually learning to solve the XOR problem through experience — a truly magical moment in machine learning!
At its heart, machine learning is about iterative improvement. Think of learning to play a musical instrument: you don't master it overnight, but through countless repetitions, gradually adjusting your technique based on feedback. Neural network training follows this same principle.
Our training process operates in cycles called epochs. During each epoch, we present our network with the training data and perform these key steps: compute predictions using forward propagation, calculate how wrong these predictions are using our loss function, determine the direction to adjust parameters using gradients, and finally update parameters to (hopefully) improve performance. This cycle repeats hundreds or thousands of times until our network performs well.
The beauty of this approach lies in its simplicity — we don't need to manually figure out what each parameter should be. Instead, we let the mathematical framework of gradient descent guide us toward better solutions automatically. Each iteration brings us closer to a network that can accurately solve our XOR problem.
The mathematical foundation of our training process is gradient descent, an optimization algorithm that systematically adjusts parameters to minimize our loss function. The core update rule is elegantly simple:
JAX provides powerful utilities for working with PyTrees through the jax.tree_util
module; recall that the tree_map
function that we encountered previously is particularly useful for applying an operation element-wise across corresponding leaves of one or more PyTrees that share the same structure.
Let's say we have our current parameters (current_params
) and their corresponding gradients
, both structured as PyTrees. We want to apply the gradient descent update to each parameter. We can define a simple function for the update rule:
Here's what makes this approach powerful:
apply_gradient_step
takes a single parameter value (param_leaf
) and its corresponding gradient (grad_leaf
). It useslearning_rate
(which we'll define in our training setup) to compute the updated parameter.jtu.tree_map(apply_gradient_step, current_params, gradients)
then intelligently traverses both thecurrent_params
PyTree and thegradients
PyTree. For every matching pair of parameter leaf and gradient leaf, it calls and constructs a new PyTree () with the updated values. This means we don't need to manually iterate through layers or dictionaries; handles the PyTree structure for us.
Now we'll construct the main training loop that orchestrates the entire learning process. This loop will cycle through epochs, and in each epoch, it will compute loss and gradients, then apply updates to our network's parameters.
First, let's set up our training configuration and recall our loss_and_grad_fn
from the previous lesson, which gives us both the loss and the gradients.
Let's break down this crucial piece of code:
- We set a
learning_rate
and thenum_epochs
(how many times we'll iterate through the data). loss_and_grad_fn
is our JIT-compiled function from the last lesson that efficiently calculates both the loss and the gradients forcurrent_params
given thexor_X
data andxor_y
targets.- The
for
loop iterates fornum_epochs
.
To understand how well our training is progressing, we need to monitor and display the loss values. This feedback helps us verify that learning is occurring and provides insights into the training dynamics. We can simply add a print statement inside our loop.
When we integrate this into our training loop and run it (assuming all previous functions like initialize_mlp_params
, compute_loss_for_mlp
, and the XOR data are set up), we should see output similar to this:
This progress log tells a beautiful story of learning! Notice how the loss starts around 0.778 (quite high for our normalized BCE loss) and steadily decreases to about 0.019 by the end of 2000 epochs. The most dramatic improvement often happens once the network starts to "understand" the underlying pattern, as seen by the significant drop around epoch 1000. This steady decrease in loss is exactly what we want to see during successful training, indicating our MLP is getting better at predicting the XOR outputs.
After training is complete, we need to evaluate how well our network has learned to solve the XOR problem. This involves running the forward pass with the trained parameters, calculating the final loss, and, importantly, checking the actual predictions against the targets to compute accuracy.
Running this evaluation code after our training loop should produce results like this:
These results are remarkable! Our network, trained from scratch using JAX, has achieved 100% accuracy on the XOR problem. Notice how the raw prediction values are very close to the target binary values (e.g., 0.0131 for a target of 0.0, and 0.9829 for a target of 1.0). This confirms that our MLP has successfully learned the non-linear XOR logical function through iterative gradient descent optimization.
Congratulations on this momentous milestone! You've successfully completed JAX in Action: Neural Networks from Scratch, building a neural network training system entirely from scratch. You've mastered forward propagation, loss functions, automatic differentiation, and now, the complete gradient descent training loop using pure JAX.
The foundational principles you've learned here are pivotal for modern deep learning. Get ready for your next adventure, Beyond Pure JAX: Flax & Optax for Elegant ML! In this upcoming course, you'll leverage powerful JAX ecosystem libraries like Flax and Optax to construct sophisticated models with greater ease and efficiency.
