AutoDiff in JAX and PyTorch

AutoDiff in JAX and PyTorch
Author

Nipun Batra

Published

March 1, 2023

import jax.numpy as jnp
import jax

import torch
print(torch.__version__)
print(jax.__version__)
Array(1., dtype=float32, weak_type=True)
def f(x):
    return jnp.sin(x)
Array(1., dtype=float32, weak_type=True)
z = torch.tensor(0.0, requires_grad=True)
torch.sin(z).backward()
print(jax.grad(f)(0.0), z.grad)
tensor(1.)
def f(x):
    return jnp.abs(x)


z1 = torch.tensor(0.0001, requires_grad=True)
torch.abs(z1).backward()

z2 = torch.tensor(-0.0001, requires_grad=True)
torch.abs(z2).backward()

z3 = torch.tensor(0.0, requires_grad=True)
torch.abs(z3).backward()

print(jax.grad(f)(0.0), z1.grad, z2.grad, z3.grad)
1.0 tensor(1.) tensor(-1.) tensor(0.)
# Use functorch

import functorch
ImportError: dlopen(/Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so, 0x0002): Symbol not found: __ZN2at4_ops10as_strided4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEES8_NS5_8optionalIS7_EE
  Referenced from: <12715304-4308-3E9B-A374-E4ADB3345E65> /Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so
  Expected in:     <22ECBAD5-EEDD-3C80-9B5A-0564B60B6811> /Users/nipun/miniconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
'1.12.1'