Autograd in JAX and PyTorch

ML
JAX
PyTorch
Author

Nipun Batra

Published

February 9, 2022

Basic Imports

import torch
from jax import grad
import jax.numpy as jnp

Creating scalar variables in PyTorch

x_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)
y_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)

Creating scalar variables in JAX

x_jax = jnp.array(1.)
y_jax = jnp.array(1.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Defining a loss on scalar inputs

def loss(x, y):
    return x*x + y*y

Computing the loss on PyTorch input

l_torch  = loss(x_torch, y_torch)
l_torch
tensor(2., grad_fn=<AddBackward0>)

Computing the loss on JAX input

l_jax = loss(x_jax, y_jax)

Computing the gradient on PyTorch input

l_torch.backward()
x_torch.grad, y_torch.grad
(tensor(2.), tensor(2.))

Computing the gradient on JAX input

grad_loss = grad(loss, argnums=[0, 1])
grad_loss(x_jax, y_jax)
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(2., dtype=float32, weak_type=True))

Repeating the same procedure as above for both libraries but instead using vector function

def loss(theta):
    return theta.T@theta
theta_torch = torch.autograd.Variable(torch.tensor([1., 1.]), requires_grad=True)
theta_torch
tensor([1., 1.], requires_grad=True)
l = loss(theta_torch)
l
tensor(2., grad_fn=<DotBackward0>)
l.backward()
theta_torch.grad
tensor([2., 2.])
theta_jax = jnp.array([1., 1.])
loss(theta_jax)
DeviceArray(2., dtype=float32)
grad_loss = grad(loss, argnums=[0])
grad_loss(theta_jax)
(DeviceArray([2., 2.], dtype=float32),)