import jax.numpy as jnp
import jax
import torch
print(torch.__version__)
print(jax.__version__)Array(1., dtype=float32, weak_type=True)Nipun Batra
March 1, 2023
Array(1., dtype=float32, weak_type=True)tensor(1.)def f(x):
    return jnp.abs(x)
z1 = torch.tensor(0.0001, requires_grad=True)
torch.abs(z1).backward()
z2 = torch.tensor(-0.0001, requires_grad=True)
torch.abs(z2).backward()
z3 = torch.tensor(0.0, requires_grad=True)
torch.abs(z3).backward()
print(jax.grad(f)(0.0), z1.grad, z2.grad, z3.grad)1.0 tensor(1.) tensor(-1.) tensor(0.)ImportError: dlopen(/Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so, 0x0002): Symbol not found: __ZN2at4_ops10as_strided4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEES8_NS5_8optionalIS7_EE
  Referenced from: <12715304-4308-3E9B-A374-E4ADB3345E65> /Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so
  Expected in:     <22ECBAD5-EEDD-3C80-9B5A-0564B60B6811> /Users/nipun/miniconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib'1.12.1'