Introduction

Welcome back, intrepid explorer! I'm thrilled to see you for the second lesson in our JAX in Action: Neural Networks from Scratch course. In our previous transmission, we laid the critical groundwork: we prepared our XOR dataset and skillfully crafted the initialize_mlp_params function to set up our network's weights and biases using the Xavier/Glorot method. Our parameters are now neatly organized in a PyTree, ready for action!

Today, we're shifting gears to something truly dynamic: implementing forward propagation. This is the very heart of how a neural network makes predictions. We'll build a function that takes our carefully prepared inputs and parameters and channels the data through the network's layers and activation functions to produce an output. By the end of this lesson, you'll have a working JAX function that performs the complete forward pass for our Multi-Layer Perceptron (MLP), and you'll see it generate its very first (untrained) predictions for the XOR problem!

The Journey of Data: Understanding Forward Propagation

Forward propagation, often called a "forward pass," is the process by which input data flows through a neural network, layer by layer, until it produces an output. Imagine it as a sophisticated assembly line. At each station (or layer), the raw materials (data) undergo specific transformations.

For a typical feedforward neural network like our MLP, this journey involves two main steps at each layer:

  1. Affine Transformation: The input data (or the output from the previous layer) is linearly transformed using the layer's weights and biases. If xx is the input, WW are the weights, and bb is the bias, this step computes z=xW+bz = xW + b. Here, represents the or : the raw numerical results before any non-linear transformation is applied. Think of as the "candidate outputs" that capture the weighted influence of all inputs, but haven't yet been processed through the activation function's non-linear transformation.

Building Blocks: Affine Transformations and Activations

The core mathematical operations within each layer of our MLP are the affine transformation and the application of an activation function. JAX, with its jax.numpy library, provides efficient tools for these.

The affine transformation, z=xW+bz = xW + b, involves:

  • A matrix multiplication between the input xx and the weights WW. In JAX, we'll use jnp.dot(x, W).
  • An of the bias term . JAX handles broadcasting, so if is a matrix of outputs for a batch of inputs, will be added to each corresponding row.
Crafting the MLP's Forward Pass Function

Now, let's translate this understanding into a JAX function. We'll create mlp_forward_pass, which takes the list of parameters (our PyTree of weights and biases for each layer) and an input x_input. It will then guide the input data through the network.

Let's break this down:

  • We initialize activations with the x_input. This variable will hold the output of the current layer, which becomes the input to the next.
  • We loop through params_list up to the second-to-last element. These are our hidden layers.
    • Inside the loop, we extract the weights and for the current layer.
Setting the Stage: Inputs and Initialized Parameters

Before we can see our mlp_forward_pass function in action, we need two things: our input data (the XOR examples) and the initialized parameters for our network. We prepared the XOR data in the previous lesson, and we also developed a function, initialize_mlp_params, to create the network's weights and biases.

Let's bring in the initialize_mlp_params function. You'll recall its detailed construction from our last lesson; we include it here as it's essential for our current task of running the forward pass.

Here, we've defined our xor_X data. Then, we specified layer_sizes for an MLP with an input layer of 2 neurons, a hidden layer of 3 neurons, and an output layer of 1 neuron. Finally, we used jax.random.key(0) to create a PRNG key and called to get our PyTree. With our inputs and parameters ready, we can now perform the forward pass!

Witnessing Initial Predictions: Testing the Forward Pass

It's time for the exciting part: let's use our mlp_forward_pass function with the xor_X data and the mlp_params we just prepared. This will give us the initial predictions of our untrained network.

When you run this code, you'll see the following output:

Let's analyze this output:

  • The "Initial Predictions" are the network's outputs for each of the four XOR input pairs. Since the network's weights were initialized randomly and it hasn't been trained, these predictions are essentially random values (between 0 and 1 due to the sigmoid activation in the output layer). They don't match the true XOR outputs yet (which should be [[0.], [1.], [1.], [0.]]).
  • The Shape of predictions: (4, 1) confirms that our network is producing one output for each of the four input samples, which is exactly what we expect.
  • The test with a single_sample ([[0. 0.]]) shows that our mlp_forward_pass function correctly handles inputs with a batch size of 1, producing a single prediction . This demonstrates the flexibility of our implementation.
Conclusion and Next Steps

Fantastic work, cosmic coder! You've successfully implemented the forward propagation mechanism for our MLP, allowing it to take inputs and generate predictions. This mlp_forward_pass function is a cornerstone of any neural network, representing how it processes information.

While our network currently makes random guesses, we've built the essential pathway for information flow. In our next lesson, we'll tackle the other half of the learning puzzle: calculating how wrong our predictions are (the loss) and figuring out how to adjust our network's parameters to improve them using backpropagation and JAX's powerful automatic differentiation.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal