Welcome back to another exciting chapter in our JAX in Action: Neural Networks from Scratch journey! You've made excellent progress so far. In our first lesson, we tackled the XOR challenge and established our MLP architecture with proper parameter initialization. Then, in our second lesson, we brought our network to life by implementing forward propagation, watching as data flowed through layers to produce predictions.
Now we're ready for a crucial next step: measuring how wrong our predictions are and figuring out how to improve them. Today, we'll dive into loss functions and JAX
's automatic differentiation capabilities. We'll implement a Binary Cross Entropy (BCE) loss function to quantify the difference between our network's predictions and the true XOR targets. More importantly, we'll harness the power of jax.value_and_grad
to compute both the loss value and the gradients we need for training. By the end of this lesson, you'll have the essential tools to measure your network's performance and understand exactly how to adjust each parameter to make it better!
Before we can train our neural network, we need a way to measure how wrong our predictions are. This is where loss functions come into play. Think of a loss function as a coach evaluating an athlete's performance — it provides a single number that captures how far we are from our goal.
A loss function takes two inputs: the predictions our model makes and the true targets (the correct answers we want). It then computes a scalar value representing the "error" or "loss." The smaller this value, the better our model is performing. During training, our goal will be to adjust the network's parameters to minimize this loss.
For binary classification problems like XOR, where the output is typically 0 or 1, Binary Cross Entropy (BCE) is a very common and effective choice. Its mathematical formulation is:
Let's implement our bce_loss
function. This function will be the cornerstone of our training process, providing the error signal that drives learning.
Let's break this down:
- The
bce_loss
function takespredictions
from our model and the truetargets
. - A crucial numerical stability trick is
eps = 1e-8
andjnp.clip(predictions, eps, 1 - eps)
. This preventspredictions
from being exactly 0 or 1, which would lead tojnp.log(0)
(negative infinity) orjnp.log(1-1)
(also negative infinity), causing issues. Clipping keeps predictions within a safe range like[1e-8, 1-1e-8]
. - The core BCE computation
targets * jnp.log(predictions) + (1 - targets) * jnp.log(1 - predictions)
calculates the loss for each sample.
Now we need to create a bridge between our MLP's forward pass (which you built in the previous lesson!) and our new loss calculation. We'll build a function that takes our network parameters and a batch of data, runs the forward pass to get predictions, and then computes the loss based on those predictions and the true labels.
This function, compute_loss_for_mlp
, is straightforward but very important. It serves as the complete evaluation pipeline for our network's performance on a given batch of data:
- It accepts the current
params_list
(our PyTree of weights and biases), a batch of inputsx_batch
, and the corresponding true labelsy_batch
. - It calls
mlp_forward_pass
(which you developed earlier) to get the network'spredictions
forx_batch
. - It then uses our
bce_loss
function to calculate the scalarloss
between thesepredictions
and .
As you recall from our earlier work with JAX
, one of its most powerful features is automatic differentiation through functions like jax.grad
; here, we'll apply these familiar tools to our neural network training pipeline. Since we need both the loss value (to monitor training progress) and the gradients (to update parameters), jax.value_and_grad
is perfect for our needs. As you remember, it transforms a scalar-valued function to return both the original output and the gradients with respect to the first argument:
Notice how we're stacking transformations here — a common and powerful pattern in JAX
. We first apply jax.value_and_grad
to get both loss and gradients, then wrap the result with jax.jit
for optimized performance. This gives us a highly efficient function that, when called like loss_and_gradients_fn(mlp_params, xor_X, xor_y)
, returns both the loss_value
and a gradients_pytree
with the exact same structure as our parameters.
The gradients will tell us exactly how to adjust each weight and bias to reduce the loss — the foundation of gradient-based training that we'll fully implement in our next lesson!
Let's put all these pieces together: our XOR data, initialized parameters and the forward pass from the previous lessons, as well as our new BCE loss, the combined loss computation function, and the value_and_grad
transformation. We'll see the initial loss of our untrained network and inspect the structure of the gradients.
When you run this code, you should see output similar to this:
Let's interpret this output:
- The "Initial Loss (BCE)" is around 0.7782. Since our network parameters are random, this loss value reflects how far off its initial guesses are from the true XOR outputs.
- The "Loss value (from value_and_grad)" is identical, confirming that
jax.value_and_grad
correctly computes the loss value alongside the gradients. - The "Gradients PyTree Structure" shows the shapes of the gradients. Notice how they perfectly match the shapes of our parameters:
Excellent work, space explorer! You've successfully implemented the critical components for measuring your neural network's error using Binary Cross Entropy loss. Crucially, you've learned how to use jax.value_and_grad
to automatically compute the gradients of this loss with respect to all network parameters.
These gradients are the key to learning, as they tell us the direction in which to adjust our parameters to reduce the loss. With the ability to perform a forward pass, calculate loss, and get gradients, you now possess the fundamental toolkit for training. In our next lesson, we'll put these tools into action and implement the gradient descent optimization algorithm, finally enabling our MLP to learn from data!
