Nn Vectorization With Vmap Jaxtyping And Beartype Zeel

Interactive tutorial on nn vectorization with vmap jaxtyping and beartype zeel with practical implementations and visualizations
Author

Nipun Batra

Published

July 24, 2025

Open In Colab

Introduction to Neural Networks

Imports

import torch
import torch.nn as nn
import torch.nn.functional as F

Convention

n0 = 3
n1 = 2
layer = nn.Linear(n0, n1)
layer
Linear(in_features=3, out_features=2, bias=True)
layer.weight.shape
torch.Size([2, 3])
layer.bias.shape
torch.Size([2])
for i in range(n0):
    print(layer.weight[0, i])
tensor(-0.0279, grad_fn=<SelectBackward0>)
tensor(0.0743, grad_fn=<SelectBackward0>)
tensor(-0.1339, grad_fn=<SelectBackward0>)

A 2-layer network

mlp_2layers = nn.Sequential(
    nn.Linear(4, 3),
    nn.ReLU(),
    nn.Linear(3, 2),
    nn.ReLU(),
    nn.Linear(2, 1)
)
params = dict(mlp_2layers.named_parameters())
{name: param.shape for name, param in params.items()}
{'0.weight': torch.Size([3, 4]),
 '0.bias': torch.Size([3]),
 '2.weight': torch.Size([2, 3]),
 '2.bias': torch.Size([2]),
 '4.weight': torch.Size([1, 2]),
 '4.bias': torch.Size([1])}

Vectorization with torch.vmap

from jaxtyping import Float
from torch import Tensor
from beartype import beartype

Quick intro to jaxtyping

Scalars

scalar_type = Float[Tensor, ""]
scalar = torch.tensor(1.0)
scalar.shape
torch.Size([])
isinstance(scalar, scalar_type)
True
non_scalar = torch.tensor([1.0])
non_scalar.shape
torch.Size([1])
isinstance(non_scalar, scalar_type)
False

Vectors

vector_type = Float[Tensor, "n0"]
isinstance(non_scalar, vector_type)
True
vector = torch.tensor([1.0, 2.0])
vector.shape
torch.Size([2])
isinstance(vector, vector_type)
True

Matrices

matrix_type = Float[Tensor, "n0 n1"]
isinstance(vector, matrix_type)
False
matrix = torch.tensor([[1.0, 2.0]])
matrix.shape
torch.Size([1, 2])
isinstance(matrix, matrix_type)
True
another_matrix = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
another_matrix.shape
torch.Size([2, 3])
isinstance(another_matrix, matrix_type)
True

Tensors

tensor_type = Float[Tensor, "n0 n1 n2"]
tensor = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
tensor.shape
torch.Size([2, 2, 2])
isinstance(tensor, tensor_type)
True

Quick intro to beartype

def call_my_name(name):
    return f"Hello, {name}!"
call_my_name("John")
'Hello, John!'
call_my_name(123)
'Hello, 123!'
@beartype
def secured_call_my_name(name: str) -> str:
    return f"Hello, {name}!"
secured_call_my_name("John")
'Hello, John!'
secured_call_my_name(123)
---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
Cell In[39], line 1
----> 1 secured_call_my_name(123)

File <@beartype(__main__.secured_call_my_name) at 0x7fdbb0613ce0>:28, in secured_call_my_name(__beartype_get_violation, __beartype_conf, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Function __main__.secured_call_my_name() parameter name=123 violates type hint <class 'str'>, as int 123 not instance of str.

Vectorization

On which dimensions should we apply the vectorization? - Current layer’s neurons size - Number of examples

n = 50
a = torch.rand(n, n0)
a.shape
torch.Size([50, 3])
activation = F.relu

@beartype
def forward(a: Float[Tensor, "n0"], w: Float[Tensor, "n0"], b: Float[Tensor, ""]) -> Float[Tensor, ""]:
    z = (a * w).sum() + b  # () + () -> ()
    a = activation(z)  # () -> ()
    return a  # ()
dummy_a = torch.rand(n0)
dummy_w = torch.rand(n0)
dummy_b = torch.rand(())
print(dummy_a.shape, dummy_w.shape, dummy_b.shape)
torch.Size([3]) torch.Size([3]) torch.Size([])
forward(dummy_a, dummy_w, dummy_b).shape
torch.Size([])
forward(a[0], layer.weight[0], layer.bias[0])
tensor(0.1544, grad_fn=<ReluBackward0>)

Vectorization over current layer’s neurons size

input shape in forward shape in vectorized forward
a [n0=3] [n0=3]
w [n0=3] [n1=2, n0=3]
b [] [n1=2]
output [] [n1=2]
v1_forward = torch.vmap(forward, in_dims=(None, 0, 0), out_dims=0)
layer.weight.shape
torch.Size([2, 3])
layer.bias.shape
torch.Size([2])
out = v1_forward(a[0], layer.weight, layer.bias)
out.shape
torch.Size([2])

Vectorization over number of examples

input shape in forward shape in vectorized forward
a [n0=3] [n=50, n0=3]
w [n1=2, n0=3] [n1=2, n0=3]
b [n1=2] [n1=2]
output [n1=2] [n=50, n1=2]
v2_forward = torch.vmap(v1_forward, in_dims=(0, None, None), out_dims=0)
final_out = v2_forward(a, layer.weight, layer.bias)
final_out.shape
torch.Size([50, 2])

Comparing with torch model forward pass

layer_out = F.relu(layer(a))
torch.allclose(final_out, layer_out)
True

XOR example

Define inputs, outputs, weights and biases

xor_x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
print(xor_x)
print(xor_x.shape)
tensor([[0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 1.]])
torch.Size([4, 2])
xor_y = torch.tensor([[0.0], [1.0], [1.0], [0.0]])
print(xor_y)
print(xor_y.shape)
tensor([[0.],
        [1.],
        [1.],
        [0.]])
torch.Size([4, 1])
W_1 = torch.tensor([[1.0, 1.0], [1.0, 1.0]])
b_1 = torch.tensor([0.0, -1.0])
print(W_1)
print(b_1)
print(W_1.shape, b_1.shape)
tensor([[1., 1.],
        [1., 1.]])
tensor([ 0., -1.])
torch.Size([2, 2]) torch.Size([2])
W_2 = torch.tensor([[1.0, -2.0]])
b_2 = torch.tensor([0.0])
print(W_2)
print(b_2)
print(W_2.shape, b_2.shape)
tensor([[ 1., -2.]])
tensor([0.])
torch.Size([1, 2]) torch.Size([1])

Forward pass

a1 = v2_forward(xor_x, W_1, b_1)
a1
tensor([[0., 0.],
        [1., 0.],
        [1., 0.],
        [2., 1.]])
a2 = v2_forward(a1, W_2, b_2)
a2
tensor([[0.],
        [1.],
        [1.],
        [0.]])

Forward pass with torch neural network

model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
print(model[0])
print(model[1])
print(model[2])
Linear(in_features=2, out_features=2, bias=True)
ReLU()
Linear(in_features=2, out_features=1, bias=True)
model[0].weight.data = W_1
model[0].bias.data = b_1
model[2].weight.data = W_2
model[2].bias.data = b_2
model_out = model(xor_x)
model_out
tensor([[0.],
        [1.],
        [1.],
        [0.]], grad_fn=<AddmmBackward0>)