24 Lab 08 — Verify Proposition 3.4
Anchor chapter: Chapter 8 — The Forward Pass as Harmonic Extension.
Goal. Show numerically that the forward pass of a small ReLU network equals the harmonic extension on the neural sheaf built in Lab 07.
Build the restricted coboundary \(\delta_\Omega(\sigma)\) for a small pretrained MLP (e.g., a [2, 8, 1] regressor) as a dense matrix. Verify Lem. 8.2 numerically (\(\det \delta_\Omega = 1\) up to floating-point noise), solve the pinned Dirichlet problem by one triangular back-substitution, and compare the solution component-by-component against forward(x) from PyTorch. Repeat across 100 random inputs to confirm Prop. 8.6 holds independently of which activation pattern is realised.
This lab imports torch; PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.
Prefer a local Jupyter environment? Download lab-08-verify-prop-3-4.ipynb
Install dependencies: pip install torch numpy matplotlib
24.1 Setup
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scipy.linalg import solve_triangular
rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)24.2 1. Build the object
We initialise a [2, 8, 1] ReLU MLP, extract its weight matrices as NumPy arrays, and for any input \(x\) build the restricted coboundary \(\delta_\Omega(\sigma)\) as a dense \((n_1 + n_2) \times (n_1 + n_2)\) lower-triangular matrix. The key identity is \((δ_\Omega)_{\ell\ell} = I_{n_\ell}\): each diagonal block is an identity matrix, making \(\delta_\Omega\) unitriangular and therefore invertible with \(\det = 1\) regardless of the activation pattern \(\sigma\).
# ── Network and weight extraction ────────────────────────────────────────
dims = [2, 8, 1] # [n0, n1, n2]
net = nn.Sequential(nn.Linear(dims[0], dims[1]), nn.ReLU(),
nn.Linear(dims[1], dims[2]))
W1 = net[0].weight.detach().numpy() # (8, 2)
b1 = net[0].bias.detach().numpy() # (8,)
W2 = net[2].weight.detach().numpy() # (1, 8)
b2 = net[2].bias.detach().numpy() # (1,)
def build_delta_omega(x, W1, b1, W2, b2):
"""
Build δ_Ω(σ) for input x. Returns (delta, rhs) where
delta : (n1+n2, n1+n2) lower-triangular, det=1
rhs : (n1+n2,) boundary forcing (so harmonic ext. solves delta @ c = rhs)
"""
n0, n1, n2 = len(x), W1.shape[0], W2.shape[0]
z1 = W1 @ x + b1 # pre-activation layer 1
sig1 = np.diag((z1 > 0).astype(float)) # Σ^(1), (n1, n1)
N = n1 + n2
delta = np.zeros((N, N))
# Edge e1 rows 0..n1-1: head = z^(1), tail = z^(0)=x (pinned)
delta[:n1, :n1] = np.eye(n1) # F_{1 <= e1} = I
# Edge e2 rows n1..N-1: head = z^(2), tail = z^(1)
delta[n1:, :n1] = -W2 @ sig1 # -F_{1 <= e2} = -W2 Σ1
delta[n1:, n1:] = np.eye(n2) # F_{2 <= e2} = I
# Boundary RHS: δ_Ω c_Ω = rhs ⟺ c_Ω = harmonic extension
rhs = np.zeros(N)
rhs[:n1] = W1 @ x + b1 # from pinned z^(0)=x via edge e1
rhs[n1:] = b2 # bias on edge e2
return delta, rhs
x_test = np.array([0.5, -1.2])
delta, rhs = build_delta_omega(x_test, W1, b1, W2, b2)
print("δ_Ω shape:", delta.shape)
print(f"det(δ_Ω) = {np.linalg.det(delta):.8f} (Lemma 8.2: should be 1)")
print(f"Is lower-triangular: {np.allclose(delta, np.tril(delta))}")24.3 2. Verify a theorem / run an experiment
We solve the harmonic extension by triangular back-substitution (one forward solve, since \(\delta_\Omega\) is lower-triangular). The resulting cochain components are the hidden-layer pre-activations \(z^{(1)}\) and the output pre-activation \(z^{(2)}\); they should match PyTorch’s net(x) and internal hook values to within floating-point precision. We repeat across 100 random inputs and plot the maximum absolute error per input to confirm Proposition 3.4 holds for all activation patterns.
# ── Triangular back-substitution (exploit lower-triangular structure) ─────
c_harm = solve_triangular(delta, rhs, lower=True)
z1_harm = c_harm[:dims[1]]
z2_harm = c_harm[dims[1]:]
# Compare to PyTorch forward pass
with torch.no_grad():
x_t = torch.tensor(x_test, dtype=torch.float32)
z2_pt = net(x_t).numpy()
# Internal z1 via hook
z1_pt = net[0](x_t).numpy() # pre-activation of layer 1
print("Pre-activations z^(1):")
print(f" Harmonic ext.: {np.round(z1_harm, 4)}")
print(f" PyTorch: {np.round(z1_pt, 4)}")
print(f" Max |error|: {abs(z1_harm - z1_pt).max():.2e}")
print(f"\nOutput z^(2):")
print(f" Harmonic ext.: {z2_harm}")
print(f" PyTorch: {z2_pt.flatten()}")
print(f" Max |error|: {abs(z2_harm - z2_pt.flatten()).max():.2e}")
# ── Sweep 100 random inputs ──────────────────────────────────────────────
errors_z1, errors_z2 = [], []
dets = []
for _ in range(100):
xr = rng.standard_normal(dims[0])
dr, rr = build_delta_omega(xr, W1, b1, W2, b2)
dets.append(np.linalg.det(dr))
cr = solve_triangular(dr, rr, lower=True)
with torch.no_grad():
xrt = torch.tensor(xr, dtype=torch.float32)
z1_pt = net[0](xrt).numpy()
z2_pt = net(xrt).numpy()
errors_z1.append(abs(cr[:dims[1]] - z1_pt).max())
errors_z2.append(abs(cr[dims[1]:] - z2_pt.flatten()).max())
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
axes[0].semilogy(errors_z1, 'o', ms=4, label='z^(1) error')
axes[0].semilogy(errors_z2, 's', ms=4, label='z^(2) error')
axes[0].set_xlabel('input index'); axes[0].set_ylabel('max |error|')
axes[0].set_title('Harmonic ext. vs forward pass (100 inputs)')
axes[0].legend(); axes[0].axhline(1e-6, color='gray', linestyle='--', alpha=0.5)
axes[1].hist(dets, bins=10, color='steelblue', edgecolor='white')
axes[1].axvline(1.0, color='red', linestyle='--', label='det=1')
axes[1].set_xlabel('det(δ_Ω)'); axes[1].set_ylabel('count')
axes[1].set_title('det(δ_Ω) across 100 inputs\n(Lemma 8.2)')
axes[1].legend()
plt.tight_layout(); plt.show()
print(f"\nMax z^(1) error: {max(errors_z1):.2e}")
print(f"Max z^(2) error: {max(errors_z2):.2e}")
print(f"Prop. 3.4 confirmed: {max(errors_z2) < 1e-5}")
print(f"det(δ_Ω) range: [{min(dets):.8f}, {max(dets):.8f}]")24.4 Exercises
Deeper network. Extend the construction to a [2, 8, 4, 1] network (two hidden layers). Build the \(13 \times 13\) lower-triangular \(\delta_\Omega\), verify \(\det = 1\), solve by triangular substitution, and confirm the match to PyTorch across 50 random inputs.
Condition number. Compute \(\kappa(\delta_\Omega) = \|\delta_\Omega\| \cdot \|\delta_\Omega^{-1}\|\) as a function of weight magnitude \(\alpha\) (scale all weights by \(\alpha \in \{0.1, 0.5, 1, 2, 5\}\)). Plot \(\kappa\) vs \(\alpha\) and explain the dependence.
Verification via Dirichlet energy. The harmonic extension minimises \(\|\delta_\sigma c\|^2\) subject to \(c^{(0)} = x\). Verify this numerically by computing \(\|\delta_\sigma c\|^2\) for the harmonic extension and for 20 random perturbations around it, confirming the harmonic extension achieves the minimum.
Bias absorption. The bias vector \(b^{(\ell)}\) appears in \(\tilde{b}\) (the RHS) rather than in \(\delta_\Omega\) itself. Reformulate the problem by absorbing the bias into an augmented cochain \(\tilde{c}^{(\ell)} = (c^{(\ell)}, 1)\) and modified weight matrices \(\tilde{W}^{(\ell)} = [W^{(\ell)}, b^{(\ell)}]\). What does \(\delta_\Omega\) look like in this augmented formulation?