December 18, 2025
For the machine learning (ML) community, it is typical to rely on the PyTorch or Torch framework to interact with all kinds of modern ML models and even build them, since the Torch platform already has a strong documentation and dedicated open-source community. Thus, one may ask: Is there a need for an alternative framework?
The answer is yes! Although Torch makes it easy to deploy, train, and interact with ML models, there are some aspects in which it lacks and requires further improvements. Moreover, Torch favors object-oriented design tremendously, but some people would rather write their code in a more functional way. Such people may prefer the ML framework Jax created by Google.
In this post, I will focus on introducing some general concepts that people should know when it comes to utilizing Jax and also highlight some pitfalls of both Torch and Jax.
Disclaimer: I am not a guru of either one of the frameworks. I like them both. This post and subsequent ones will serve as reminders to my forgetting brain…
The first essential tool of Jax you should know is jax.vmap. Although Torch handles many things (like batching of inputs) behind the scene for you, it is still essential for you to be able to write and handle further complex batching requirements of your code.
For example, let’s assume we are working with a bunch of dynamical variables where you want to apply a function at different times and for different variables independently. How do you flexibly handle this aspect without writing an explicit for-loop that is slow?
Let’s assume that our input is \(\mathbf{x} \in \mathbb{R}^{T \times B \times D}\), where \(T\) is the number of timesteps, \(B\) is the number of samples, and \(D\) is their dimensionality.
import jax
import numpy as np
import jax.numpy as jnp
from functools import partial
def mat_prod(x, w):
'''Args:
x: Tensor of shape D
w: a learnable weight of shape D x K
Output:
y: a vector of shape K
'''
return jnp.einsum('d, d k -> k', x, w)
T = 100
B = 16
D = 32
K = 10
x = np.random.normal(0, 1, [T, B, D])
w = np.random.normal(0, 1, [D, K]) / (D ** 0.5)
print(jax.vmap(jax.vmap(partial(mat_prod, w=w)))(x))
Yes… just wrap jax.vmap twice to handle the batching over \(T\) and \(B\) axes. The caveat here is that jax.vmap will also pass over our weight w matrix which is of shape \(\mathbb{R}^{D \times K}\). To circumvent this technicality, we utilize the partial function to set w as our static variable such that jax.vmap will ignore it.
Side note: Jax will move your NumPy arrays to your GPU automatically.
Just-in-Time (JIT) compilation is a form of dynamic compilation which happens during execution of a program at run time rather than before execution. Since the compilation happens during run time, this aspect is ideal for an interpreted programming language, like Python, where the optimization of the code’s run time is extremely limited unlike that of C/C++. So, why do we need to care about JIT in ML?
Well, if we can compile our ML models, they will run faster and hopefully, more memory efficient. Jax was the first one to offer JIT way before Torch. Moreover, due to the fact Jax favors functional programming, the JIT compiler of Jax is much more flexible and less prone to compilation failures comparing to its counterpart, Torch.
Assume we are performing a typical matrix product:
def mat_prod(x, w):
'''Args:
x: a vector of shape D (representing a data sample)
w: a learnable weight of shape D x K
Output:
y: a vector of shape K
'''
return jnp.einsum('d, d k -> k', x, w)
B = 16
D = 32
K = 10
x = np.random.normal(0, 1, [B, D])
w = np.random.normal(0, 1, [D, K]) / (D ** 0.5)
print(jax.jit(jax.vmap(partial(mat_prod, w=w)(x))))
Compilation will first happen and then the compiled function is called. Meanwhile, in the case of Torch (v2.0.0+), our code would be
import torch
import torch.nn as nn
D = 32
K = 10
x = torch.randn(D)
mat_prod = nn.Linear(D, K)
compiled_mat_prod = torch.compile(mat_prod)
print(compiled_mat_prod(x, w))
At this point, I can imagine your face being something like the following…
Although the JIT compiler is getting better for Torch as time passes, a particular example in which I can illustrate the power of Jax is writing a simple Langevin dynamics MCMC sampler, which involves a neural network or an Energy-based model.
The whole point of this program is to extract the gradient \(\nabla_{x_t} E_\theta (x_t) \in \mathbb{R}^N\) of the model \(E_\theta (x_t) : \mathbb{R}^N \rightarrow \mathbb{R}\), with respect to the dynamical variable \(x_t\), and perform gradient descent on an energy landscape for a number of steps. The larger the number of steps, the closer the particle is to equilibrium. You can also adapt this code for the reverse process of diffusion models as well.
from typing import Callable
from functools import partial
def langevin(
x: jax.Array,
params: dict,
fn: Callable,
key: jax.Array,
step_size: float = 1e-4,
steps: int = 1000,
):
"""Args:
x: Batch of samples or particles (of shape N x D)
params: FrozenDict object of weights (obtained from initializing Flax module)
fn: Function from the Flax module
key: Random seed (e.g., jax.random.PRNGKey(42))
Output:
y: Updated particles (of shape N x D)
"""
compiled_fn = jax.vmap(jax.grad(fn, argnums=0))
def sample(i, vals):
x, key = vals
key, _ = jax.random.split(key)
dEdx = compiled_fn({'params': params}, x)
noise = (step_size ** 2) * jax.random.normal(key, shape=x.shape)
x = x - step_size * dEdx + noise
return x, key
y, key = jax.lax.fori_loop(0, steps, sample, (x, key))
return y
x = np.zeros(3, 32, 32) # Initial input to determine shape
model = Unet(...) # Create your model (Flax module)
params = model.init(key, x)["params"] # Initialize the weight (no vmap)
trained_params = train(model, params, ...) # Train the model
x = np.random.normal(0, 1, [128, 3, 32, 32]) # We have 128 priors to sample from...
key = jax.random.PRNGKey(42)
y = partial(
langevin,
fn=model, key=key, step_size=1e-4, steps=100)(x, trained_params) # We can use partial to set our static variables
In the above code, I assume we are using Flax, but please consider alternative libraries like Equinox too. Also, check out Patrick Kidger’s blog!
Anyhow, my point of showing you this example is when you do additional complicated things, like taking the Jacobian or some sort of auto-differentiation during inference. The Torch version of this algorithm
might break due to the fact we need to compile (just-in-time) for more complex functions: torch.func.jvp, torch.func.grad, and many more.
Overall, I enjoy writing Jax code since it forces me to know a lot of details (which Torch hides from you) and provides a much greater control for manipulating and training my ML models. Although Jax intimidates many people, at the end of the day this framework is a roided-up version of NumPy, taking advantage of GPUs. I think you can prototype your new ideas easily and run them at a much faster pace than Torch; it may save you time, or it may not. Use whichever framework that is comfortable to you!