Introduction

Welcome back to the second lesson of Beyond Pure JAX: Flax & Optax for Elegant ML! You've made excellent progress so far. Having mastered the foundational concepts of Flax modules with setup() and __call__() methods in our previous lesson, you're now ready to explore one of Flax's most essential building blocks: the Dense layer.

While creating custom modules gives us a deep understanding and flexibility, real-world machine learning relies heavily on proven, optimized components. The flax.linen.Dense layer is precisely such a component — a fully connected layer that handles weight initialization, bias terms, and efficient linear transformations automatically. Today, we'll construct a complete Multi-Layer Perceptron (MLP) using these built-in Dense layers, learning how to compose them elegantly while exploring two distinct approaches to module definition.

By the end of this lesson, you'll confidently build neural networks using Flax's high-level abstractions, understanding both the setup() method and the powerful @nn.compact decorator. We'll apply these techniques to construct an MLP capable of learning the XOR function, demonstrating how Flax simplifies complex model construction while maintaining JAX's functional programming principles.

Understanding Flax's Built-in Dense Layer

The flax.linen.Dense layer represents a fundamental building block in neural networks, implementing the ubiquitous linear transformation:

y=Wx+by = Wx + b

where WW is the weight matrix, xx is the input vector, and bb is the bias term. Unlike our custom implementations from the previous lesson, layers come with sophisticated initialization strategies, automatic shape inference, and optimized computation paths.

Building MLPs with the setup() Method

Let's construct our first MLP using the familiar setup() method pattern you learned previously. This approach explicitly defines all layers during module initialization, providing clear separation between structure definition and forward pass logic. Here's how we define an MLP class using this methodology:

In this setup() method, we create two Dense layers as instance attributes. The hidden_features and output_features are dataclass-style attributes that parameterize our module, making it reusable for different architectures. Notice how we explicitly name our layers using the name parameter — this creates meaningful parameter names in our model's parameter tree, making debugging and analysis much easier.

Defining the Forward Pass with __call__

The forward pass logic comes in the __call__ method, where we chain our layers together with activation functions:

This implementation demonstrates the power of composition in Flax. Each Dense layer call self.dense_hidden(x) performs the linear transformation, while nn.sigmoid(x) applies the activation function. The flow is intuitive: input → hidden layer → activation → output layer → final activation. The sequential application of transformations mirrors the mathematical formulation of neural networks, making the code both readable and maintainable.

The @nn.compact Alternative

Flax offers a more concise approach through the @nn.compact decorator, which allows us to define layers inline within the __call__ method. This approach can be more readable for certain architectures and enables dynamic layer creation based on input properties:

The @nn.compact decorator transforms our __call__ method into both a layer definer and a forward pass executor. When we write nn.Dense(features=self.hidden_features, name="HiddenLayer")(x), we're simultaneously creating the Dense layer and applying it to input x. This pattern is particularly useful for models where the architecture depends on input shapes or for rapid prototyping.

Both approaches produce identical functionality — the choice between setup() and @nn.compact often comes down to personal preference and specific use cases. The setup() method provides explicit structure definition, which is useful for complex architectures with shared layers, while @nn.compact offers more concise, mathematical notation-like code that's perfect for straightforward sequential models.

Initializing Parameters and Understanding Structure

Now let's instantiate our models and examine how Flax handles parameter initialization. We'll use again the XOR dataset to demonstrate the process, creating both MLP variants and initializing them with identical random keys to ensure reproducible comparisons:

The init() method performs the crucial task of parameter initialization. It takes a PRNG key and sample input data, uses the input to infer shapes, and creates the initial parameter values. Let's examine what these initialized parameters look like:

The parameter structure reveals Flax's automatic organization. Each Dense layer contributes a kernel (weight matrix) and bias vector. The HiddenLayer has a kernel with shape (2, 3) — transforming 2 input features to 3 hidden units — and a bias of shape (3,). Similarly, the OutputLayer transforms 3 hidden units to 1 output with shapes and , respectively. This hierarchical organization makes it easy to inspect, modify, or save specific parts of the model.

Forward Passes and Model Verification

With parameters initialized, we can perform forward passes using the apply() method. This method takes the parameter dictionary and input data, executing the forward computation we defined in __call__. Semantically, apply() represents Flax's functional approach to neural networks: it treats our module as a pure function and applies it to specific parameter values and input data, maintaining JAX's stateless programming paradigm.

Let's examine the predictions from both models to verify they produce identical results:

The identical parameter structures and predictions confirm that both setup() and @nn.compact approaches are functionally equivalent. The models haven't been trained yet, so the predictions are essentially random outputs from the initialized network; this demonstrates that our MLP architecture is correctly constructed and ready for training.

Conclusion and Next Steps

Congratulations on mastering MLPs with Flax's Dense layers! You've successfully learned two powerful approaches for building neural networks — the explicit setup() method and the concise @nn.compact decorator — understanding how each offers unique advantages while producing identical functionality. The automatic parameter management, meaningful naming conventions, and clean composition patterns you've explored today form the foundation for building sophisticated deep learning models.

In our next lesson, we'll bring these static architectures to life by integrating Optax optimizers to train your MLP on the XOR problem. You'll discover how Flax and Optax work together seamlessly, transforming random predictions into perfect classification through gradient descent, and finally witnessing the power of the JAX ecosystem in action.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal