Welcome back to JAX in Action: Building an Image Classifier! In our first lesson, we established a solid foundation by creating implementing robust data loading with the MNIST dataset using tfds
. Now that we have clean, normalized image data flowing through our pipeline, we're ready to tackle the next exciting challenge: building our first convolutional neural network.
Today's lesson focuses on crafting a CNN architecture using Flax
's elegant Linen API. We'll discover how to combine convolutional layers, activation functions, and pooling operations to create a network specifically designed for image recognition tasks. As you may recall from our earlier courses, Flax
provides a clean, functional approach to building neural networks, and CNNs showcase this elegance beautifully. By the end of this lesson, you'll have implemented a complete CNN that can process our MNIST images and produce classification logits, setting the stage for training in upcoming lessons.
Before we write any code, let's build intuition about why CNNs revolutionized computer vision. Traditional dense neural networks treat images as flat vectors, losing critical spatial relationships between pixels. Imagine trying to recognize a handwritten "8" by examining individual pixel values randomly scattered without their positions: the curved loops that define an "8" would be completely lost in this chaos!
Convolutional layers preserve these vital spatial relationships by applying small filters (kernels
) that scan across the image systematically. Think of each filter as a pattern detector: one might specialize in finding vertical edges, another in detecting curves, and yet another in identifying corners. As these filters slide across your image, they produce feature maps that light up wherever their specific pattern appears. The magic happens when you stack multiple filters: the first layer might detect simple edges, but by combining these edge detectors, deeper layers can recognize complete shapes like loops or lines.
Pooling layers act as intelligent summarizers that complement convolutions perfectly. After detecting features, pooling reduces the spatial resolution while keeping the most important information. Max pooling, for instance, looks at small regions (like 2×2 pixels) and keeps only the strongest signal: if any part of that region detected an edge strongly, that's what matters. This makes our network robust to small shifts in the input (a "7" is still a "7" even if shifted slightly) while dramatically reducing computation. This hierarchical feature extraction mimics how our own visual system works: from detecting edges to recognizing shapes to understanding complete objects.
You may recall that Flax
offers two elegant approaches for defining neural network modules: the setup()
method and the @nn.compact
decorator. For our CNN, we'll embrace the @nn.compact
approach, which allows us to define the entire network architecture within a single, flowing __call__
method. This approach shines particularly brightly for feedforward architectures like our CNN, where data flows naturally from input to output.
This foundation establishes our CNN
class as a Flax
module with a configurable number of output classes (defaulting to 10 for MNIST's digits 0-9). The @nn.compact
decorator transforms our __call__
method into a complete module definition, enabling us to build and apply layers in one seamless step.
Now let's implement the heart of our CNN: the convolutional blocks that will extract meaningful features from our handwritten digits. We'll craft two convolutional blocks, each combining convolution, activation, and pooling to progressively refine our understanding of the image:
Let's focus on the first convolutional layer:
- Uses 32 different 3×3 pixel filters, with each filter learning to detect distinct low-level patterns (like edges or corners).
- The 3×3 kernel size is chosen because it is large enough to capture meaningful patterns, while being small enough to be computationally efficient.
padding='SAME'
ensures output spatial dimensions match the input by adding zeros around the edges, preserving border information.- After convolution,
nn.relu
introduces non-linearity, enabling the network to learn complex, non-linear patterns.
After applying convolution, we have a max pooling operation:
window_shape=(2, 2)
defines the pooling window size: each 2×2 pixel region is examined to find the maximum value.strides=(2, 2)
controls how far the pooling window moves between operations; in this case, we are stepping 2 pixels in each direction, which means no overlap between pooling windows.- Together, these parameters halve both the height and width of the feature maps (since we move the 2×2 window by 2 pixels each time).
- Retains the strongest activations within each region, making the network robust to small shifts in the input.
The second convolutional block mirrors the first block’s structure but doubles the feature count to 64, allowing the network to combine features from the first block into more sophisticated patterns. This progressive deepening and spatial reduction, where each block increases the number of features while reducing spatial resolution, forms the backbone of effective CNN design. By trading spatial detail for increasingly abstract and powerful feature representations, the network becomes capable of recognizing complex patterns necessary for accurate image classification.
After our convolutional blocks have extracted rich spatial features, we need to transform these distributed patterns into concrete classification decisions. This transformation requires flattening our 2D feature maps and passing them through fully connected layers that can learn complex decision boundaries:
The reshape operation jnp.reshape(x, (x.shape[0], -1))
performs a crucial transformation: it takes our 3D tensor of feature maps (batch × height × width × channels) and flattens everything except the batch dimension into a long vector. The -1
is JAX's clever way of saying "figure out this dimension automatically", which turns out incredibly useful when our spatial dimensions might vary.
Our 128-neuron dense layer acts as a sophisticated pattern combiner, learning to synthesize the spatial features into higher-level concepts. Think of it as a committee of 128 experts, each looking at all the extracted features and voting on what digit they represent. The final output layer produces raw logits — one score per digit class. Notice the deliberate absence of an activation function here; we'll let our loss function handle the conversion from logits to probabilities during training, maintaining numerical stability and giving us maximum flexibility.
With our CNN architecture complete, let's bring it to life by initializing its parameters and performing a test forward pass. This crucial step verifies our design before we invest time in training:
This initialization showcases Flax
's functional philosophy beautifully. Unlike frameworks where models carry mutable state, Flax cleanly separates architecture from parameters. The model.init()
call traces through our entire network with dummy data, automatically determining parameter shapes and initializing them using our random key. The resulting variables
dictionary organizes all trainable parameters by layer name — a clean, inspectable structure that makes debugging and analysis straightforward.
Let's complete our architectural validation by examining the initialized parameters and confirming our forward pass produces the expected output:
Running our complete test produces the following output:
Our first convolutional layer's kernel has shape (3, 3, 1, 32)
, which is exactly what we specified: 3×3 spatial dimensions, 1 input channel for grayscale images, and 32 output feature maps. The forward pass yields logits with shape (1, 10)
— one example in our batch producing 10 class scores, perfect for our digit classification task. This successful test confirms our CNN correctly transforms 28×28 grayscale images into classification decisions, ready for the training phase ahead.
Congratulations on implementing a complete convolutional neural network using Flax
's expressive Linen API! You've successfully combined the fundamental building blocks of modern computer vision: convolutional layers for hierarchical feature extraction, pooling for translation invariance and efficiency, and dense layers for final classification. Our modular CNN architecture follows industry best practices while remaining clear and maintainable.
The journey from raw pixels to classification logits is now complete, but our CNN remains untrained, like a musical instrument waiting to be played. In our next lesson, we'll breathe life into this architecture by implementing a full training pipeline with loss functions, optimization strategies, and evaluation metrics. The transformation from static architecture to learning system awaits, bringing us one step closer to a CNN that can actually recognize handwritten digits with impressive accuracy!
