import torch
import torch.autograd.functional as F
import torch.distributions as dist
import jax.numpy as jnp
import jax
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
Calculus
Derivative
def f1(x):
return 3*x**2
= torch.tensor(2.0, requires_grad=True)
x
# Torch version 1 (using .backward)
= f1(x)
z
z.backward()print("Using backwards in torch", x.grad)
# Torch version 2 (using autograd.grad)
print("Using autograd.grad in torch", torch.autograd.grad(f1(x), x)[0])
# Jax version
print("Using grad in jax", jax.grad(f1)(jnp.array(2.0)))
Using backwards in torch tensor(12.)
Using autograd.grad in torch tensor(12.)
Using grad in jax 12.0
Partial Derivative
def f2(x, y):
return 2*x**2 + 3*y
= torch.tensor(2.0, requires_grad=True)
x = torch.tensor(1.5, requires_grad=True)
y
# Torch version 1 (using .backward)
= f2(x, y)
z
z.backward()print("\nUsing Method 1 Torch")
print("Partial wrt x: ", x.grad)
print("Partial wrt y: ", y.grad)
# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Partial wrt x: ", torch.autograd.grad(f2(x, y), x)[0])
print("Partial wrt y: ", torch.autograd.grad(f2(x, y), y)[0])
# Jax version
print("\nUsing Jax")
print("Partial wrt x: ", jax.grad(f2, argnums=0)(jnp.array(2.0), jnp.array(1.5)))
print("Partial wrt y: ", jax.grad(f2, argnums=1)(jnp.array(2.0), jnp.array(1.5)))
Using Method 1 Torch
Partial wrt x: tensor(8.)
Partial wrt y: tensor(3.)
Using Method 2 Torch
Partial wrt x: tensor(8.)
Partial wrt y: tensor(3.)
Using Jax
Partial wrt x: 8.0
Partial wrt y: 3.0
Gradient
# Torch version 1 (using .backward)
= torch.tensor([x.grad, y.grad])
grad_f2_v1 print("\nUsing Method 1 Torch")
print(grad_f2_v1)
# Torch version 2 (using autograd.grad)
= torch.tensor(torch.autograd.grad(f2(x, y), (x, y)))
grad_f2_v2 print("\nUsing Method 2 Torch")
print(grad_f2_v2)
# Jax version
= jax.grad(f2, argnums=[0, 1])(jnp.array(2.0), jnp.array(1.5))
grad_f2_jax print("\nUsing Jax")
print(grad_f2_jax)
Using Method 1 Torch
tensor([8., 3.])
Using Method 2 Torch
tensor([8., 3.])
Using Jax
(Array(8., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
Gradient (vectorized)
def f2_vectorized(input):
= input
x, y return 2*x**2 + 3*y
input = torch.tensor([2.0, 1.5], requires_grad=True)
# Torch version 1 (using .backward)
= f2_vectorized(input)
z
z.backward()print("\nUsing Method 1 Torch")
print("Gradient: ", input.grad)
# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Gradient: ", torch.autograd.grad(f2_vectorized(input), input))
# Jax version
print("\nUsing Jax")
print("Gradient: ", jax.grad(f2_vectorized)(jnp.array([2.0, 1.5])))
Using Method 1 Torch
Gradient: tensor([8., 3.])
Using Method 2 Torch
Gradient: (tensor([8., 3.]),)
Using Jax
Gradient: [8. 3.]
Jacobian
# We take the Jacobian of the function f(x, y, z) = [x**2 + y**2, y - z]
# The Jacobian analytically is [[2x, 2y, 0], [0, 1, -1]]
def f1(x, y, z):
return x**2 + y**2
def f2(x, y, z):
return y - z
def f(x, y, z):
return torch.stack([f1(x, y, z), f2(x, y, z)])
= torch.tensor(2.0, requires_grad=True)
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z
= f(x, y, z)
output output
tensor([ 5., -2.], grad_fn=<StackBackward0>)
If we are to directly call: z.backward()
we get the following error.
try:
output.backward()except Exception as e:
print(e)
grad can be implicitly created only for scalar outputs
But, output.backward()
has a parameter called gradient
The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying
gradient
. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t.self
.
import torch.autograd.functional as F
= torch.vstack(F.jacobian(f, (x, y, z))).T
jacobian print("Jacobian with Functional method:")
print(jacobian)
# use torch.autograd.grad
= torch.zeros(2, 3)
jacobian for output_index in range(2):
= torch.vstack(torch.autograd.grad(f(x, y, z)[output_index], (x, y, z))).ravel()
jacobian[output_index] print("Jacobian with grad method:")
print(jacobian)
Jacobian with Functional method:
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
Jacobian with grad method:
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
# Jax version using jax.jacobian
def f_jax(x, y, z):
return jnp.stack([f1(x, y, z), f2(x, y, z)])
print("Jax Jacobian using jax.jacobian:")
print(jnp.array(jax.jacobian(f_jax, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0))).T)
= jnp.array(jax.grad(f1, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
g1 = jnp.array(jax.grad(f2, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
g2 print("Jax Jacobian using jax.grad done manually:")
print(jnp.vstack([g1.T, g2.T]))
Jax Jacobian using jax.jacobian:
[[ 4. 2. -0.]
[ 0. 1. -1.]]
Jax Jacobian using jax.grad done manually:
[[ 4. 2. 0.]
[ 0. 1. -1.]]
Jacobian (vectorized)
def f_vectorized(input):
= input
x, y, z return torch.stack([f1(x, y, z), f2(x, y, z)])
print("Torch Functional method")
print(F.jacobian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))
def f_vectorized_jax(input):
= input
x, y, z return jnp.stack([f1(x, y, z), f2(x, y, z)])
print("Jax Jacobian using jax.jacobian:")
print(jax.jacobian(f_vectorized_jax)(jnp.array([2.0, 1.0, 3.0])))
Torch Functional method
tensor([[ 4., 2., -0.],
[ 0., 1., -1.]])
Jax Jacobian using jax.jacobian:
[[ 4. 2. 0.]
[ 0. 1. -1.]]
Hessian
def f(x, y, z):
return x**2 + y**2 + x * y * z
= torch.tensor(2.0, requires_grad=True)
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z
# Torch version using autograd.functional.hessian
print("Hessian using autograd.functional.hessian:")
= torch.tensor(F.hessian(f, (x, y, z)))
torch_v1_hessian print(torch_v1_hessian)
# Jax version using jax.hessian
= jnp.array(jax.hessian(f, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
jax_hessian print("Jax Hessian using jax.hessian:")
print(jax_hessian)
# Jax version using jax.jacobian
= jax.jacobian(f, argnums=[0, 1, 2])
jacobian_fn = jax.jacobian(jacobian_fn, argnums=[0, 1, 2])
hessian_fn = jnp.array(hessian_fn(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
jax_hessian print("Jax Hessian using jax.jacobian:")
print(jax_hessian)
Hessian using autograd.functional.hessian:
tensor([[2., 3., 1.],
[3., 2., 2.],
[1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Hessian (vectorized)
def f_vectorized(input):
= input
x, y, z return x**2 + y**2 + x * y * z
print("Torch Functional method")
print(F.hessian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))
print("Jax Hessian using jax.hessian:")
print(jax.hessian(f_vectorized)(jnp.array([2.0, 1.0, 3.0])))
print("Jax Hessian using jax.jacobian:")
print(jax.jacobian(jax.jacobian(f_vectorized))(jnp.array([2.0, 1.0, 3.0])))
Torch Functional method
tensor([[2., 3., 1.],
[3., 2., 2.],
[1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
[3. 2. 2.]
[1. 2. 0.]]