Introduction

Welcome back to Beyond Pure JAX: Flax & Optax for Elegant ML! In the previous lessons, you've mastered the essential building blocks: Flax modules for elegant neural network construction, powerful MLP architectures with dense layers, and sophisticated Optax optimizers with JIT-compiled training steps. Today, for the final lesson of the course we're bringing everything together to solve the classic XOR problem using our complete Flax and Optax toolkit.

By the end of this lesson, you'll have assembled a complete, production-ready training pipeline that seamlessly integrates model initialization, optimization state management, training loops, and model evaluation. This represents the culmination of everything we've learned and serves as a template for tackling real-world machine learning problems. We'll also compare our elegant Flax and Optax solution with the manual JAX implementation we built back in Course 3. Get ready!

Assembling the Complete Training Pipeline

Creating a robust training pipeline requires orchestrating several components that we've explored individually. We need to coordinate our Flax MLP model, the Optax Adam optimizer, and the JIT-compiled training step we mastered previously. The beauty of this integration lies in how these components work together seamlessly, despite being developed independently.

Our training pipeline follows a clear structure: initialize the model and create its parameters, set up the optimizer and its internal state, then iterate through training epochs while calling our compiled training step function. Each component maintains its own state that gets passed through the training loop, ensuring we can track both parameter evolution and optimizer momentum terms throughout the learning process.

The key insight is that modern deep learning frameworks like Flax and Optax are designed for composability. We can mix and match different model architectures, loss functions, and optimizers without rewriting our core training logic. This modularity makes our code both more maintainable and more experiment-friendly, allowing us to rapidly prototype different configurations.

Model Initialization

Let's begin by setting up our core components. We'll initialize our MLP model with the same architecture we used previously, then create our optimizer with carefully chosen hyperparameters:

This code outputs:

The model.init() call is where Flax performs its lazy initialization magic. By passing our XOR input data, Flax can infer all the necessary shapes and create appropriately sized weight matrices and bias vectors. The returned initial_variables contains our parameter tree under the 'params' key, which we'll use throughout training. The jax.tree_util.tree_map function elegantly displays the shape of each parameter in our nested structure without printing the actual values.

Optimizer Initialization

Next, we initialize our Optax optimizer and its state:

The optimizer.init() method creates internal state structures that Adam needs to track momentum and adaptive learning rates for each parameter. This state mirrors our parameter tree structure, ensuring every weight and bias has its corresponding optimization state. The printed count shows Adam's internal step counter, starting from zero, which will increment with each optimization step.

Structuring the Training Loop

Now we'll create the heart of our training pipeline: the main training loop that iterates through epochs and applies our JIT-compiled training step. The loop structure is elegantly simple, yet handles all the complexity of parameter updates and state management:

Notice how clean this loop appears compared to manual JAX implementations. We're calling our train_step function that we developed in the previous lesson, passing all necessary components: the model for forward passes, current parameters, optimizer state, the optimizer itself, and our XOR dataset. The function returns updated parameters, updated optimizer state, and the current loss value.

The state threading pattern is crucial here. Each iteration receives the updated parameters and optimizer state from the previous iteration, ensuring continuous learning progression. Since our XOR dataset is small (only 4 examples), we treat the entire dataset as a single batch, which is perfectly reasonable for this demonstration problem. The periodic logging every 200 epochs provides visibility into training progress without overwhelming the output.

Training Execution and Progress Monitoring

When we execute our training loop, we'll see the learning progression unfold. The model starts with random parameters and gradually learns the XOR mapping through iterative optimization:

The loss progression demonstrates effective learning: starting from an initial loss of 0.257154, the model rapidly decreases the error and converges to approximately 0.000045 after 2000 epochs. This exponential decay pattern is characteristic of successful neural network training on well-conditioned problems like XOR. The steady decrease in loss magnitude shows that our Adam optimizer with a learning rate of 0.05 is well-tuned for this problem, neither too aggressive (which would cause instability) nor too conservative (which would slow convergence).

Model Evaluation and Results Analysis

After training is complete, we evaluate our model's performance by making predictions on the XOR inputs and analyzing the results:

The evaluation code demonstrates how we use the trained parameters with model.apply() to generate predictions. We compare both the raw continuous outputs (which should be close to 0 or 1) and the binarized predictions using a 0.5 threshold. The final accuracy calculation uses vectorized operations to efficiently compute the percentage of correct predictions across all examples.

The evaluation results demonstrate perfect learning:

Our model has successfully learned the XOR function with 100% accuracy. The continuous predictions are very close to the target binary values: false cases (0,0) and (1,1) produce outputs near 0, while true cases (0,1) and (1,0) produce outputs near 1. This clear separation demonstrates that our MLP has learned meaningful internal representations that capture the XOR logic perfectly.

Comparing Flax & Optax with Pure JAX

Reflecting on this implementation compared to the pure JAX version from Course 3 reveals the dramatic power of high-level abstractions. Our Flax and Optax solution achieves equivalent computational results with significantly cleaner, more maintainable code:

  • Parameter Management: Flax's @nn.compact and automatic shape inference eliminates the custom initialize_mlp_params function and manual Xavier initialization
  • Forward Pass: Declarative layer definitions replace the explicit mlp_forward_pass function with manual loops and activation applications
  • Optimization: Built-in Adam optimizer with momentum and adaptive learning rates vs. simple gradient descent with manual jtu.tree_map updates
  • Training Loop: Clean, reusable train_step function encapsulates optimization logic instead of explicit jax.value_and_grad calls and manual gradient application

Note that the performance characteristics remain identical because Flax and Optax compile down to the same underlying JAX operations. We still benefit from JIT compilation, automatic differentiation, and hardware acceleration. However, our abstracted approach provides superior code organization, debugging capabilities, and extensibility. The pure JAX version would require significant refactoring to experiment with different architectures or optimizers, while our current structure supports these changes through simple parameter modifications. Most importantly, this structure scales effortlessly to production scenarios with complex architectures like convolutional networks and transformers.

Conclusion and Next Steps

Congratulations on successfully implementing your first complete Flax and Optax training pipeline! You've demonstrated how these powerful libraries work together to create elegant, efficient, and maintainable deep learning solutions. The XOR problem that once required extensive manual parameter management and gradient calculations now solves cleanly in less than 100 lines of readable code.

This training pipeline serves as a template for real-world applications. Whether you're tackling image classification, natural language processing, or reinforcement learning, the same fundamental pattern of model initialization, optimizer setup, training loops, and evaluation applies. In our next course of this path, we'll scale this foundation to tackle a real-world image classification project, demonstrating how the principles you've mastered today apply to practical machine learning challenges with larger datasets and more complex model architectures.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal