Introduction

Welcome back to another exciting chapter in our JAX in Action: Neural Networks from Scratch journey! You've made excellent progress so far. In our first lesson, we tackled the XOR challenge and established our MLP architecture with proper parameter initialization. Then, in our second lesson, we brought our network to life by implementing forward propagation, watching as data flowed through layers to produce predictions.

Now we're ready for a crucial next step: measuring how wrong our predictions are and figuring out how to improve them. Today, we'll dive into loss functions and JAX's automatic differentiation capabilities. We'll implement a Binary Cross Entropy (BCE) loss function to quantify the difference between our network's predictions and the true XOR targets. More importantly, we'll harness the power of jax.value_and_grad to compute both the loss value and the gradients we need for training. By the end of this lesson, you'll have the essential tools to measure your network's performance and understand exactly how to adjust each parameter to make it better!

Understanding Loss Functions: The Heart of Learning

Before we can train our neural network, we need a way to measure how wrong our predictions are. This is where loss functions come into play. Think of a loss function as a coach evaluating an athlete's performance — it provides a single number that captures how far we are from our goal.

A loss function takes two inputs: the predictions our model makes and the true targets (the correct answers we want). It then computes a scalar value representing the "error" or "loss." The smaller this value, the better our model is performing. During training, our goal will be to adjust the network's parameters to minimize this loss.

For binary classification problems like XOR, where the output is typically 0 or 1, Binary Cross Entropy (BCE) is a very common and effective choice. Its mathematical formulation is:

BCE(y,y^)=1Ni=1N[yilog(y^i)+(1yi)log(1y^i)]\text{BCE}(y, \hat{y}) = -\frac{1}{N}\sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i)]
Implementing Binary Cross Entropy Loss

Let's implement our bce_loss function. This function will be the cornerstone of our training process, providing the error signal that drives learning.

Let's break this down:

  • The bce_loss function takes predictions from our model and the true targets.
  • A crucial numerical stability trick is eps = 1e-8 and jnp.clip(predictions, eps, 1 - eps). This prevents predictions from being exactly 0 or 1, which would lead to jnp.log(0) (negative infinity) or jnp.log(1-1) (also negative infinity), causing issues. Clipping keeps predictions within a safe range like [1e-8, 1-1e-8].
  • The core BCE computation targets * jnp.log(predictions) + (1 - targets) * jnp.log(1 - predictions) calculates the loss for each sample.
Connecting Loss to Model Predictions

Now we need to create a bridge between our MLP's forward pass (which you built in the previous lesson!) and our new loss calculation. We'll build a function that takes our network parameters and a batch of data, runs the forward pass to get predictions, and then computes the loss based on those predictions and the true labels.

This function, compute_loss_for_mlp, is straightforward but very important. It serves as the complete evaluation pipeline for our network's performance on a given batch of data:

  1. It accepts the current params_list (our PyTree of weights and biases), a batch of inputs x_batch, and the corresponding true labels y_batch.
  2. It calls mlp_forward_pass (which you developed earlier) to get the network's predictions for x_batch.
  3. It then uses our bce_loss function to calculate the scalar loss between these predictions and .
JAX's Automatic Differentiation Magic: value_and_grad

As you recall from our earlier work with JAX, one of its most powerful features is automatic differentiation through functions like jax.grad; here, we'll apply these familiar tools to our neural network training pipeline. Since we need both the loss value (to monitor training progress) and the gradients (to update parameters), jax.value_and_grad is perfect for our needs. As you remember, it transforms a scalar-valued function to return both the original output and the gradients with respect to the first argument:

Notice how we're stacking transformations here — a common and powerful pattern in JAX. We first apply jax.value_and_grad to get both loss and gradients, then wrap the result with jax.jit for optimized performance. This gives us a highly efficient function that, when called like loss_and_gradients_fn(mlp_params, xor_X, xor_y), returns both the loss_value and a gradients_pytree with the exact same structure as our parameters.

The gradients will tell us exactly how to adjust each weight and bias to reduce the loss — the foundation of gradient-based training that we'll fully implement in our next lesson!

Testing Our Complete Implementation

Let's put all these pieces together: our XOR data, initialized parameters and the forward pass from the previous lessons, as well as our new BCE loss, the combined loss computation function, and the value_and_grad transformation. We'll see the initial loss of our untrained network and inspect the structure of the gradients.

When you run this code, you should see output similar to this:

Let's interpret this output:

  • The "Initial Loss (BCE)" is around 0.7782. Since our network parameters are random, this loss value reflects how far off its initial guesses are from the true XOR outputs.
  • The "Loss value (from value_and_grad)" is identical, confirming that jax.value_and_grad correctly computes the loss value alongside the gradients.
  • The "Gradients PyTree Structure" shows the shapes of the gradients. Notice how they perfectly match the shapes of our parameters:
Conclusion and Next Steps

Excellent work, space explorer! You've successfully implemented the critical components for measuring your neural network's error using Binary Cross Entropy loss. Crucially, you've learned how to use jax.value_and_grad to automatically compute the gradients of this loss with respect to all network parameters.

These gradients are the key to learning, as they tell us the direction in which to adjust our parameters to reduce the loss. With the ability to perform a forward pass, calculate loss, and get gradients, you now possess the fundamental toolkit for training. In our next lesson, we'll put these tools into action and implement the gradient descent optimization algorithm, finally enabling our MLP to learn from data!

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal