import torch
import torch.autograd.functional as F
import torch.distributions as dist
import jax.numpy as jnp
import jax
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'Calculus
Derivative
def f1(x):
return 3*x**2
x = torch.tensor(2.0, requires_grad=True)
# Torch version 1 (using .backward)
z = f1(x)
z.backward()
print("Using backwards in torch", x.grad)
# Torch version 2 (using autograd.grad)
print("Using autograd.grad in torch", torch.autograd.grad(f1(x), x)[0])
# Jax version
print("Using grad in jax", jax.grad(f1)(jnp.array(2.0)))Using backwards in torch tensor(12.)
Using autograd.grad in torch tensor(12.)
Using grad in jax 12.0
Partial Derivative
def f2(x, y):
return 2*x**2 + 3*y
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.5, requires_grad=True)
# Torch version 1 (using .backward)
z = f2(x, y)
z.backward()
print("\nUsing Method 1 Torch")
print("Partial wrt x: ", x.grad)
print("Partial wrt y: ", y.grad)
# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Partial wrt x: ", torch.autograd.grad(f2(x, y), x)[0])
print("Partial wrt y: ", torch.autograd.grad(f2(x, y), y)[0])
# Jax version
print("\nUsing Jax")
print("Partial wrt x: ", jax.grad(f2, argnums=0)(jnp.array(2.0), jnp.array(1.5)))
print("Partial wrt y: ", jax.grad(f2, argnums=1)(jnp.array(2.0), jnp.array(1.5)))
Using Method 1 Torch
Partial wrt x: tensor(8.)
Partial wrt y: tensor(3.)
Using Method 2 Torch
Partial wrt x: tensor(8.)
Partial wrt y: tensor(3.)
Using Jax
Partial wrt x: 8.0
Partial wrt y: 3.0
Gradient
# Torch version 1 (using .backward)
grad_f2_v1 = torch.tensor([x.grad, y.grad])
print("\nUsing Method 1 Torch")
print(grad_f2_v1)
# Torch version 2 (using autograd.grad)
grad_f2_v2 = torch.tensor(torch.autograd.grad(f2(x, y), (x, y)))
print("\nUsing Method 2 Torch")
print(grad_f2_v2)
# Jax version
grad_f2_jax = jax.grad(f2, argnums=[0, 1])(jnp.array(2.0), jnp.array(1.5))
print("\nUsing Jax")
print(grad_f2_jax)
Using Method 1 Torch
tensor([8., 3.])
Using Method 2 Torch
tensor([8., 3.])
Using Jax
(Array(8., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
Gradient (vectorized)
def f2_vectorized(input):
x, y = input
return 2*x**2 + 3*y
input = torch.tensor([2.0, 1.5], requires_grad=True)
# Torch version 1 (using .backward)
z = f2_vectorized(input)
z.backward()
print("\nUsing Method 1 Torch")
print("Gradient: ", input.grad)
# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Gradient: ", torch.autograd.grad(f2_vectorized(input), input))
# Jax version
print("\nUsing Jax")
print("Gradient: ", jax.grad(f2_vectorized)(jnp.array([2.0, 1.5])))
Using Method 1 Torch
Gradient: tensor([8., 3.])
Using Method 2 Torch
Gradient: (tensor([8., 3.]),)
Using Jax
Gradient: [8. 3.]
Jacobian
# We take the Jacobian of the function f(x, y, z) = [x**2 + y**2, y - z]
# The Jacobian analytically is [[2x, 2y, 0], [0, 1, -1]]
def f1(x, y, z):
return x**2 + y**2
def f2(x, y, z):
return y - z
def f(x, y, z):
return torch.stack([f1(x, y, z), f2(x, y, z)])
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.0, requires_grad=True)
z = torch.tensor(3.0, requires_grad=True)
output = f(x, y, z)
outputtensor([ 5., -2.], grad_fn=<StackBackward0>)
If we are to directly call: z.backward() we get the following error.
try:
output.backward()
except Exception as e:
print(e)grad can be implicitly created only for scalar outputs
But, output.backward() has a parameter called gradient
The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying
gradient. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t.self.
import torch.autograd.functional as F
jacobian = torch.vstack(F.jacobian(f, (x, y, z))).T
print("Jacobian with Functional method:")
print(jacobian)
# use torch.autograd.grad
jacobian = torch.zeros(2, 3)
for output_index in range(2):
jacobian[output_index] = torch.vstack(torch.autograd.grad(f(x, y, z)[output_index], (x, y, z))).ravel()
print("Jacobian with grad method:")
print(jacobian)Jacobian with Functional method:
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
Jacobian with grad method:
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
# Jax version using jax.jacobian
def f_jax(x, y, z):
return jnp.stack([f1(x, y, z), f2(x, y, z)])
print("Jax Jacobian using jax.jacobian:")
print(jnp.array(jax.jacobian(f_jax, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0))).T)
g1 = jnp.array(jax.grad(f1, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
g2 = jnp.array(jax.grad(f2, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Jacobian using jax.grad done manually:")
print(jnp.vstack([g1.T, g2.T]))Jax Jacobian using jax.jacobian:
[[ 4. 2. -0.]
[ 0. 1. -1.]]
Jax Jacobian using jax.grad done manually:
[[ 4. 2. 0.]
[ 0. 1. -1.]]
Jacobian (vectorized)
def f_vectorized(input):
x, y, z = input
return torch.stack([f1(x, y, z), f2(x, y, z)])
print("Torch Functional method")
print(F.jacobian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))
def f_vectorized_jax(input):
x, y, z = input
return jnp.stack([f1(x, y, z), f2(x, y, z)])
print("Jax Jacobian using jax.jacobian:")
print(jax.jacobian(f_vectorized_jax)(jnp.array([2.0, 1.0, 3.0])))Torch Functional method
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
Jax Jacobian using jax.jacobian:
[[ 4. 2. 0.]
[ 0. 1. -1.]]
Hessian
def f(x, y, z):
return x**2 + y**2 + x * y * z
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.0, requires_grad=True)
z = torch.tensor(3.0, requires_grad=True)
# Torch version using autograd.functional.hessian
print("Hessian using autograd.functional.hessian:")
torch_v1_hessian = torch.tensor(F.hessian(f, (x, y, z)))
print(torch_v1_hessian)
# Jax version using jax.hessian
jax_hessian = jnp.array(jax.hessian(f, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Hessian using jax.hessian:")
print(jax_hessian)
# Jax version using jax.jacobian
jacobian_fn = jax.jacobian(f, argnums=[0, 1, 2])
hessian_fn = jax.jacobian(jacobian_fn, argnums=[0, 1, 2])
jax_hessian = jnp.array(hessian_fn(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Hessian using jax.jacobian:")
print(jax_hessian)Hessian using autograd.functional.hessian:
tensor([[2., 3., 1.],
[3., 2., 2.],
[1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Hessian (vectorized)
def f_vectorized(input):
x, y, z = input
return x**2 + y**2 + x * y * z
print("Torch Functional method")
print(F.hessian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))
print("Jax Hessian using jax.hessian:")
print(jax.hessian(f_vectorized)(jnp.array([2.0, 1.0, 3.0])))
print("Jax Hessian using jax.jacobian:")
print(jax.jacobian(jax.jacobian(f_vectorized))(jnp.array([2.0, 1.0, 3.0])))Torch Functional method
tensor([[2., 3., 1.],
[3., 2., 2.],
[1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]