intermediate
Introduction to JAX for Deep Learning
Machine Learning
5 courses
91 practices
12 hours
Master JAX from the ground up! This path takes you from NumPy-like basics and automatic differentiation, through advanced batching and PyTrees, to building and training deep neural networks with Flax and Optax—culminating in a real-world image classification project.
See courses
Verified skills you'll gain
Badge for Coding and Data Algorithms, Intermediate
INTERMEDIATE
Coding and Data Algorithms
Badge for Deep Learning and Neural Networks, Advanced
ADVANCED
Deep Learning and Neural Networks
Tools you'll use
Flax
JAX
Optax
Python