import torch
import torch.nn as nn
import torch.nn.functional as FNn Vectorization With Vmap Jaxtyping And Beartype Zeel
Interactive tutorial on nn vectorization with vmap jaxtyping and beartype zeel with practical implementations and visualizations
Introduction to Neural Networks
Imports
Convention
n0 = 3
n1 = 2
layer = nn.Linear(n0, n1)
layerLinear(in_features=3, out_features=2, bias=True)
layer.weight.shapetorch.Size([2, 3])
layer.bias.shapetorch.Size([2])
for i in range(n0):
print(layer.weight[0, i])tensor(-0.0279, grad_fn=<SelectBackward0>)
tensor(0.0743, grad_fn=<SelectBackward0>)
tensor(-0.1339, grad_fn=<SelectBackward0>)
A 2-layer network
mlp_2layers = nn.Sequential(
nn.Linear(4, 3),
nn.ReLU(),
nn.Linear(3, 2),
nn.ReLU(),
nn.Linear(2, 1)
)params = dict(mlp_2layers.named_parameters())
{name: param.shape for name, param in params.items()}{'0.weight': torch.Size([3, 4]),
'0.bias': torch.Size([3]),
'2.weight': torch.Size([2, 3]),
'2.bias': torch.Size([2]),
'4.weight': torch.Size([1, 2]),
'4.bias': torch.Size([1])}
Vectorization with torch.vmap
from jaxtyping import Float
from torch import Tensor
from beartype import beartypeQuick intro to jaxtyping
Scalars
scalar_type = Float[Tensor, ""]scalar = torch.tensor(1.0)
scalar.shapetorch.Size([])
isinstance(scalar, scalar_type)True
non_scalar = torch.tensor([1.0])
non_scalar.shapetorch.Size([1])
isinstance(non_scalar, scalar_type)False
Vectors
vector_type = Float[Tensor, "n0"]isinstance(non_scalar, vector_type)True
vector = torch.tensor([1.0, 2.0])
vector.shapetorch.Size([2])
isinstance(vector, vector_type)True
Matrices
matrix_type = Float[Tensor, "n0 n1"]isinstance(vector, matrix_type)False
matrix = torch.tensor([[1.0, 2.0]])
matrix.shapetorch.Size([1, 2])
isinstance(matrix, matrix_type)True
another_matrix = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
another_matrix.shapetorch.Size([2, 3])
isinstance(another_matrix, matrix_type)True
Tensors
tensor_type = Float[Tensor, "n0 n1 n2"]tensor = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
tensor.shapetorch.Size([2, 2, 2])
isinstance(tensor, tensor_type)True
Quick intro to beartype
def call_my_name(name):
return f"Hello, {name}!"call_my_name("John")'Hello, John!'
call_my_name(123)'Hello, 123!'
@beartype
def secured_call_my_name(name: str) -> str:
return f"Hello, {name}!"secured_call_my_name("John")'Hello, John!'
secured_call_my_name(123)--------------------------------------------------------------------------- BeartypeCallHintParamViolation Traceback (most recent call last) Cell In[39], line 1 ----> 1 secured_call_my_name(123) File <@beartype(__main__.secured_call_my_name) at 0x7fdbb0613ce0>:28, in secured_call_my_name(__beartype_get_violation, __beartype_conf, __beartype_func, *args, **kwargs) BeartypeCallHintParamViolation: Function __main__.secured_call_my_name() parameter name=123 violates type hint <class 'str'>, as int 123 not instance of str.
Vectorization
On which dimensions should we apply the vectorization? - Current layer’s neurons size - Number of examples
n = 50
a = torch.rand(n, n0)
a.shapetorch.Size([50, 3])
activation = F.relu
@beartype
def forward(a: Float[Tensor, "n0"], w: Float[Tensor, "n0"], b: Float[Tensor, ""]) -> Float[Tensor, ""]:
z = (a * w).sum() + b # () + () -> ()
a = activation(z) # () -> ()
return a # ()dummy_a = torch.rand(n0)
dummy_w = torch.rand(n0)
dummy_b = torch.rand(())
print(dummy_a.shape, dummy_w.shape, dummy_b.shape)torch.Size([3]) torch.Size([3]) torch.Size([])
forward(dummy_a, dummy_w, dummy_b).shapetorch.Size([])
forward(a[0], layer.weight[0], layer.bias[0])tensor(0.1544, grad_fn=<ReluBackward0>)
Vectorization over current layer’s neurons size
| input | shape in forward | shape in vectorized forward |
|---|---|---|
| a | [n0=3] | [n0=3] |
| w | [n0=3] | [n1=2, n0=3] |
| b | [] | [n1=2] |
| output | [] | [n1=2] |
v1_forward = torch.vmap(forward, in_dims=(None, 0, 0), out_dims=0)layer.weight.shapetorch.Size([2, 3])
layer.bias.shapetorch.Size([2])
out = v1_forward(a[0], layer.weight, layer.bias)
out.shapetorch.Size([2])
Vectorization over number of examples
| input | shape in forward | shape in vectorized forward |
|---|---|---|
| a | [n0=3] | [n=50, n0=3] |
| w | [n1=2, n0=3] | [n1=2, n0=3] |
| b | [n1=2] | [n1=2] |
| output | [n1=2] | [n=50, n1=2] |
v2_forward = torch.vmap(v1_forward, in_dims=(0, None, None), out_dims=0)final_out = v2_forward(a, layer.weight, layer.bias)
final_out.shapetorch.Size([50, 2])
Comparing with torch model forward pass
layer_out = F.relu(layer(a))
torch.allclose(final_out, layer_out)True
XOR example
Define inputs, outputs, weights and biases
xor_x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
print(xor_x)
print(xor_x.shape)tensor([[0., 0.],
[0., 1.],
[1., 0.],
[1., 1.]])
torch.Size([4, 2])
xor_y = torch.tensor([[0.0], [1.0], [1.0], [0.0]])
print(xor_y)
print(xor_y.shape)tensor([[0.],
[1.],
[1.],
[0.]])
torch.Size([4, 1])
W_1 = torch.tensor([[1.0, 1.0], [1.0, 1.0]])
b_1 = torch.tensor([0.0, -1.0])
print(W_1)
print(b_1)
print(W_1.shape, b_1.shape)tensor([[1., 1.],
[1., 1.]])
tensor([ 0., -1.])
torch.Size([2, 2]) torch.Size([2])
W_2 = torch.tensor([[1.0, -2.0]])
b_2 = torch.tensor([0.0])
print(W_2)
print(b_2)
print(W_2.shape, b_2.shape)tensor([[ 1., -2.]])
tensor([0.])
torch.Size([1, 2]) torch.Size([1])
Forward pass
a1 = v2_forward(xor_x, W_1, b_1)
a1tensor([[0., 0.],
[1., 0.],
[1., 0.],
[2., 1.]])
a2 = v2_forward(a1, W_2, b_2)
a2tensor([[0.],
[1.],
[1.],
[0.]])
Forward pass with torch neural network
model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
print(model[0])
print(model[1])
print(model[2])Linear(in_features=2, out_features=2, bias=True)
ReLU()
Linear(in_features=2, out_features=1, bias=True)
model[0].weight.data = W_1
model[0].bias.data = b_1
model[2].weight.data = W_2
model[2].bias.data = b_2model_out = model(xor_x)
model_outtensor([[0.],
[1.],
[1.],
[0.]], grad_fn=<AddmmBackward0>)