import torch
import torch.nn as nn
import torch.nn.functional as F
Nn Vectorization With Vmap Jaxtyping And Beartype Zeel
Interactive tutorial on nn vectorization with vmap jaxtyping and beartype zeel with practical implementations and visualizations
Introduction to Neural Networks
Imports
Convention
= 3
n0 = 2
n1 = nn.Linear(n0, n1)
layer layer
Linear(in_features=3, out_features=2, bias=True)
layer.weight.shape
torch.Size([2, 3])
layer.bias.shape
torch.Size([2])
for i in range(n0):
print(layer.weight[0, i])
tensor(-0.0279, grad_fn=<SelectBackward0>)
tensor(0.0743, grad_fn=<SelectBackward0>)
tensor(-0.1339, grad_fn=<SelectBackward0>)
A 2-layer network
= nn.Sequential(
mlp_2layers 4, 3),
nn.Linear(
nn.ReLU(),3, 2),
nn.Linear(
nn.ReLU(),2, 1)
nn.Linear( )
= dict(mlp_2layers.named_parameters())
params for name, param in params.items()} {name: param.shape
{'0.weight': torch.Size([3, 4]),
'0.bias': torch.Size([3]),
'2.weight': torch.Size([2, 3]),
'2.bias': torch.Size([2]),
'4.weight': torch.Size([1, 2]),
'4.bias': torch.Size([1])}
Vectorization with torch.vmap
from jaxtyping import Float
from torch import Tensor
from beartype import beartype
Quick intro to jaxtyping
Scalars
= Float[Tensor, ""] scalar_type
= torch.tensor(1.0)
scalar scalar.shape
torch.Size([])
isinstance(scalar, scalar_type)
True
= torch.tensor([1.0])
non_scalar non_scalar.shape
torch.Size([1])
isinstance(non_scalar, scalar_type)
False
Vectors
= Float[Tensor, "n0"] vector_type
isinstance(non_scalar, vector_type)
True
= torch.tensor([1.0, 2.0])
vector vector.shape
torch.Size([2])
isinstance(vector, vector_type)
True
Matrices
= Float[Tensor, "n0 n1"] matrix_type
isinstance(vector, matrix_type)
False
= torch.tensor([[1.0, 2.0]])
matrix matrix.shape
torch.Size([1, 2])
isinstance(matrix, matrix_type)
True
= torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
another_matrix another_matrix.shape
torch.Size([2, 3])
isinstance(another_matrix, matrix_type)
True
Tensors
= Float[Tensor, "n0 n1 n2"] tensor_type
= torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
tensor tensor.shape
torch.Size([2, 2, 2])
isinstance(tensor, tensor_type)
True
Quick intro to beartype
def call_my_name(name):
return f"Hello, {name}!"
"John") call_my_name(
'Hello, John!'
123) call_my_name(
'Hello, 123!'
@beartype
def secured_call_my_name(name: str) -> str:
return f"Hello, {name}!"
"John") secured_call_my_name(
'Hello, John!'
123) secured_call_my_name(
--------------------------------------------------------------------------- BeartypeCallHintParamViolation Traceback (most recent call last) Cell In[39], line 1 ----> 1 secured_call_my_name(123) File <@beartype(__main__.secured_call_my_name) at 0x7fdbb0613ce0>:28, in secured_call_my_name(__beartype_get_violation, __beartype_conf, __beartype_func, *args, **kwargs) BeartypeCallHintParamViolation: Function __main__.secured_call_my_name() parameter name=123 violates type hint <class 'str'>, as int 123 not instance of str.
Vectorization
On which dimensions should we apply the vectorization? - Current layer’s neurons size - Number of examples
= 50
n = torch.rand(n, n0)
a a.shape
torch.Size([50, 3])
= F.relu
activation
@beartype
def forward(a: Float[Tensor, "n0"], w: Float[Tensor, "n0"], b: Float[Tensor, ""]) -> Float[Tensor, ""]:
= (a * w).sum() + b # () + () -> ()
z = activation(z) # () -> ()
a return a # ()
= torch.rand(n0)
dummy_a = torch.rand(n0)
dummy_w = torch.rand(())
dummy_b print(dummy_a.shape, dummy_w.shape, dummy_b.shape)
torch.Size([3]) torch.Size([3]) torch.Size([])
forward(dummy_a, dummy_w, dummy_b).shape
torch.Size([])
0], layer.weight[0], layer.bias[0]) forward(a[
tensor(0.1544, grad_fn=<ReluBackward0>)
Vectorization over current layer’s neurons size
input | shape in forward | shape in vectorized forward |
---|---|---|
a | [n0=3] | [n0=3] |
w | [n0=3] | [n1=2, n0=3] |
b | [] | [n1=2] |
output | [] | [n1=2] |
= torch.vmap(forward, in_dims=(None, 0, 0), out_dims=0) v1_forward
layer.weight.shape
torch.Size([2, 3])
layer.bias.shape
torch.Size([2])
= v1_forward(a[0], layer.weight, layer.bias)
out out.shape
torch.Size([2])
Vectorization over number of examples
input | shape in forward | shape in vectorized forward |
---|---|---|
a | [n0=3] | [n=50, n0=3] |
w | [n1=2, n0=3] | [n1=2, n0=3] |
b | [n1=2] | [n1=2] |
output | [n1=2] | [n=50, n1=2] |
= torch.vmap(v1_forward, in_dims=(0, None, None), out_dims=0) v2_forward
= v2_forward(a, layer.weight, layer.bias)
final_out final_out.shape
torch.Size([50, 2])
Comparing with torch model forward pass
= F.relu(layer(a))
layer_out torch.allclose(final_out, layer_out)
True
XOR example
Define inputs, outputs, weights and biases
= torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
xor_x print(xor_x)
print(xor_x.shape)
tensor([[0., 0.],
[0., 1.],
[1., 0.],
[1., 1.]])
torch.Size([4, 2])
= torch.tensor([[0.0], [1.0], [1.0], [0.0]])
xor_y print(xor_y)
print(xor_y.shape)
tensor([[0.],
[1.],
[1.],
[0.]])
torch.Size([4, 1])
= torch.tensor([[1.0, 1.0], [1.0, 1.0]])
W_1 = torch.tensor([0.0, -1.0])
b_1 print(W_1)
print(b_1)
print(W_1.shape, b_1.shape)
tensor([[1., 1.],
[1., 1.]])
tensor([ 0., -1.])
torch.Size([2, 2]) torch.Size([2])
= torch.tensor([[1.0, -2.0]])
W_2 = torch.tensor([0.0])
b_2 print(W_2)
print(b_2)
print(W_2.shape, b_2.shape)
tensor([[ 1., -2.]])
tensor([0.])
torch.Size([1, 2]) torch.Size([1])
Forward pass
= v2_forward(xor_x, W_1, b_1)
a1 a1
tensor([[0., 0.],
[1., 0.],
[1., 0.],
[2., 1.]])
= v2_forward(a1, W_2, b_2)
a2 a2
tensor([[0.],
[1.],
[1.],
[0.]])
Forward pass with torch neural network
= nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
model print(model[0])
print(model[1])
print(model[2])
Linear(in_features=2, out_features=2, bias=True)
ReLU()
Linear(in_features=2, out_features=1, bias=True)
0].weight.data = W_1
model[0].bias.data = b_1
model[2].weight.data = W_2
model[2].bias.data = b_2 model[
= model(xor_x)
model_out model_out
tensor([[0.],
[1.],
[1.],
[0.]], grad_fn=<AddmmBackward0>)