import jax.numpy as jnp
import jax
import torch
print(torch.__version__)
print(jax.__version__)
Array(1., dtype=float32, weak_type=True)
Nipun Batra
March 1, 2023
Array(1., dtype=float32, weak_type=True)
tensor(1.)
def f(x):
return jnp.abs(x)
z1 = torch.tensor(0.0001, requires_grad=True)
torch.abs(z1).backward()
z2 = torch.tensor(-0.0001, requires_grad=True)
torch.abs(z2).backward()
z3 = torch.tensor(0.0, requires_grad=True)
torch.abs(z3).backward()
print(jax.grad(f)(0.0), z1.grad, z2.grad, z3.grad)
1.0 tensor(1.) tensor(-1.) tensor(0.)
ImportError: dlopen(/Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so, 0x0002): Symbol not found: __ZN2at4_ops10as_strided4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEES8_NS5_8optionalIS7_EE
Referenced from: <12715304-4308-3E9B-A374-E4ADB3345E65> /Users/nipun/miniconda3/lib/python3.9/site-packages/functorch/_C.cpython-39-darwin.so
Expected in: <22ECBAD5-EEDD-3C80-9B5A-0564B60B6811> /Users/nipun/miniconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
'1.12.1'