Nipun Batra
February 9, 2022
import torch from jax import grad import jax.numpy as jnp
x_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True) y_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)
x_jax = jnp.array(1.) y_jax = jnp.array(1.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
def loss(x, y): return x*x + y*y
l_torch = loss(x_torch, y_torch) l_torch
tensor(2., grad_fn=<AddBackward0>)
l_jax = loss(x_jax, y_jax)
l_torch.backward() x_torch.grad, y_torch.grad
(tensor(2.), tensor(2.))
grad_loss = grad(loss, argnums=[0, 1]) grad_loss(x_jax, y_jax)
(DeviceArray(2., dtype=float32, weak_type=True), DeviceArray(2., dtype=float32, weak_type=True))
def loss(theta): return theta.T@theta
theta_torch = torch.autograd.Variable(torch.tensor([1., 1.]), requires_grad=True)
theta_torch
tensor([1., 1.], requires_grad=True)
l = loss(theta_torch) l
tensor(2., grad_fn=<DotBackward0>)
l.backward() theta_torch.grad
tensor([2., 2.])
theta_jax = jnp.array([1., 1.])
loss(theta_jax)
DeviceArray(2., dtype=float32)
grad_loss = grad(loss, argnums=[0])
grad_loss(theta_jax)
(DeviceArray([2., 2.], dtype=float32),)