Introduction

Welcome to the second lesson of our course "Advanced JAX: Transformations for Speed & Scale"! I'm delighted to continue this journey with you as we dive deeper into JAX's powerful transformation capabilities. In our previous lesson, you mastered JAX's functional approach to randomness, learning how to create, split, and manage PRNG keys for reproducible computations. That foundation of explicit state management — rather than relying on hidden global state — perfectly sets the stage for today's topic.

In this lesson, we'll explore automatic vectorization with jax.vmap, one of JAX's most elegant and powerful transformations. Just as you learned to explicitly manage randomness through keys, vmap allows you to explicitly control how computations are batched across data, transforming functions that operate on single examples into functions that efficiently process entire batches. This transformation is fundamental to high-performance scientific computing and machine learning, where we routinely need to apply the same operation to thousands or millions of data points.

By the end of this lesson, you'll understand how to use jax.vmap to automatically vectorize your functions, control batching behavior with the in_axes and out_axes parameters, and achieve significant performance improvements over manual Python loops. You'll also see how vmap seamlessly integrates with other JAX transformations like jit, creating a powerful toolkit for scalable computations.

Understanding the Batching Challenge
Sign up
Join the 1M+ learners on CodeSignal
Be a part of our community of 1M+ users who develop and demonstrate their skills on CodeSignal