import torch
from jax import grad
import jax.numpy as jnpBasic Imports
Creating scalar variables in PyTorch
x_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)
y_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)Creating scalar variables in JAX
x_jax = jnp.array(1.)
y_jax = jnp.array(1.)WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Defining a loss on scalar inputs
def loss(x, y):
return x*x + y*yComputing the loss on PyTorch input
l_torch = loss(x_torch, y_torch)
l_torchtensor(2., grad_fn=<AddBackward0>)
Computing the loss on JAX input
l_jax = loss(x_jax, y_jax)Computing the gradient on PyTorch input
l_torch.backward()
x_torch.grad, y_torch.grad(tensor(2.), tensor(2.))
Computing the gradient on JAX input
grad_loss = grad(loss, argnums=[0, 1])
grad_loss(x_jax, y_jax)(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(2., dtype=float32, weak_type=True))
Repeating the same procedure as above for both libraries but instead using vector function
def loss(theta):
return theta.T@thetatheta_torch = torch.autograd.Variable(torch.tensor([1., 1.]), requires_grad=True)theta_torchtensor([1., 1.], requires_grad=True)
l = loss(theta_torch)
ltensor(2., grad_fn=<DotBackward0>)
l.backward()
theta_torch.gradtensor([2., 2.])
theta_jax = jnp.array([1., 1.])loss(theta_jax)DeviceArray(2., dtype=float32)
grad_loss = grad(loss, argnums=[0])grad_loss(theta_jax)(DeviceArray([2., 2.], dtype=float32),)