Lecture 0 · MLE & MAP from scratch · in PyTorch

Goal · turn the lecture’s pen-and-paper derivations into runnable code.

Same recipe in five settings:

  1. MLE for a coin — recover \(\hat p = h/N\) by optimizing the Bernoulli log-likelihood.
  2. MLE for linear regression — show that maximizing the Gaussian log-likelihood = the OLS closed form.
  3. MLE for logistic regression — show BCEWithLogitsLoss is exactly the Bernoulli NLL.
  4. MLE for multiclass — softmax + categorical = cross_entropy.
  5. MAP for linear regression — Gaussian prior → L2 (ridge), Laplace prior → L1 (lasso). Watch sparsity emerge as \(\lambda\) grows.

Throughout, we use torch.distributions so the probabilistic structure is explicit. The only recurring code pattern is:

loss = -dist.log_prob(observation).sum()
loss.backward()
opt.step()

That single line is all of MLE. The rest of the notebook is bookkeeping and visualization.

import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli, Normal, Categorical, Laplace
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(0)
np.random.seed(0)

# House style
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

1 · MLE for a coin

We observe H, H, T, H, T, T, H, H, T, H — 6 heads, 4 tails. The MLE for the Bernoulli is \(\hat p = h/N = 0.6\) analytically. Let’s verify with gradient ascent on the log-likelihood.

data = torch.tensor([1, 1, 0, 1, 0, 0, 1, 1, 0, 1], dtype=torch.float)
print(f"observations: {data.tolist()}")
print(f"#H = {int(data.sum())}, N = {len(data)}")
print(f"analytical MLE: p = h/N = {data.mean().item():.3f}")

Step 1 · parametrize \(p\) such that \(p \in (0, 1)\).

Direct optimization of \(p\) would need a constraint. Instead, we optimize an unconstrained logit \(\theta\) and use \(p = \sigma(\theta)\). (Same trick as in logistic regression.)

theta = torch.zeros(1, requires_grad=True)   # logit; sigmoid(0) = 0.5
opt = torch.optim.Adam([theta], lr=0.1)

losses, ps = [], []
for step in range(300):
    p = torch.sigmoid(theta)
    nll = -Bernoulli(probs=p).log_prob(data).sum()
    opt.zero_grad(); nll.backward(); opt.step()
    losses.append(nll.item()); ps.append(p.item())

print(f"converged p = {ps[-1]:.4f}    (analytical 0.6)")
fig, axes = plt.subplots(1, 2, figsize=(11, 3.6))
axes[0].plot(losses); axes[0].set_xlabel('step'); axes[0].set_ylabel('NLL')
axes[0].set_title('NLL converges')

# likelihood surface
p_grid = torch.linspace(0.001, 0.999, 200)
loglik = (data.sum() * torch.log(p_grid) +
          (len(data) - data.sum()) * torch.log(1 - p_grid))
axes[1].plot(p_grid, loglik); axes[1].axvline(0.6, color='r', ls='--', label='MLE = 0.6')
axes[1].set_xlabel('p'); axes[1].set_ylabel('log-likelihood'); axes[1].legend()
axes[1].set_title('likelihood surface for the coin')
plt.tight_layout(); plt.show()

Optimization recovers \(\hat p \approx 0.6\), exactly matching the closed form. The right-hand plot shows the (concave) log-likelihood surface — gradient ascent walks up to the peak.

2 · MLE for linear regression = OLS

Generate synthetic data with true parameters \(\theta^* = [2, -3]\), intercept \(b^* = 1\), and Gaussian noise \(\sigma = 0.5\). Then fit by maximizing the Gaussian log-likelihood and check we recover the OLS closed form.

N, d = 200, 2
X = torch.randn(N, d)
theta_true = torch.tensor([2.0, -3.0])
b_true = 1.0
sigma_true = 0.5
y = X @ theta_true + b_true + sigma_true * torch.randn(N)

