Welcome back to the third lesson of "JAX Fundamentals: NumPy Power-Up"! We're making fantastic progress on our journey to mastering JAX's most powerful capabilities. In our previous lessons, we explored JAX arrays and their immutable nature, then discovered how pure functions form the cornerstone of JAX's design. These concepts weren't just theoretical exercises — they were laying the groundwork for the truly transformative feature we'll explore today.
Today, we're diving into one of JAX's most celebrated features: automatic differentiation. This is where JAX begins to show its true power beyond just being a NumPy replacement. Automatic differentiation is the mathematical foundation that enables machine learning, optimization algorithms, and scientific computing applications to compute gradients effortlessly and accurately.
As you may recall from our previous lesson, pure functions are essential because they enable JAX's transformations to work reliably. Today, we'll see this principle in action as we use jax.grad
to automatically compute derivatives of our pure functions. By the end of this lesson, you'll understand how to use jax.grad
to compute gradients of scalar-output functions, evaluate these gradients at specific points, and even handle functions with multiple variables.
Before we start computing gradients with code, let's understand what automatic differentiation actually is and why it's so revolutionary for numerical computing. In calculus, we learned to compute derivatives by hand using rules like the power rule, product rule, and chain rule. For simple functions like , finding the derivative is straightforward. But imagine trying to compute derivatives by hand for the complex functions found in modern machine learning models — functions with millions of parameters and hundreds of layers of computations!
Now let's experience automatic differentiation in action with our first example. We'll start with a simple polynomial function and use jax.grad
to compute its derivative automatically. This will demonstrate the basic workflow and show how JAX handles what we would normally do by hand in calculus.
The above code outputs:
Let's break down what's happening here step by step. First, we define our function f(x)
as a pure function that takes a single input and returns a scalar value. Notice that this function satisfies all the requirements for automatic differentiation: it's pure, deterministic, and returns a scalar output.
The magic happens with jax.grad(f)
. This doesn't immediately compute a gradient — instead, it returns a new function grad_f
that computes the derivative of our original function. When we call grad_f(2.0)
, JAX automatically applies the differentiation rules and returns the value of the derivative at .
In many practical applications, we need both the function's value and its gradient at the same point. Computing these separately would be inefficient, so JAX provides jax.value_and_grad
to compute both simultaneously in a single, optimized pass through the computation:
Which outputs:
The value_and_grad
function is particularly important for optimization algorithms, where we typically need both the loss value (to monitor training progress) and the gradient (to update parameters). Instead of making two separate function calls, we get both pieces of information from a single, efficient computation.
Notice that value_and_grad_f(3.0)
returns a tuple containing both the function value and its gradient at . The output shows us that at , the function evaluates to with a gradient of . We verify this result analytically: , confirming JAX's calculation once again.
Real-world functions often depend on multiple variables, and we need to compute partial derivatives with respect to different inputs. JAX handles this elegantly through the argnums
parameter, which specifies which arguments to differentiate with respect to. By default, jax.grad
differentiates with respect to the first argument (argnums=0
).
This snippet results in the following output:
The argnums
parameter gives us fine-grained control over which partial derivatives to compute. When argnums=0
(the default), jax.grad
computes the partial derivative with respect to the first argument. When argnums=1
, it computes the partial derivative with respect to the second argument.
Looking at our output, the function evaluates to at and . The partial derivatives are and , which match JAX's computed gradients exactly.
Congratulations! You've just mastered the fundamentals of automatic differentiation with jax.grad
. You've learned how to transform any pure, scalar-output function into a gradient function using jax.grad
, compute both function values and gradients efficiently with jax.value_and_grad
, and handle multi-variable functions using the argnums
parameter to specify which arguments to differentiate. Most importantly, you've seen how JAX transforms complex differentiation tasks that would be tedious by hand into simple, reliable computations that produce mathematically exact results.
This foundation in automatic differentiation opens the door to JAX's more advanced transformations and real-world applications. In our upcoming lessons, we'll explore how to combine jax.grad
with other JAX transformations like jit
for performance optimization, and we'll see how automatic differentiation forms the backbone of modern machine learning algorithms. The pure functions and gradient computations you've mastered today will be the building blocks for more sophisticated optimization and machine learning workflows ahead.
