Autograd in JAX and PyTorch

ML
JAX
PyTorch
Author

Nipun Batra

Published

February 9, 2022

Basic Imports

Code
import torch
from jax import grad
import jax.numpy as jnp

Creating scalar variables in PyTorch

Code
x_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)
y_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)

Creating scalar variables in JAX

Code
x_jax = jnp.array(1.)
y_jax = jnp.array(1.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Defining a loss on scalar inputs

Code
def loss(x, y):
    return x*x + y*y

Computing the loss on PyTorch input

Code
l_torch  = loss(x_torch, y_torch)
l_torch
tensor(2., grad_fn=<AddBackward0>)

Computing the loss on JAX input

Code
l_jax = loss(x_jax, y_jax)

Computing the gradient on PyTorch input

Code
l_torch.backward()
x_torch.grad, y_torch.grad
(tensor(2.), tensor(2.))

Computing the gradient on JAX input

Code
grad_loss = grad(loss, argnums=[0, 1])
grad_loss(x_jax, y_jax)
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(2., dtype=float32, weak_type=True))

Repeating the same procedure as above for both libraries but instead using vector function

Code
def loss(theta):
    return theta.T@theta
Code
theta_torch = torch.autograd.Variable(torch.tensor([1., 1.]), requires_grad=True)
Code
theta_torch
tensor([1., 1.], requires_grad=True)
Code
l = loss(theta_torch)
l
tensor(2., grad_fn=<DotBackward0>)
Code
l.backward()
theta_torch.grad
tensor([2., 2.])
Code
theta_jax = jnp.array([1., 1.])
Code
loss(theta_jax)
DeviceArray(2., dtype=float32)
Code
grad_loss = grad(loss, argnums=[0])
Code
grad_loss(theta_jax)
(DeviceArray([2., 2.], dtype=float32),)