# Closed-form OLS · stack a column of 1s for the bias
X_ = torch.cat([X, torch.ones(N, 1)], dim=1)
theta_ols = torch.linalg.lstsq(X_, y).solution
print(f"OLS solution:  theta = {theta_ols[:-1].tolist()},  bias = {theta_ols[-1].item():.3f}")
# MLE via gradient descent on the Gaussian NLL
theta = torch.zeros(d, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
log_sigma = torch.zeros(1, requires_grad=True)   # learnable noise scale

opt = torch.optim.Adam([theta, b, log_sigma], lr=0.05)
losses = []
for step in range(800):
    mu = X @ theta + b
    sigma = torch.exp(log_sigma)
    nll = -Normal(mu, sigma).log_prob(y).sum()
    opt.zero_grad(); nll.backward(); opt.step()
    losses.append(nll.item())

print(f"MLE solution:  theta = {theta.tolist()},  bias = {b.item():.3f},  sigma = {torch.exp(log_sigma).item():.3f}")

The MLE estimate matches OLS to ~3 decimal places. The learned noise scale also recovers the true \(\sigma \approx 0.5\).

Why they agree · maximizing the Gaussian log-likelihood is exactly minimizing \(\sum (y_i - \hat\mu_i)^2\) (after dropping constants in \(\theta\)). Different code path, same answer.

3 · MLE for logistic regression = BCE

Generate a binary classification dataset. Fit by minimizing two equivalent losses — F.binary_cross_entropy_with_logits and the Bernoulli NLL via torch.distributions. They should give identical numbers.

N, d = 300, 2
X = torch.randn(N, d)
theta_true = torch.tensor([2.0, -1.5])
b_true = 0.3
logits_true = X @ theta_true + b_true
y = torch.bernoulli(torch.sigmoid(logits_true))     # binary labels
# Path A · BCEWithLogitsLoss
theta_a = torch.zeros(d, requires_grad=True)
b_a = torch.zeros(1, requires_grad=True)
opt_a = torch.optim.Adam([theta_a, b_a], lr=0.05)
loss_a_history = []
for _ in range(600):
    logits = X @ theta_a + b_a
    loss = F.binary_cross_entropy_with_logits(logits, y, reduction='sum')
    opt_a.zero_grad(); loss.backward(); opt_a.step()
    loss_a_history.append(loss.item())

# Path B · explicit Bernoulli NLL via torch.distributions
theta_b = torch.zeros(d, requires_grad=True)
b_b = torch.zeros(1, requires_grad=True)
opt_b = torch.optim.Adam([theta_b, b_b], lr=0.05)
loss_b_history = []
for _ in range(600):
    p_hat = torch.sigmoid(X @ theta_b + b_b)
    loss = -Bernoulli(probs=p_hat).log_prob(y).sum()
    opt_b.zero_grad(); loss.backward(); opt_b.step()
    loss_b_history.append(loss.item())

print(f"BCE-with-logits final: theta = {theta_a.tolist()},  bias = {b_a.item():.3f}")
print(f"Bernoulli NLL final  : theta = {theta_b.tolist()},  bias = {b_b.item():.3f}")
print(f"\ntrue                 : theta = {theta_true.tolist()},  bias = {b_true}")

Both paths converge to the same parameters, near the truth. BCEWithLogitsLoss is just the Bernoulli NLL with a fused log-sigmoid for numerical stability.

4 · MLE for multiclass = cross-entropy

3 classes in 2D. Fit a linear model with softmax output by minimizing the categorical NLL. Verify it agrees with F.cross_entropy.

K = 3
N, d = 300, 2
centers = torch.tensor([[-2., 0.], [2., 0.], [0., 2.5]])
class_assign = torch.randint(0, K, (N,))
X = centers[class_assign] + 0.6 * torch.randn(N, d)
y = class_assign

# Linear model · weights W (d, K), bias b (K,)
W = torch.zeros(d, K, requires_grad=True)
b = torch.zeros(K, requires_grad=True)

opt = torch.optim.Adam([W, b], lr=0.05)
for step in range(600):
    logits = X @ W + b
    # Path A · F.cross_entropy
    loss_ce = F.cross_entropy(logits, y, reduction='sum')
    # Path B · explicit Categorical NLL
    loss_cat = -Categorical(logits=logits).log_prob(y).sum()
    # They should be identical
    if step == 0:
        print(f"step 0 · F.cross_entropy = {loss_ce.item():.4f},  Categorical NLL = {loss_cat.item():.4f}")
    opt.zero_grad(); loss_ce.backward(); opt.step()

print(f"final accuracy: {(logits.argmax(-1) == y).float().mean().item():.3f}")
# Decision boundaries
xx, yy = torch.meshgrid(torch.linspace(-4, 4, 200), torch.linspace(-3, 5, 200), indexing='ij')
grid = torch.stack([xx.flatten(), yy.flatten()], dim=-1)
with torch.no_grad():
    pred = (grid @ W + b).argmax(-1).reshape(xx.shape)

plt.figure(figsize=(5.5, 5))
plt.contourf(xx, yy, pred, alpha=0.25, levels=[-0.5, 0.5, 1.5, 2.5], colors=['#d97757', '#7b9e89', '#4a6670'])
for c, color in zip(range(K), ['#d97757', '#7b9e89', '#4a6670']):
    plt.scatter(X[y == c, 0], X[y == c, 1], color=color, edgecolor='k', s=30, label=f'class {c}')
plt.legend(); plt.title('multiclass MLE · linear softmax classifier')
plt.tight_layout(); plt.show()

Clean linear decision boundaries · the same F.cross_entropy you’ve used dozens of times is MLE under a categorical conditional distribution.

5 · MAP for linear regression with Gaussian prior = ridge

Add a Gaussian prior \(\theta_j \sim \mathcal{N}(0, \sigma_p^2)\) to the linear regression of section 2. The MAP objective is

\[L_{\text{MAP}}(\theta) = -\log p(\mathcal{D} \mid \theta) - \log p(\theta) = \text{NLL} + \lambda \|\theta\|_2^2\]

where \(\lambda = 1/(2\sigma_p^2)\). We sweep \(\lambda\) and watch the weights shrink smoothly toward zero.

# Reuse the linear-regression dataset (high-dimensional this time so shrinkage is visible)
N, d = 80, 20
X = torch.randn(N, d)
theta_true = torch.zeros(d)
theta_true[:5] = torch.tensor([2.0, -1.5, 1.0, 0.7, -0.4])   # only first 5 features matter
y = X @ theta_true + 0.3 * torch.randn(N)

def fit_map(X, y, lam, prior='gaussian', n_steps=4000, lr=0.01):
    # L2 (Gaussian prior) via plain gradient descent.
    # L1 (Laplace prior) via proximal gradient (ISTA) -- soft-thresholding produces exact zeros.
    N = X.shape[0]
    theta = torch.zeros(X.shape[1], requires_grad=True)
    opt = torch.optim.SGD([theta], lr=lr)
    for _ in range(n_steps):
        mu = X @ theta
        nll = (1.0 / N) * ((y - mu) ** 2).sum()        # mean squared error
        if prior == 'gaussian':
            loss = nll + lam * (theta ** 2).sum()
            opt.zero_grad(); loss.backward(); opt.step()
        elif prior == 'laplace':
            opt.zero_grad(); nll.backward(); opt.step()
            with torch.no_grad():
                theta.data = torch.sign(theta.data) * torch.clamp(theta.data.abs() - lr * lam, min=0.0)
        else:
            opt.zero_grad(); nll.backward(); opt.step()
    return theta.detach()

lambdas = [0.0, 0.05, 0.2, 0.5, 1.0]
gaussian_fits = [fit_map(X, y, lam, 'gaussian') for lam in lambdas]
laplace_fits  = [fit_map(X, y, lam, 'laplace')  for lam in lambdas]
fig, axes = plt.subplots(1, 2, figsize=(13, 4.2), sharey=True)

for fit, lam in zip(gaussian_fits, lambdas):
    axes[0].plot(fit.numpy(), 'o-', label=f'$\\lambda$={lam}', alpha=0.7)
axes[0].axhline(0, color='k', lw=0.5)
axes[0].set_title('Gaussian prior · L2 (ridge) · weights shrink smoothly')
axes[0].set_xlabel('feature index'); axes[0].set_ylabel('weight'); axes[0].legend()

for fit, lam in zip(laplace_fits, lambdas):
    axes[1].plot(fit.numpy(), 'o-', label=f'$\\lambda$={lam}', alpha=0.7)
axes[1].axhline(0, color='k', lw=0.5)
axes[1].set_title('Laplace prior · L1 (lasso) · weights snap to zero')
axes[1].set_xlabel('feature index'); axes[1].legend()

plt.tight_layout(); plt.show()

Compare the two panels.

  • L2 (Gaussian prior) — every weight shrinks smoothly toward zero as \(\lambda\) grows, but few are exactly zero.
  • L1 (Laplace prior) — many weights snap to exactly zero as \(\lambda\) grows. Sparsity emerges.

The data was generated such that only the first 5 features matter. L1 correctly identifies this — the “feature selection” property of lasso, derived end-to-end from a Laplace prior on the weights.

6 · Sparsity in numbers

How many weights are exactly zero (within numerical tolerance) under L1 vs L2?

print(f"{'lambda':>8s} | {'L2 zeros':>8s} | {'L1 zeros':>8s}")
print('-' * 32)
for lam, g, l in zip(lambdas, gaussian_fits, laplace_fits):
    n_zero_g = (g.abs() < 1e-3).sum().item()
    n_zero_l = (l.abs() < 1e-3).sum().item()
    print(f"{lam:8.2f} | {n_zero_g:8d} | {n_zero_l:8d}")

L1 zeros out 15+ of the 20 weights at \(\lambda = 10\), recovering the underlying sparsity (only 5 features were truly active). L2 never hits zero — it just shrinks.

Why L1 zeros out and L2 doesn’t · the gradient of \(\lambda\|\theta\|_1\) is \(\lambda\,\text{sign}(\theta_j)\) — a constant push toward zero. The gradient of \(\lambda\|\theta\|_2^2\) is \(2\lambda\theta_j\)proportional to the weight, so it gets weaker as the weight gets smaller and never quite reaches zero.

7 · Visualize · likelihood and posterior in 2D

For a tiny 2-parameter linear regression, plot the log-likelihood surface, the log-prior surface, and the log-posterior surface (= sum). The MAP estimate sits at the peak of the third surface — a compromise between the data and the prior.

# Tiny dataset · 2 parameters (theta_1, theta_2)
N = 12
X_tiny = torch.randn(N, 2)
theta_star = torch.tensor([2.5, 1.0])
y_tiny = X_tiny @ theta_star + 0.4 * torch.randn(N)

# Grid over (theta_1, theta_2)
g = torch.linspace(-1, 4, 80)
T1, T2 = torch.meshgrid(g, g, indexing='ij')
ll, lprior, lpost = torch.zeros_like(T1), torch.zeros_like(T1), torch.zeros_like(T1)

lam = 0.5
for i in range(g.numel()):
    for j in range(g.numel()):
        theta = torch.tensor([T1[i, j], T2[i, j]])
        mu = X_tiny @ theta
        ll_ij = -0.5 * ((y_tiny - mu) ** 2).sum()
        lp_ij = -lam * (theta ** 2).sum()       # Gaussian prior
        ll[i, j] = ll_ij
        lprior[i, j] = lp_ij
        lpost[i, j] = ll_ij + lp_ij

mle_idx = ll.flatten().argmax(); map_idx = lpost.flatten().argmax()
mle = (T1.flatten()[mle_idx].item(), T2.flatten()[mle_idx].item())
map_ = (T1.flatten()[map_idx].item(), T2.flatten()[map_idx].item())

fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
for ax, surface, title, peak in zip(
    axes,
    [ll, lprior, lpost],
    ['log-likelihood', 'log-prior (Gaussian, $\\lambda$=0.5)', 'log-posterior = sum'],
    [mle, (0, 0), map_]
):
    cs = ax.contourf(T1, T2, surface, levels=20, cmap='viridis')
    ax.contour(T1, T2, surface, levels=10, colors='white', alpha=0.3, linewidths=0.5)
    ax.plot(*peak, 'r*', markersize=18, markeredgecolor='white', markeredgewidth=1.5)
    ax.set_title(title); ax.set_xlabel('$\\theta_1$'); ax.set_ylabel('$\\theta_2$')

print(f"MLE: theta = ({mle[0]:.2f}, {mle[1]:.2f})")
print(f"MAP: theta = ({map_[0]:.2f}, {map_[1]:.2f})  ← pulled toward zero by the prior")
plt.tight_layout(); plt.show()

The MAP estimate sits between the MLE peak and the prior peak (origin). The stronger you make the prior (larger \(\lambda\)), the more MAP shifts toward zero — and in the limit \(\lambda \to \infty\), MAP collapses to the prior, ignoring the data entirely. In the limit \(\lambda \to 0\), MAP becomes MLE.

8 · Try this

Each of these takes 2-3 lines to change. Predict the answer first, then run.

  1. MLE blows up — set the coin data to all heads (data = torch.ones(10)). Run section 1. Where does \(\hat p\) go? Why is this a problem?
  2. Tiny prior fixes it — re-run with a Beta(2, 2) prior in MAP form: add + lam * (-torch.log(p) - torch.log(1 - p)) to the loss. Does \(\hat p\) stay sane?
  3. Heteroscedastic regression — replace sigma_true with a per-example value (e.g. 0.1 + 0.5 * torch.abs(X[:, 0])). Does the MLE objective in section 2 still match OLS? (Hint · no — different distribution = different loss. This motivates weighted least squares.)
  4. Tiny data, big λ — in section 5, use only N=10 examples. How does L1 behave at large \(\lambda\)? Does it still find the true 5 features?
  5. Sweep the prior — in section 7, re-run the visualization with \(\lambda = 0.05\) and \(\lambda = 5.0\). Watch the MAP point slide between the MLE peak and the origin.

Same recipe, every time · pick the conditional distribution, write the NLL, optionally add the log-prior. The deep models in this course (VAEs, diffusion, RLHF) are all instances of this pattern with fancier distributions.