Welcome back to the second lesson of Beyond Pure JAX: Flax & Optax for Elegant ML! You've made excellent progress so far. Having mastered the foundational concepts of Flax modules with setup()
and __call__()
methods in our previous lesson, you're now ready to explore one of Flax's most essential building blocks: the Dense
layer.
While creating custom modules gives us a deep understanding and flexibility, real-world machine learning relies heavily on proven, optimized components. The flax.linen.Dense
layer is precisely such a component — a fully connected layer that handles weight initialization, bias terms, and efficient linear transformations automatically. Today, we'll construct a complete Multi-Layer Perceptron (MLP) using these built-in Dense
layers, learning how to compose them elegantly while exploring two distinct approaches to module definition.
By the end of this lesson, you'll confidently build neural networks using Flax's high-level abstractions, understanding both the setup()
method and the powerful @nn.compact
decorator. We'll apply these techniques to construct an MLP capable of learning the XOR function, demonstrating how Flax simplifies complex model construction while maintaining JAX's functional programming principles.
The flax.linen.Dense
layer represents a fundamental building block in neural networks, implementing the ubiquitous linear transformation:
where is the weight matrix, is the input vector, and is the bias term. Unlike our custom implementations from the previous lesson, layers come with sophisticated initialization strategies, automatic shape inference, and optimized computation paths.
Let's construct our first MLP using the familiar setup()
method pattern you learned previously. This approach explicitly defines all layers during module initialization, providing clear separation between structure definition and forward pass logic. Here's how we define an MLP class using this methodology:
In this setup()
method, we create two Dense
layers as instance attributes. The hidden_features
and output_features
are dataclass-style attributes that parameterize our module, making it reusable for different architectures. Notice how we explicitly name our layers using the name
parameter — this creates meaningful parameter names in our model's parameter tree, making debugging and analysis much easier.
The forward pass logic comes in the __call__
method, where we chain our layers together with activation functions:
This implementation demonstrates the power of composition in Flax. Each Dense
layer call self.dense_hidden(x)
performs the linear transformation, while nn.sigmoid(x)
applies the activation function. The flow is intuitive: input → hidden layer → activation → output layer → final activation. The sequential application of transformations mirrors the mathematical formulation of neural networks, making the code both readable and maintainable.
Flax offers a more concise approach through the @nn.compact
decorator, which allows us to define layers inline within the __call__
method. This approach can be more readable for certain architectures and enables dynamic layer creation based on input properties:
The @nn.compact
decorator transforms our __call__
method into both a layer definer and a forward pass executor. When we write nn.Dense(features=self.hidden_features, name="HiddenLayer")(x)
, we're simultaneously creating the Dense
layer and applying it to input x
. This pattern is particularly useful for models where the architecture depends on input shapes or for rapid prototyping.
Both approaches produce identical functionality — the choice between setup()
and @nn.compact
often comes down to personal preference and specific use cases. The setup()
method provides explicit structure definition, which is useful for complex architectures with shared layers, while @nn.compact
offers more concise, mathematical notation-like code that's perfect for straightforward sequential models.
Now let's instantiate our models and examine how Flax handles parameter initialization. We'll use again the XOR dataset to demonstrate the process, creating both MLP variants and initializing them with identical random keys to ensure reproducible comparisons:
The init()
method performs the crucial task of parameter initialization. It takes a PRNG key and sample input data, uses the input to infer shapes, and creates the initial parameter values. Let's examine what these initialized parameters look like:
The parameter structure reveals Flax's automatic organization. Each Dense
layer contributes a kernel
(weight matrix) and bias
vector. The HiddenLayer
has a kernel with shape (2, 3)
— transforming 2 input features to 3 hidden units — and a bias of shape (3,)
. Similarly, the OutputLayer
transforms 3 hidden units to 1 output with shapes and , respectively. This hierarchical organization makes it easy to inspect, modify, or save specific parts of the model.
With parameters initialized, we can perform forward passes using the apply()
method. This method takes the parameter dictionary and input data, executing the forward computation we defined in __call__
. Semantically, apply()
represents Flax's functional approach to neural networks: it treats our module as a pure function and applies it to specific parameter values and input data, maintaining JAX's stateless programming paradigm.
Let's examine the predictions from both models to verify they produce identical results:
The identical parameter structures and predictions confirm that both setup()
and @nn.compact
approaches are functionally equivalent. The models haven't been trained yet, so the predictions are essentially random outputs from the initialized network; this demonstrates that our MLP architecture is correctly constructed and ready for training.
Congratulations on mastering MLPs with Flax's Dense
layers! You've successfully learned two powerful approaches for building neural networks — the explicit setup()
method and the concise @nn.compact
decorator — understanding how each offers unique advantages while producing identical functionality. The automatic parameter management, meaningful naming conventions, and clean composition patterns you've explored today form the foundation for building sophisticated deep learning models.
In our next lesson, we'll bring these static architectures to life by integrating Optax optimizers to train your MLP on the XOR problem. You'll discover how Flax and Optax work together seamlessly, transforming random predictions into perfect classification through gradient descent, and finally witnessing the power of the JAX ecosystem in action.
