Introduction

Welcome back to the third lesson of "JAX Fundamentals: NumPy Power-Up"! We're making fantastic progress on our journey to mastering JAX's most powerful capabilities. In our previous lessons, we explored JAX arrays and their immutable nature, then discovered how pure functions form the cornerstone of JAX's design. These concepts weren't just theoretical exercises — they were laying the groundwork for the truly transformative feature we'll explore today.

Today, we're diving into one of JAX's most celebrated features: automatic differentiation. This is where JAX begins to show its true power beyond just being a NumPy replacement. Automatic differentiation is the mathematical foundation that enables machine learning, optimization algorithms, and scientific computing applications to compute gradients effortlessly and accurately.

As you may recall from our previous lesson, pure functions are essential because they enable JAX's transformations to work reliably. Today, we'll see this principle in action as we use jax.grad to automatically compute derivatives of our pure functions. By the end of this lesson, you'll understand how to use jax.grad to compute gradients of scalar-output functions, evaluate these gradients at specific points, and even handle functions with multiple variables.

What is Automatic Differentiation?

Before we start computing gradients with code, let's understand what automatic differentiation actually is and why it's so revolutionary for numerical computing. In calculus, we learned to compute derivatives by hand using rules like the power rule, product rule, and chain rule. For simple functions like f(x)=x2f(x) = x^2, finding the derivative f(x)=2xf'(x) = 2x is straightforward. But imagine trying to compute derivatives by hand for the complex functions found in modern machine learning models — functions with millions of parameters and hundreds of layers of computations!

Your First Gradient with jax.grad

Now let's experience automatic differentiation in action with our first example. We'll start with a simple polynomial function and use jax.grad to compute its derivative automatically. This will demonstrate the basic workflow and show how JAX handles what we would normally do by hand in calculus.

The above code outputs:

Let's break down what's happening here step by step. First, we define our function f(x) as a pure function that takes a single input and returns a scalar value. Notice that this function satisfies all the requirements for automatic differentiation: it's pure, deterministic, and returns a scalar output.

The magic happens with jax.grad(f). This doesn't immediately compute a gradient — instead, it returns a new function grad_f that computes the derivative of our original function. When we call grad_f(2.0), JAX automatically applies the differentiation rules and returns the value of the derivative at x=2.0x = 2.0.

Computing Values and Gradients Together

In many practical applications, we need both the function's value and its gradient at the same point. Computing these separately would be inefficient, so JAX provides jax.value_and_grad to compute both simultaneously in a single, optimized pass through the computation:

Which outputs:

The value_and_grad function is particularly important for optimization algorithms, where we typically need both the loss value (to monitor training progress) and the gradient (to update parameters). Instead of making two separate function calls, we get both pieces of information from a single, efficient computation.

Notice that value_and_grad_f(3.0) returns a tuple containing both the function value and its gradient at x=3.0x = 3.0. The output shows us that at x=, the function evaluates to with a gradient of . We verify this result analytically: , confirming JAX's calculation once again.

Gradients of Multi-Variable Functions

Real-world functions often depend on multiple variables, and we need to compute partial derivatives with respect to different inputs. JAX handles this elegantly through the argnums parameter, which specifies which arguments to differentiate with respect to. By default, jax.grad differentiates with respect to the first argument (argnums=0).

This snippet results in the following output:

The argnums parameter gives us fine-grained control over which partial derivatives to compute. When argnums=0 (the default), jax.grad computes the partial derivative with respect to the first argument. When argnums=1, it computes the partial derivative with respect to the second argument.

Looking at our output, the function evaluates to at and . The partial derivatives are and , which match JAX's computed gradients exactly.

Conclusion and Next Steps

Congratulations! You've just mastered the fundamentals of automatic differentiation with jax.grad. You've learned how to transform any pure, scalar-output function into a gradient function using jax.grad, compute both function values and gradients efficiently with jax.value_and_grad, and handle multi-variable functions using the argnums parameter to specify which arguments to differentiate. Most importantly, you've seen how JAX transforms complex differentiation tasks that would be tedious by hand into simple, reliable computations that produce mathematically exact results.

This foundation in automatic differentiation opens the door to JAX's more advanced transformations and real-world applications. In our upcoming lessons, we'll explore how to combine jax.grad with other JAX transformations like jit for performance optimization, and we'll see how automatic differentiation forms the backbone of modern machine learning algorithms. The pure functions and gradient computations you've mastered today will be the building blocks for more sophisticated optimization and machine learning workflows ahead.

Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal