24  Lab 08 — Verify Proposition 3.4

Anchor chapter: Chapter 8 — The Forward Pass as Harmonic Extension.

Goal. Show numerically that the forward pass of a small ReLU network equals the harmonic extension on the neural sheaf built in Lab 07.

Build the restricted coboundary \(\delta_\Omega(\sigma)\) for a small pretrained MLP (e.g., a [2, 8, 1] regressor) as a dense matrix. Verify Lem. 8.2 numerically (\(\det \delta_\Omega = 1\) up to floating-point noise), solve the pinned Dirichlet problem by one triangular back-substitution, and compare the solution component-by-component against forward(x) from PyTorch. Repeat across 100 random inputs to confirm Prop. 8.6 holds independently of which activation pattern is realised.

This lab imports torch; PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.

Prefer a local Jupyter environment? Download lab-08-verify-prop-3-4.ipynb

Install dependencies: pip install torch numpy matplotlib

24.1 Setup

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scipy.linalg import solve_triangular

rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)

24.2 1. Build the object

We initialise a [2, 8, 1] ReLU MLP, extract its weight matrices as NumPy arrays, and for any input \(x\) build the restricted coboundary \(\delta_\Omega(\sigma)\) as a dense \((n_1 + n_2) \times (n_1 + n_2)\) lower-triangular matrix. The key identity is \((δ_\Omega)_{\ell\ell} = I_{n_\ell}\): each diagonal block is an identity matrix, making \(\delta_\Omega\) unitriangular and therefore invertible with \(\det = 1\) regardless of the activation pattern \(\sigma\).

# ── Network and weight extraction ────────────────────────────────────────
dims = [2, 8, 1]   # [n0, n1, n2]
net  = nn.Sequential(nn.Linear(dims[0], dims[1]), nn.ReLU(),
                     nn.Linear(dims[1], dims[2]))

W1 = net[0].weight.detach().numpy()   # (8, 2)
b1 = net[0].bias.detach().numpy()     # (8,)
W2 = net[2].weight.detach().numpy()   # (1, 8)
b2 = net[2].bias.detach().numpy()     # (1,)


def build_delta_omega(x, W1, b1, W2, b2):
    """
    Build δ_Ω(σ) for input x.  Returns (delta, rhs) where
      delta : (n1+n2, n1+n2) lower-triangular, det=1
      rhs   : (n1+n2,) boundary forcing   (so harmonic ext. solves delta @ c = rhs)
    """
    n0, n1, n2 = len(x), W1.shape[0], W2.shape[0]
    z1   = W1 @ x + b1                          # pre-activation layer 1
    sig1 = np.diag((z1 > 0).astype(float))      # Σ^(1),  (n1, n1)

    N = n1 + n2
    delta = np.zeros((N, N))
    # Edge e1 rows 0..n1-1: head = z^(1), tail = z^(0)=x (pinned)
    delta[:n1, :n1] = np.eye(n1)                # F_{1 <= e1} = I
    # Edge e2 rows n1..N-1: head = z^(2), tail = z^(1)
    delta[n1:, :n1] = -W2 @ sig1               # -F_{1 <= e2} = -W2 Σ1
    delta[n1:, n1:] = np.eye(n2)               #  F_{2 <= e2} = I

    # Boundary RHS: δ_Ω c_Ω = rhs  ⟺  c_Ω = harmonic extension
    rhs = np.zeros(N)
    rhs[:n1] = W1 @ x + b1                     # from pinned z^(0)=x via edge e1
    rhs[n1:] = b2                               # bias on edge e2

    return delta, rhs


x_test = np.array([0.5, -1.2])
delta, rhs = build_delta_omega(x_test, W1, b1, W2, b2)

print("δ_Ω shape:", delta.shape)
print(f"det(δ_Ω) = {np.linalg.det(delta):.8f}  (Lemma 8.2: should be 1)")
print(f"Is lower-triangular: {np.allclose(delta, np.tril(delta))}")

24.3 2. Verify a theorem / run an experiment

We solve the harmonic extension by triangular back-substitution (one forward solve, since \(\delta_\Omega\) is lower-triangular). The resulting cochain components are the hidden-layer pre-activations \(z^{(1)}\) and the output pre-activation \(z^{(2)}\); they should match PyTorch’s net(x) and internal hook values to within floating-point precision. We repeat across 100 random inputs and plot the maximum absolute error per input to confirm Proposition 3.4 holds for all activation patterns.

# ── Triangular back-substitution (exploit lower-triangular structure) ─────
c_harm = solve_triangular(delta, rhs, lower=True)
z1_harm = c_harm[:dims[1]]
z2_harm = c_harm[dims[1]:]

# Compare to PyTorch forward pass
with torch.no_grad():
    x_t   = torch.tensor(x_test, dtype=torch.float32)
    z2_pt = net(x_t).numpy()
    # Internal z1 via hook
    z1_pt = net[0](x_t).numpy()   # pre-activation of layer 1

print("Pre-activations z^(1):")
print(f"  Harmonic ext.: {np.round(z1_harm, 4)}")
print(f"  PyTorch:       {np.round(z1_pt, 4)}")
print(f"  Max |error|:   {abs(z1_harm - z1_pt).max():.2e}")

print(f"\nOutput z^(2):")
print(f"  Harmonic ext.: {z2_harm}")
print(f"  PyTorch:       {z2_pt.flatten()}")
print(f"  Max |error|:   {abs(z2_harm - z2_pt.flatten()).max():.2e}")

# ── Sweep 100 random inputs ──────────────────────────────────────────────
errors_z1, errors_z2 = [], []
dets = []
for _ in range(100):
    xr  = rng.standard_normal(dims[0])
    dr, rr = build_delta_omega(xr, W1, b1, W2, b2)
    dets.append(np.linalg.det(dr))
    cr   = solve_triangular(dr, rr, lower=True)
    with torch.no_grad():
        xrt   = torch.tensor(xr, dtype=torch.float32)
        z1_pt = net[0](xrt).numpy()
        z2_pt = net(xrt).numpy()
    errors_z1.append(abs(cr[:dims[1]] - z1_pt).max())
    errors_z2.append(abs(cr[dims[1]:] - z2_pt.flatten()).max())

fig, axes = plt.subplots(1, 2, figsize=(11, 4))
axes[0].semilogy(errors_z1, 'o', ms=4, label='z^(1) error')
axes[0].semilogy(errors_z2, 's', ms=4, label='z^(2) error')
axes[0].set_xlabel('input index'); axes[0].set_ylabel('max |error|')
axes[0].set_title('Harmonic ext. vs forward pass (100 inputs)')
axes[0].legend(); axes[0].axhline(1e-6, color='gray', linestyle='--', alpha=0.5)

axes[1].hist(dets, bins=10, color='steelblue', edgecolor='white')
axes[1].axvline(1.0, color='red', linestyle='--', label='det=1')
axes[1].set_xlabel('det(δ_Ω)'); axes[1].set_ylabel('count')
axes[1].set_title('det(δ_Ω) across 100 inputs\n(Lemma 8.2)')
axes[1].legend()
plt.tight_layout(); plt.show()

print(f"\nMax z^(1) error: {max(errors_z1):.2e}")
print(f"Max z^(2) error: {max(errors_z2):.2e}")
print(f"Prop. 3.4 confirmed: {max(errors_z2) < 1e-5}")
print(f"det(δ_Ω) range: [{min(dets):.8f}, {max(dets):.8f}]")

24.4 Exercises

  1. Deeper network. Extend the construction to a [2, 8, 4, 1] network (two hidden layers). Build the \(13 \times 13\) lower-triangular \(\delta_\Omega\), verify \(\det = 1\), solve by triangular substitution, and confirm the match to PyTorch across 50 random inputs.

  2. Condition number. Compute \(\kappa(\delta_\Omega) = \|\delta_\Omega\| \cdot \|\delta_\Omega^{-1}\|\) as a function of weight magnitude \(\alpha\) (scale all weights by \(\alpha \in \{0.1, 0.5, 1, 2, 5\}\)). Plot \(\kappa\) vs \(\alpha\) and explain the dependence.

  3. Verification via Dirichlet energy. The harmonic extension minimises \(\|\delta_\sigma c\|^2\) subject to \(c^{(0)} = x\). Verify this numerically by computing \(\|\delta_\sigma c\|^2\) for the harmonic extension and for 20 random perturbations around it, confirming the harmonic extension achieves the minimum.

  4. Bias absorption. The bias vector \(b^{(\ell)}\) appears in \(\tilde{b}\) (the RHS) rather than in \(\delta_\Omega\) itself. Reformulate the problem by absorbing the bias into an augmented cochain \(\tilde{c}^{(\ell)} = (c^{(\ell)}, 1)\) and modified weight matrices \(\tilde{W}^{(\ell)} = [W^{(\ell)}, b^{(\ell)}]\). What does \(\delta_\Omega\) look like in this augmented formulation?