Introduction

Welcome to our course "JAX Fundamentals: NumPy Power-Up"! We're excited to begin on this journey with you as we explore one of the most powerful and elegant numerical computing libraries available today. This is the very first lesson in our comprehensive learning path, where we'll transform you from a JAX newcomer into someone who can confidently build and train neural networks using this remarkable framework.

This learning path consists of five courses that will take you from the fundamentals all the way to building real-world machine learning applications:

  1. JAX Fundamentals: NumPy Power-Up (our current course) — We'll master JAX's NumPy-like API, immutability, pure functions, automatic differentiation, and just-in-time (JIT) compilation.
  2. Advanced JAX: Transformations for Speed & Scale — We'll explore functional random number generation, vectorization with vmap, multi-device parallelism, and PyTrees.
  3. JAX in Action: Neural Networks from Scratch — We'll build a complete Multi-Layer Perceptron from scratch to solve the XOR problem.
  4. Beyond Pure JAX: Flax & Optax for Elegant ML — We'll refactor our neural networks using industry-standard libraries such as Flax and Optax for cleaner, more maintainable code.
  5. JAX in Action: Building an Image Classifier — We'll culminate our journey by building a real-world Convolutional Neural Network for image classification.

To get the most out of this learning path, we expect you to be comfortable with Python programming, basic NumPy operations (array creation, indexing, and mathematical operations), and fundamental linear algebra and calculus concepts (matrix multiplication, vectors, and basic calculus for understanding gradients). If you're solid on these foundations, you're ready to unlock the power of JAX!

Today, we're starting with the cornerstone of JAX: arrays. Think of JAX arrays as NumPy arrays' more sophisticated, immutable cousins that come with superpowers like automatic differentiation and lightning-fast compilation.

What is JAX and Why Should You Care?

Before starting, you might be wondering: why should I care about JAX? JAX stands as one of the most exciting developments in numerical computing and machine learning. Developed by Google Research, this framework brings together the familiar NumPy API we know and love with cutting-edge performance optimizations and functional programming principles that make it incredibly powerful for scientific computing and deep learning.

At its core, JAX is built on three fundamental pillars that set it apart from traditional NumPy:

  1. First, JAX embraces immutability — once we create an array, we cannot modify it in place. While this might seem restrictive at first, immutability enables powerful optimizations and makes our code more predictable and easier to reason about, especially in parallel computing environments.
  2. Second, JAX provides automatic differentiation (autodiff) that can compute gradients of any function we write. This capability is essential for machine learning, where we need to compute gradients to train our models. Unlike manual differentiation or finite differences, JAX's autodiff is both fast and numerically stable.
  3. Third, JAX includes advanced compilation capabilities that can dramatically accelerate our computations. We'll explore these performance features throughout our course, building from simple concepts to powerful optimizations.

What makes JAX particularly compelling is that it maintains NumPy's familiar interface while adding these superpowers. If you can write NumPy code, you're already 80% of the way to writing JAX code! The learning curve is gentle, but the performance gains and capabilities you'll unlock are transformational.

Creating JAX Arrays

Let's start our hands-on journey by learning how to create JAX arrays. Just as NumPy serves as the foundation for scientific computing in Python, JAX arrays serve as the foundation for everything we'll build in this course.

JAX provides the jax.numpy module (commonly imported as jnp) that mirrors NumPy's functionality. We can create JAX arrays from Python lists, NumPy arrays, or by using any of the familiar array creation functions:

When we run this code, we'll see that JAX arrays look and behave very similarly to NumPy arrays in terms of their string representation. The key difference lies under the hood — JAX arrays are designed for functional programming and come with the immutability guarantee we'll explore shortly.

Notice that we're using jnp.array() instead of np.array(). This is our gateway into the JAX ecosystem. The function signature and behavior are nearly identical to NumPy's array() function, making the transition seamless for anyone familiar with NumPy.

Basic Operations with JAX Arrays

Now that we can create JAX arrays, let's explore how to perform computations with them. JAX maintains the same intuitive syntax as NumPy for mathematical operations, making our transition smooth and natural.

When we execute this code, we'll see output like:

The beauty of JAX lies in this familiar interface. If you've written NumPy code before, this feels completely natural. We can perform element-wise addition with the + operator and matrix multiplication using jnp.dot(), just as we would in NumPy.

Behind the scenes, however, JAX is preparing these operations for potential optimization and making them compatible with automatic differentiation. This means that any function we write using these basic operations can later be automatically differentiated or optimized for enhanced performance.

The Power of Immutability

Here's where JAX diverges significantly from NumPy, and this difference is crucial to understand. JAX arrays are immutable, meaning that once created, they cannot be modified. This design choice might feel unfamiliar at first, but it's fundamental to JAX's functional programming paradigm and enables many of its powerful features.

Let's see what happens when we try to modify a JAX array in the traditional NumPy way:

When we run this code, JAX will raise a helpful error message telling us that arrays are immutable and pointing us toward the correct approach:

As suggested by the error message, instead of modifying arrays in place, JAX provides the .at[] method for functional updates:

This approach creates a new array with our desired changes while leaving the original array completely unchanged. The output demonstrates this clearly:

Immutable JAX against Mutable NumPy

To contrast this with NumPy's mutable behavior, let's recall how traditional NumPy arrays behave:

The NumPy array changes in place:

This mutable behavior can lead to unexpected side effects in complex programs. JAX's immutability eliminates these surprises and makes our code more predictable and safer for parallel execution.

Performance Insights: A Glimpse of JAX's Potential

One of the most notable performance optimization of JAX is known as just-in-time (JIT) compilation. While we'll explore JAX's JIT mechanism in much greater detail later in our course, let's get a taste of what's possible with a simple benchmark. This will help us understand why JAX has become so popular in the scientific computing community.

This benchmark reveals something fascinating about JAX's execution model and performance characteristics:

Notice the interesting pattern here! The first JAX run is significantly slower because it includes compilation overhead — JAX is analyzing our function and optimizing it for the specific hardware. However, the second run is actually faster than NumPy, demonstrating the power of JIT compilation.

This pattern is crucial to understand: JAX trades some initial setup time for dramatically improved performance on subsequent runs. For machine learning workloads where we repeatedly call the same functions (like forward passes through a neural network), this trade-off is incredibly beneficial. The .block_until_ready() call ensures our timing is accurate because JAX operations are asynchronous by default. This asynchronous execution is part of what enables JAX's performance optimizations.

Conclusion and Next Steps

Congratulations! We've taken our first steps into the JAX ecosystem and discovered the fundamental building blocks that make JAX so powerful. We've learned how to create JAX arrays from various sources, perform basic mathematical operations, and, most importantly, we've grasped the concept of immutability that sets JAX apart from NumPy. The .at[] method has shown us how functional programming principles can replace traditional in-place modifications.

As we continue our journey through this course, we'll build upon these fundamentals to explore automatic differentiation, dive deep into compilation techniques, and learn about JAX's pure function requirements. Each concept we learn will bring us closer to understanding why JAX has become the go-to choice for high-performance numerical computing and modern machine learning research. In our next lesson, we'll explore how JAX's emphasis on pure functions creates a more predictable and optimizable programming environment. But before moving forward, get ready for some practice! Happy coding!

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal