Calculus

import torch
import torch.autograd.functional as F
import torch.distributions as dist
import jax.numpy as jnp
import jax

%matplotlib inline

# Retina display
%config InlineBackend.figure_format = 'retina'

Derivative

def f1(x):
    return 3*x**2

x = torch.tensor(2.0, requires_grad=True)


# Torch version 1 (using .backward)
z = f1(x)
z.backward()
print("Using backwards in torch", x.grad)

# Torch version 2 (using autograd.grad)
print("Using autograd.grad in torch", torch.autograd.grad(f1(x), x)[0])

# Jax version
print("Using grad in jax", jax.grad(f1)(jnp.array(2.0)))
Using backwards in torch tensor(12.)
Using autograd.grad in torch tensor(12.)
Using grad in jax 12.0

Partial Derivative

def f2(x, y):
    return 2*x**2 + 3*y
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.5, requires_grad=True)

# Torch version 1 (using .backward)
z = f2(x, y)
z.backward()
print("\nUsing Method 1 Torch")
print("Partial wrt x: ", x.grad)
print("Partial wrt y: ", y.grad)


# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Partial wrt x: ", torch.autograd.grad(f2(x, y), x)[0])
print("Partial wrt y: ", torch.autograd.grad(f2(x, y), y)[0])

# Jax version
print("\nUsing Jax")
print("Partial wrt x: ", jax.grad(f2, argnums=0)(jnp.array(2.0), jnp.array(1.5)))
print("Partial wrt y: ", jax.grad(f2, argnums=1)(jnp.array(2.0), jnp.array(1.5)))

Using Method 1 Torch
Partial wrt x:  tensor(8.)
Partial wrt y:  tensor(3.)

Using Method 2 Torch
Partial wrt x:  tensor(8.)
Partial wrt y:  tensor(3.)

Using Jax
Partial wrt x:  8.0
Partial wrt y:  3.0

Gradient

# Torch version 1 (using .backward)
grad_f2_v1 = torch.tensor([x.grad, y.grad])
print("\nUsing Method 1 Torch")
print(grad_f2_v1)
# Torch version 2 (using autograd.grad)
grad_f2_v2 = torch.tensor(torch.autograd.grad(f2(x, y), (x, y)))
print("\nUsing Method 2 Torch")
print(grad_f2_v2)

# Jax version
grad_f2_jax = jax.grad(f2, argnums=[0, 1])(jnp.array(2.0), jnp.array(1.5))
print("\nUsing Jax")
print(grad_f2_jax)

Using Method 1 Torch
tensor([8., 3.])

Using Method 2 Torch
tensor([8., 3.])

Using Jax
(Array(8., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))

Gradient (vectorized)

def f2_vectorized(input):
    x, y = input
    return 2*x**2 + 3*y

input = torch.tensor([2.0, 1.5], requires_grad=True)

# Torch version 1 (using .backward)
z = f2_vectorized(input)
z.backward()
print("\nUsing Method 1 Torch")
print("Gradient: ", input.grad)

# Torch version 2 (using autograd.grad)
print("\nUsing Method 2 Torch")
print("Gradient: ", torch.autograd.grad(f2_vectorized(input), input))

# Jax version
print("\nUsing Jax")
print("Gradient: ", jax.grad(f2_vectorized)(jnp.array([2.0, 1.5])))

Using Method 1 Torch
Gradient:  tensor([8., 3.])

Using Method 2 Torch
Gradient:  (tensor([8., 3.]),)

Using Jax
Gradient:  [8. 3.]

Jacobian

# We take the Jacobian of the function f(x, y, z) = [x**2 + y**2, y - z]
# The Jacobian analytically is [[2x, 2y, 0], [0, 1, -1]]
def f1(x, y, z):
        return x**2 + y**2
def f2(x, y, z):
        return y - z

def f(x, y, z):
    return torch.stack([f1(x, y, z), f2(x, y, z)])

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.0, requires_grad=True)
z = torch.tensor(3.0, requires_grad=True)

output = f(x, y, z)
output
tensor([ 5., -2.], grad_fn=<StackBackward0>)

If we are to directly call: z.backward() we get the following error.

try:
    output.backward()
except Exception as e:
    print(e)
grad can be implicitly created only for scalar outputs

But, output.backward() has a parameter called gradient

The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying gradient. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t. self.

import torch.autograd.functional as F

jacobian = torch.vstack(F.jacobian(f, (x, y, z))).T
print("Jacobian with Functional method:")
print(jacobian)

# use torch.autograd.grad
jacobian = torch.zeros(2, 3)
for output_index in range(2):
    jacobian[output_index] = torch.vstack(torch.autograd.grad(f(x, y, z)[output_index], (x, y, z))).ravel()
print("Jacobian with grad method:")
print(jacobian)
Jacobian with Functional method:
tensor([[ 4.,  2., -0.],
        [ 0.,  1., -1.]])
Jacobian with grad method:
tensor([[ 4.,  2., -0.],
        [ 0.,  1., -1.]])
# Jax version using jax.jacobian
def f_jax(x, y, z):
    return jnp.stack([f1(x, y, z), f2(x, y, z)])

print("Jax Jacobian using jax.jacobian:")
print(jnp.array(jax.jacobian(f_jax, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0))).T)

g1 = jnp.array(jax.grad(f1, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
g2 = jnp.array(jax.grad(f2, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Jacobian using jax.grad done manually:")
print(jnp.vstack([g1.T, g2.T]))
Jax Jacobian using jax.jacobian:
[[ 4.  2. -0.]
 [ 0.  1. -1.]]
Jax Jacobian using jax.grad done manually:
[[ 4.  2.  0.]
 [ 0.  1. -1.]]

Jacobian (vectorized)

def f_vectorized(input):
    x, y, z = input
    return torch.stack([f1(x, y, z), f2(x, y, z)])

print("Torch Functional method")
print(F.jacobian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))

def f_vectorized_jax(input):
    x, y, z = input
    return jnp.stack([f1(x, y, z), f2(x, y, z)])

print("Jax Jacobian using jax.jacobian:")
print(jax.jacobian(f_vectorized_jax)(jnp.array([2.0, 1.0, 3.0])))
Torch Functional method
tensor([[ 4.,  2., -0.],
        [ 0.,  1., -1.]])
Jax Jacobian using jax.jacobian:
[[ 4.  2.  0.]
 [ 0.  1. -1.]]

Hessian

def f(x, y, z):
    return x**2 + y**2 + x * y * z

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(1.0, requires_grad=True)
z = torch.tensor(3.0, requires_grad=True)

# Torch version using autograd.functional.hessian
print("Hessian using autograd.functional.hessian:")
torch_v1_hessian = torch.tensor(F.hessian(f, (x, y, z)))
print(torch_v1_hessian)

# Jax version using jax.hessian
jax_hessian = jnp.array(jax.hessian(f, argnums=[0, 1, 2])(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Hessian using jax.hessian:")
print(jax_hessian)

# Jax version using jax.jacobian
jacobian_fn = jax.jacobian(f, argnums=[0, 1, 2])
hessian_fn = jax.jacobian(jacobian_fn, argnums=[0, 1, 2])
jax_hessian = jnp.array(hessian_fn(jnp.array(2.0), jnp.array(1.0), jnp.array(3.0)))
print("Jax Hessian using jax.jacobian:")
print(jax_hessian)
Hessian using autograd.functional.hessian:
tensor([[2., 3., 1.],
        [3., 2., 2.],
        [1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
 [3. 2. 2.]
 [1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
 [3. 2. 2.]
 [1. 2. 0.]]

Hessian (vectorized)

def f_vectorized(input):
    x, y, z = input
    return x**2 + y**2 + x * y * z

print("Torch Functional method")
print(F.hessian(f_vectorized, torch.tensor([2.0, 1.0, 3.0])))

print("Jax Hessian using jax.hessian:")
print(jax.hessian(f_vectorized)(jnp.array([2.0, 1.0, 3.0])))

print("Jax Hessian using jax.jacobian:")
print(jax.jacobian(jax.jacobian(f_vectorized))(jnp.array([2.0, 1.0, 3.0])))
Torch Functional method
tensor([[2., 3., 1.],
        [3., 2., 2.],
        [1., 2., 0.]])
Jax Hessian using jax.hessian:
[[2. 3. 1.]
 [3. 2. 2.]
 [1. 2. 0.]]
Jax Hessian using jax.jacobian:
[[2. 3. 1.]
 [3. 2. 2.]
 [1. 2. 0.]]