"""
Capacity Analysis for Disentangled Transformer on Erdős–Rényi graphs

This script performs multiple training runs with different edge probabilities (p) to analyze
how the proportion of beyond-capacity graphs affects model performance. For each p value,
we train a model and evaluate it on:
1. In-distribution Erdős-Rényi graphs
2. Two-chain graphs (out-of-distribution)

We analyze:
- Accuracy on ER graphs (in-distribution)
- Accuracy on two-chain graphs (out-of-distribution)
- Final energy share between A and B channels
- Proportion of within/beyond capacity graphs
- Loss dynamics and convergence

Outputs:
- Model checkpoints: ./analysis_capacity/p_{p_value:.3f}/
- Analysis plots: ./analysis_capacity/analysis/
- Metrics CSV: ./analysis_capacity/capacity_metrics.csv
"""

import argparse
import csv
import math
import os
import pickle
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from tqdm import tqdm


# ----------------------------
# Utilities: graph generation
# ----------------------------


def erdos_renyi_augmented(n: int, p: float, rng: np.random.Generator) -> np.ndarray:
    """G(n,p) undirected, no self loops, then add self loops: A := A + I."""
    A = (rng.random((n, n)) < p).astype(np.float32)
    A = np.triu(A, 1)
    A = A + A.T
    np.fill_diagonal(A, 0.0)
    A = A + np.eye(n, dtype=np.float32)  # self-loops
    return A


def compute_graph_diameter(A: np.ndarray) -> int:
    """
    Compute the diameter of a graph (longest shortest path between any two vertices).
    For disconnected graphs, returns the diameter of the largest connected component.
    """
    n = A.shape[0]

    # Get connected components
    comps = connected_components_from_augmented(A)

    # For disconnected graphs, compute diameter of largest component
    if len(comps) > 1:
        largest_comp = max(comps, key=len)
        if len(largest_comp) == 1:
            return 0  # Single vertex has diameter 0

        # Create subgraph adjacency matrix for largest component
        comp_indices = {v: i for i, v in enumerate(largest_comp)}
        comp_size = len(largest_comp)
        A_comp = np.zeros((comp_size, comp_size))
        for i, u in enumerate(largest_comp):
            for j, v in enumerate(largest_comp):
                A_comp[i, j] = A[u, v]

        # Compute diameter on the largest component
        return _compute_diameter_bfs(A_comp)
    else:
        # Graph is connected, compute diameter directly
        return _compute_diameter_bfs(A)


def _compute_diameter_bfs(A: np.ndarray) -> int:
    """
    Compute diameter using BFS from each vertex.
    A should be the adjacency matrix of a connected component.
    """
    n = A.shape[0]
    if n <= 1:
        return 0

    max_distance = 0

    # For each vertex, compute shortest paths to all other vertices using BFS
    for start in range(n):
        distances = [-1] * n
        distances[start] = 0
        queue = [start]

        while queue:
            u = queue.pop(0)
            for v in range(n):
                # Check if there's an edge and v hasn't been visited
                if v != u and A[u, v] > 0 and distances[v] == -1:
                    distances[v] = distances[u] + 1
                    queue.append(v)

        # Find maximum distance from this start vertex
        for dist in distances:
            if dist > max_distance:
                max_distance = dist

    return max_distance


def connected_components_from_augmented(A: np.ndarray) -> List[List[int]]:
    """Connected components for an undirected graph given augmented adjacency (self-loops present)."""
    n = A.shape[0]
    seen = [False] * n
    comps = []
    for s in range(n):
        if seen[s]:
            continue
        q = [s]
        seen[s] = True
        cur = [s]
        while q:
            u = q.pop(0)
            for v in range(n):
                if v != u and A[u, v] > 0 and not seen[v]:
                    seen[v] = True
                    q.append(v)
                    cur.append(v)
        comps.append(cur)
    return comps


def reachability_matrix(n: int, comps: List[List[int]]) -> np.ndarray:
    """R[i,j] = 1 iff i and j lie in the same component."""
    R = np.zeros((n, n), dtype=np.float32)
    for comp in comps:
        for i in comp:
            for j in comp:
                R[i, j] = 1.0
    return R


def batch_er_graphs(
    n: int,
    p: float,
    batch_size: int,
    seed: int = 0,
    restrict_diam: Union[int, bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Return batched augmented adjacency A (B×n×n) and reachability R (B×n×n) as torch float tensors."""
    rng = np.random.default_rng(seed)
    As, Rs = [], []

    # Determine diameter threshold
    if isinstance(restrict_diam, bool):
        if restrict_diam:
            diam_threshold = 3
            do_restrict = True
        else:
            do_restrict = False
            diam_threshold = None
    else:  # restrict_diam is an int
        diam_threshold = restrict_diam
        do_restrict = True

    while len(As) < batch_size:
        A = erdos_renyi_augmented(n, p, rng)
        diam = compute_graph_diameter(A)
        if do_restrict:
            if diam > diam_threshold:
                continue
        comps = connected_components_from_augmented(A)
        R = reachability_matrix(n, comps)
        As.append(A)
        Rs.append(R)
    A_batch = torch.from_numpy(np.stack(As, axis=0))  # (B, n, n)
    R_batch = torch.from_numpy(np.stack(Rs, axis=0))  # (B, n, n)
    return A_batch, R_batch


def generate_two_chains_graph(
    n: int, k: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate a two-chains graph with n nodes.
    If k is None, use k = n // 2 for each chain.
    Returns adjacency matrix A and connectivity matrix R.
    """
    if k is None:
        k = n // 2

    # Create adjacency matrix
    A = torch.zeros((n, n))
    # First chain: 0 to k-1
    for i in range(k - 1):
        A[i, i + 1] = A[i + 1, i] = 1
    # Second chain: k to 2*k-1 (if 2*k <= n)
    if 2 * k <= n:
        for i in range(k, 2 * k - 1):
            A[i, i + 1] = A[i + 1, i] = 1

    # Add self-loops
    A = A + torch.eye(n)

    # Create connectivity matrix (block diagonal)
    R = torch.zeros((n, n))
    R[0:k, 0:k] = 1  # First chain
    if 2 * k <= n:
        R[k : 2 * k, k : 2 * k] = 1  # Second chain
    # Remaining nodes are isolated (only self-connected)
    for i in range(2 * k, n):
        R[i, i] = 1

    return A, R


def batch_two_chains_graphs(
    n: int, batch_size: int, k: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate a batch of two-chains graphs with random permutations."""
    As, Rs = [], []
    for _ in range(batch_size):
        A, R = generate_two_chains_graph(n, k)
        # Apply random permutation
        perm = torch.randperm(n)
        A = A[perm][:, perm]
        R = R[perm][:, perm]
        As.append(A)
        Rs.append(R)

    A_batch = torch.stack(As, dim=0)  # (B, n, n)
    R_batch = torch.stack(Rs, dim=0)  # (B, n, n)
    return A_batch, R_batch


def analyze_capacity(
    A_batch: torch.Tensor, capacity_threshold: int = 3
) -> Dict[str, float]:
    """
    Analyze the capacity properties of a batch of graphs.
    Returns proportion of graphs within/beyond capacity.
    """
    batch_size = A_batch.shape[0]
    within_capacity = 0
    beyond_capacity = 0
    diameters = []

    for i in range(batch_size):
        A = A_batch[i].numpy()
        diam = compute_graph_diameter(A)
        diameters.append(diam)

        if diam <= capacity_threshold:
            within_capacity += 1
        else:
            beyond_capacity += 1

    return {
        "within_capacity_prop": within_capacity / batch_size,
        "beyond_capacity_prop": beyond_capacity / batch_size,
        "avg_diameter": np.mean(diameters),
    }


def analyze_diameter_distribution(A_batch: torch.Tensor, n: int) -> Dict[int, int]:
    """
    Analyze the diameter distribution of a batch of graphs.
    Returns a dictionary mapping diameter values to their counts.
    """
    batch_size = A_batch.shape[0]
    diameter_counts = {}

    for i in range(batch_size):
        A = A_batch[i].numpy()
        diam = compute_graph_diameter(A)

        if diam in diameter_counts:
            diameter_counts[diam] += 1
        else:
            diameter_counts[diam] = 1

    return diameter_counts


def plot_diameter_distribution(
    diameter_counts: Dict[int, int], n: int, p: float, save_dir: str
):
    """
    Plot the diameter distribution for a given p value.
    """
    # Prepare data for plotting
    diameters = list(range(1, n))  # Possible diameter range from 1 to n-1
    counts = [diameter_counts.get(d, 0) for d in diameters]
    total_graphs = sum(diameter_counts.values())
    proportions = [c / total_graphs for c in counts]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot counts
    ax1.bar(diameters, counts, alpha=0.7, color="skyblue", edgecolor="black")
    ax1.set_xlabel("Diameter")
    ax1.set_ylabel("Count")
    ax1.set_title(f"Diameter Distribution (p={p:.3f}) - Counts")
    ax1.set_xticks(diameters)
    ax1.grid(True, alpha=0.3)

    # Plot proportions
    ax2.bar(diameters, proportions, alpha=0.7, color="lightcoral", edgecolor="black")
    ax2.set_xlabel("Diameter")
    ax2.set_ylabel("Proportion")
    ax2.set_title(f"Diameter Distribution (p={p:.3f}) - Proportions")
    ax2.set_xticks(diameters)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()

    # Save the plot
    os.makedirs(save_dir, exist_ok=True)
    plot_path = os.path.join(save_dir, "diameter_distribution.png")
    plt.savefig(plot_path, dpi=200, bbox_inches="tight")
    plt.close()

    return plot_path


def analyze_eval_er_diameter_distribution(
    config, analysis_dir: str, eval_p: float = 0.20
):
    """
    Generate evaluation ER graphs and plot their diameter distribution.
    Save the plot to the analysis directory.
    """
    print(f"Generating evaluation ER graphs diameter distribution (p={eval_p:.2f})...")

    # Generate evaluation ER graphs
    A_eval_er, R_eval_er = batch_er_graphs(
        config.n,
        eval_p,
        config.eval_batch_size,
        seed=config.seed + 99999999,  # Different seed for eval graphs
        restrict_diam=False,  # No diameter restriction for evaluation
    )

    # Analyze diameter distribution
    diameter_distribution = analyze_diameter_distribution(A_eval_er, config.n)

    # Create custom plot for evaluation ER graphs
    diameters = list(range(1, config.n))  # Possible diameter range from 1 to n-1
    counts = [diameter_distribution.get(d, 0) for d in diameters]
    total_graphs = sum(diameter_distribution.values())
    proportions = [c / total_graphs for c in counts]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot counts
    ax1.bar(diameters, counts, alpha=0.7, color="lightgreen", edgecolor="black")
    ax1.set_xlabel("Diameter")
    ax1.set_ylabel("Count")
    ax1.set_title(
        f"Evaluation ER Graphs Diameter Distribution (p={eval_p:.2f}) - Counts"
    )
    ax1.set_xticks(diameters)
    ax1.grid(True, alpha=0.3)

    # Plot proportions
    ax2.bar(
        diameters, proportions, alpha=0.7, color="lightsteelblue", edgecolor="black"
    )
    ax2.set_xlabel("Diameter")
    ax2.set_ylabel("Proportion")
    ax2.set_title(
        f"Evaluation ER Graphs Diameter Distribution (p={eval_p:.2f}) - Proportions"
    )
    ax2.set_xticks(diameters)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()

    # Save the plot to analysis directory
    os.makedirs(analysis_dir, exist_ok=True)
    plot_path = os.path.join(analysis_dir, "eval_er_diameter_distribution.png")
    plt.savefig(plot_path, dpi=200, bbox_inches="tight")
    plt.close()

    print(f"  Saved evaluation ER diameter distribution plot: {plot_path}")
    return plot_path, diameter_distribution


# ----------------------------
# Model: 1-layer disentangled TF (equivariant W = A⊗I + B⊗J)
# ----------------------------


class OneLayerDisentangledTF(nn.Module):
    """
    One-layer Disentangled Transformer with
        W = A_ell ⊗ I_n + B_ell ⊗ J_n,
    where A_ell,B_ell ∈ R^{K×K}, K=2 since h0=[I | A].
    """

    def __init__(
        self,
        n: int,
        A_init_scale: float = 0.1,
        B_init_scale: float = 0.1,
        seed: int = 0,
    ):
        super().__init__()
        torch.manual_seed(seed)
        self.n = n
        self.K = 2  # because h0 = [I | A]
        # Unconstrained parameters
        # self.A_ell = nn.Parameter(torch.abs(A_init_scale * torch.randn(self.K, self.K)))
        # self.B_ell = nn.Parameter(torch.abs(B_init_scale * torch.randn(self.K, self.K)))
        self.A_ell = nn.Parameter(A_init_scale * torch.randn(self.K, self.K))
        self.B_ell = nn.Parameter(B_init_scale * torch.randn(self.K, self.K))

    def forward(self, A: torch.Tensor, return_intermediates: bool = False):
        """
        A: (B, n, n)
        Returns: Z: (B, n, n)
        """
        Bsz, n, _ = A.shape
        I = torch.eye(n, dtype=A.dtype, device=A.device).expand(Bsz, n, n)
        J = torch.ones(n, n, dtype=A.dtype, device=A.device)
        h0 = torch.cat([I, A], dim=-1)  # (B, n, 2n)

        # Equivariant weight: W = A_ell ⊗ I + B_ell ⊗ J
        W = torch.kron(
            self.A_ell, torch.eye(n, dtype=A.dtype, device=A.device)
        ) + torch.kron(self.B_ell, J)

        # Core S = ReLU( h0 W h0^T )
        S = F.relu(h0 @ W @ h0.transpose(-1, -2))  # (B, n, n)

        # Attn = (1/n) S h0 -> (B, n, 2n)
        Attn = (1.0 / n) * (S @ h0)

        # h1 = [h0 | Attn] -> (B, n, 4n)
        h1 = torch.cat([h0, Attn], dim=-1)

        # Z = sum of the 4 blocks (each n×n) along the block axis
        Z = h1.view(Bsz, n, 4, n).sum(dim=2)  # (B, n, n)

        if return_intermediates:
            return Z, {"W": W, "h0": h0, "S": S}
        return Z


class OneLayerWholeMatrixTF(nn.Module):
    """
    One-layer Transformer with whole matrix W (not decomposed).
    For analysis, we project W onto I and J components.
    """

    def __init__(
        self,
        n: int,
        W_init_scale: float = 0.1,
        seed: int = 0,
    ):
        super().__init__()
        torch.manual_seed(seed)
        self.n = n
        self.K = 2  # because h0 = [I | A]
        # Whole weight matrix W of size (2n, 2n)
        self.W = nn.Parameter(W_init_scale * torch.randn(2 * n, 2 * n))

    def forward(self, A: torch.Tensor, return_intermediates: bool = False):
        """
        A: (B, n, n)
        Returns: Z: (B, n, n)
        """
        Bsz, n, _ = A.shape
        I = torch.eye(n, dtype=A.dtype, device=A.device).expand(Bsz, n, n)
        h0 = torch.cat([I, A], dim=-1)  # (B, n, 2n)

        # Core S = ReLU( h0 W h0^T )
        S = F.relu(h0 @ self.W @ h0.transpose(-1, -2))  # (B, n, n)

        # Attn = (1/n) S h0 -> (B, n, 2n)
        Attn = (1.0 / n) * (S @ h0)

        # h1 = [h0 | Attn] -> (B, n, 4n)
        h1 = torch.cat([h0, Attn], dim=-1)

        # Z = sum of the 4 blocks (each n×n) along the block axis
        Z = h1.view(Bsz, n, 4, n).sum(dim=2)  # (B, n, n)

        if return_intermediates:
            return Z, {"W": self.W, "h0": h0, "S": S}
        return Z

    def project_W_to_channels(self):
        """
        Project the whole matrix W onto I and J channels using least-squares.
        Returns A_ell, B_ell such that W ≈ A_ell ⊗ I + B_ell ⊗ J
        minimizes ||W - (A_ell ⊗ I + B_ell ⊗ J)||_F^2
        """
        n = self.n
        device = self.W.device
        dtype = self.W.dtype

        I = torch.eye(n, dtype=dtype, device=device)
        J = torch.ones(n, n, dtype=dtype, device=device)

        # Create basis matrices for the 2x2 blocks
        # We need to solve for 8 unknowns: A_ell[i,j] and B_ell[i,j] for i,j in {0,1}

        # Flatten W for easier computation
        W_flat = self.W.flatten()  # Shape: (4n^2,)

        # Create the design matrix for the least squares problem
        # Each basis matrix corresponds to one coefficient we want to solve for
        basis_matrices = []

        # Add A_ell basis matrices: A_ell[i,j] * (e_i e_j^T ⊗ I)
        for i in range(2):
            for j in range(2):
                # Create 2x2 matrix with 1 at position (i,j), 0 elsewhere
                A_ij = torch.zeros(2, 2, dtype=dtype, device=device)
                A_ij[i, j] = 1.0
                # Kronecker product with identity
                basis = torch.kron(A_ij, I)
                basis_matrices.append(basis.flatten())

        # Add B_ell basis matrices: B_ell[i,j] * (e_i e_j^T ⊗ J)
        for i in range(2):
            for j in range(2):
                # Create 2x2 matrix with 1 at position (i,j), 0 elsewhere
                B_ij = torch.zeros(2, 2, dtype=dtype, device=device)
                B_ij[i, j] = 1.0
                # Kronecker product with all-ones matrix
                basis = torch.kron(B_ij, J)
                basis_matrices.append(basis.flatten())

        # Stack all basis matrices to form design matrix
        design_matrix = torch.stack(basis_matrices, dim=1)  # Shape: (4n^2, 8)

        # Solve least squares: design_matrix @ coeffs = W_flat
        # coeffs = (design_matrix^T @ design_matrix)^{-1} @ design_matrix^T @ W_flat
        try:
            coeffs = torch.linalg.lstsq(design_matrix, W_flat, driver="gels").solution
        except:
            # Fallback to normal equations if lstsq fails
            DTD = design_matrix.T @ design_matrix
            DTW = design_matrix.T @ W_flat
            try:
                coeffs = torch.linalg.solve(DTD, DTW)
            except:
                # Final fallback: use pseudo-inverse
                coeffs = torch.pinverse(design_matrix) @ W_flat

        # Extract A_ell and B_ell from coefficients
        A_ell = coeffs[:4].view(2, 2)  # First 4 coefficients for A_ell
        B_ell = coeffs[4:].view(2, 2)  # Last 4 coefficients for B_ell

        return A_ell, B_ell

    def compute_projection_error(self):
        """
        Compute the Frobenius norm error between W and its I+J projection.
        """
        A_ell, B_ell = self.project_W_to_channels()
        n = self.n
        device = self.W.device
        dtype = self.W.dtype

        I = torch.eye(n, dtype=dtype, device=device)
        J = torch.ones(n, n, dtype=dtype, device=device)

        # Reconstruct W from projections
        W_reconstructed = torch.kron(A_ell, I) + torch.kron(B_ell, J)

        # Compute projection error
        error = torch.norm(self.W - W_reconstructed, "fro")
        relative_error = error / torch.norm(self.W, "fro")

        return error.item(), relative_error.item()


def plot_W_and_projection(model, step, save_dir, p_value):
    """
    Plot W matrix and its I+J projection side by side for visualization.
    Only called when use_decomposition=False.
    """
    A_ell, B_ell = model.project_W_to_channels()
    n = model.n
    device = model.W.device
    dtype = model.W.dtype

    I = torch.eye(n, dtype=dtype, device=device)
    J = torch.ones(n, n, dtype=dtype, device=device)

    # Reconstruct W from projections
    W_reconstructed = torch.kron(A_ell, I) + torch.kron(B_ell, J)

    # Convert to numpy for plotting
    W_np = model.W.detach().cpu().numpy()
    W_recon_np = W_reconstructed.detach().cpu().numpy()

    # Fixed color scale range with 0.0 as center
    vmin = -0.5
    vmax = 2.0
    norm = TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax)

    # Compute projection error for title
    error = torch.norm(model.W - W_reconstructed, "fro")
    relative_error = error / torch.norm(model.W, "fro")

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    # Plot original W
    im1 = ax1.imshow(W_np, cmap="RdBu_r", aspect="equal", norm=norm)
    if step == 1:
        ax1_title = f"Original W (Initialization)"
        ax2_title = f"Projected W (Initialization)"
    else:
        ax1_title = f"Original W (Step {step})"
        ax2_title = f"Projected W (Step {step})"
    ax1.set_title(ax1_title, fontsize=14, fontweight="bold")
    ax1.axis("off")  # Turn off axes completely

    # Plot reconstructed W
    im2 = ax2.imshow(W_recon_np, cmap="RdBu_r", aspect="equal", norm=norm)
    ax2.set_title(ax2_title, fontsize=14, fontweight="bold")
    ax2.axis("off")  # Turn off axes completely

    # Manually position colorbar below the plots
    # [left, bottom, width, height] in figure coordinates
    cbar_ax = fig.add_axes(
        [0.25, -0.05, 0.5, 0.04]
    )  # Centered below, thin horizontal bar, further down
    cbar = plt.colorbar(im2, cax=cbar_ax, orientation="horizontal")

    # Make colorbar text larger and bold
    cbar.ax.tick_params(labelsize=14)
    plt.setp(cbar.ax.get_xticklabels(), fontweight="bold")

    # Add overall title with error information
    fig.suptitle(
        f"W Matrix Visualization (Relative Projection Error={relative_error:.4f})",
        fontsize=16,
        fontweight="bold",
    )

    plt.tight_layout()

    # Save the figure
    os.makedirs(save_dir, exist_ok=True)
    fig_path = os.path.join(save_dir, f"W_visualization_step_{step}.pdf")
    fig.savefig(fig_path, dpi=300, bbox_inches="tight", facecolor="white")
    plt.close(fig)

    return fig_path


# ----------------------------
# Channel metrics
# ----------------------------


@torch.no_grad()
def channel_metrics(model, n: int, use_decomposition: bool = True):
    """Return norms/inner-products and the J-channel energy share in full W = A⊗I + B⊗J."""
    if use_decomposition:
        # Original decomposed model
        A_ell = model.A_ell
        B_ell = model.B_ell
        # For decomposed models, W norm is computed from components
        nf = torch.linalg.norm
        dot = lambda X, Y: torch.sum(X * Y)
        nI2 = float(n)  # ||I_n||_F^2
        nJ2 = float(n * n)  # ||J_n||_F^2
        nIJ = float(n)  # <I_n, J_n>_F
        A2 = float(nf(A_ell).pow(2))
        B2 = float(nf(B_ell).pow(2))
        AB = float(dot(A_ell, B_ell))
        W2 = A2 * nI2 + B2 * nJ2 + 2.0 * AB * nIJ
    else:
        # Whole matrix model - compute actual energy shares in original W
        A_ell, B_ell = model.project_W_to_channels()

        # Create the I and J matrices
        device = model.W.device
        dtype = model.W.dtype
        I = torch.eye(n, dtype=dtype, device=device)
        J = torch.ones(n, n, dtype=dtype, device=device)

        # Compute the actual A and B components in the original W
        A_component = torch.kron(A_ell, I)  # A⊗I part
        B_component = torch.kron(B_ell, J)  # B⊗J part

        # Compute energy shares based on inner products with original W
        W_original = model.W
        W2 = float(torch.norm(W_original, "fro").pow(2))

        # Energy from A-channel: <W, A⊗I> / ||W||^2
        A_energy = float(torch.sum(W_original * A_component))
        B_energy = float(torch.sum(W_original * B_component))

        # For compatibility with decomposed case, still compute norms
        nf = torch.linalg.norm
        dot = lambda X, Y: torch.sum(X * Y)
        nI2 = float(n)  # ||I_n||_F^2
        nJ2 = float(n * n)  # ||J_n||_F^2
        nIJ = float(n)  # <I_n, J_n>_F
        A2 = float(nf(A_ell).pow(2))
        B2 = float(nf(B_ell).pow(2))
        AB = float(dot(A_ell, B_ell))

    W2 = float(max(W2, 1e-12))

    if use_decomposition:
        # For decomposed models, use the standard formula
        share_A = (A2 * nI2) / W2  # A-channel (I) share
        share_B = (B2 * nJ2) / W2  # B-channel (J) share
        cross_ratio = (2.0 * AB * nIJ) / W2
    else:
        # For no_decomp models, use actual energy contributions
        share_A = A_energy / W2  # Actual A-channel energy share
        share_B = B_energy / W2  # Actual B-channel energy share
        # Cross term: total energy - A energy - B energy
        cross_energy = W2 - A_energy - B_energy
        cross_ratio = cross_energy / W2

    return {
        "normA": float(np.sqrt(A2)),
        "normB": float(np.sqrt(B2)),
        "dotAB": float(AB),
        "W_norm": float(np.sqrt(W2)),
        "share_A": float(share_A),
        "share_B": float(share_B),
        "share_cross": float(cross_ratio),
    }


# ----------------------------
# Plotting functions (from analysis_v2.py)
# ----------------------------


def plot_heat(ax, M: torch.Tensor, title: str, vmin=None, vmax=None):
    im = ax.imshow(
        M.detach().numpy(),
        aspect="equal",
        interpolation="nearest",
        vmin=vmin,
        vmax=vmax,
        cmap="RdBu_r",
    )
    ax.set_title(title, fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)


def plot_step_figure(
    step_idx: int,
    out: dict,
    out_dir: str,
    vlims: dict,
):
    """
    Save a per-step snapshot with before/after, grads, and Δ for B_ell, A_ell, and full W.
    vlims has fixed scales across steps for consistent coloring.
    """
    fig = plt.figure(figsize=(20, 14))
    gs = fig.add_gridspec(4, 6, wspace=0.4, hspace=0.6)

    # Row 0: inputs, outputs, grads, text
    ax = fig.add_subplot(gs[0, 0])
    plot_heat(ax, out["A0"], "A (sample 1)")
    ax = fig.add_subplot(gs[0, 1])
    plot_heat(ax, out["R0"], "R (sample 1)")
    ax = fig.add_subplot(gs[0, 2])
    plot_heat(ax, out["Z0"], f"Z (step {step_idx})")
    ax = fig.add_subplot(gs[0, 3])
    plot_heat(ax, out["dLdZ0"], "∂L/∂Z (sample 1)")
    ax = fig.add_subplot(gs[0, 4])
    plot_heat(ax, out["gW"], "∂L/∂W (full)", vmin=-vlims["gW"], vmax=vlims["gW"])
    ax = fig.add_subplot(gs[0, 5])
    ax.axis("off")
    ax.text(0, 0.92, f"Step {step_idx}", fontsize=12, weight="bold")
    ax.text(0, 0.78, f"Loss: {out['loss']:.4f}", fontsize=11)
    ax.text(0, 0.66, f"||A_ell||_F = {out['chan']['normA']:.4f}", fontsize=10)
    ax.text(0, 0.56, f"||B_ell||_F = {out['chan']['normB']:.4f}", fontsize=10)
    ax.text(0, 0.46, f"<A_ell,B_ell> = {out['chan']['dotAB']:.4f}", fontsize=10)
    ax.text(
        0, 0.34, f"A-share in ||W||^2 = {100*out['chan']['share_A']:.2f}%", fontsize=10
    )
    ax.text(
        0, 0.24, f"B-share in ||W||^2 = {100*out['chan']['share_B']:.2f}%", fontsize=10
    )
    ax.text(0, 0.14, f"LR_A = {out['current_lr_A']:.4f}", fontsize=10)
    ax.text(0, 0.04, f"LR_B = {out['current_lr_B']:.4f}", fontsize=10)

    # Row 1: A_ell before/after/Δ and grad
    ax = fig.add_subplot(gs[1, 0])
    plot_heat(ax, out["A_before"], "A_ell before", vmin=-vlims["A"], vmax=vlims["A"])
    ax = fig.add_subplot(gs[1, 1])
    plot_heat(ax, out["A_after"], "A_ell after", vmin=-vlims["A"], vmax=vlims["A"])
    ax = fig.add_subplot(gs[1, 2])
    plot_heat(ax, out["dAch"], "ΔA_ell", vmin=-vlims["dA"], vmax=vlims["dA"])
    ax = fig.add_subplot(gs[1, 3])
    plot_heat(ax, out["gAch"], "∂L/∂A_ell", vmin=-vlims["gA"], vmax=vlims["gA"])
    # blanks
    ax = fig.add_subplot(gs[1, 4])
    ax.axis("off")
    ax = fig.add_subplot(gs[1, 5])
    ax.axis("off")

    # Row 2: B_ell before/after/Δ and grad
    ax = fig.add_subplot(gs[2, 0])
    plot_heat(ax, out["B_before"], "B_ell before", vmin=-vlims["B"], vmax=vlims["B"])
    ax = fig.add_subplot(gs[2, 1])
    plot_heat(ax, out["B_after"], "B_ell after", vmin=-vlims["B"], vmax=vlims["B"])
    ax = fig.add_subplot(gs[2, 2])
    plot_heat(ax, out["dBch"], "ΔB_ell", vmin=-vlims["dB"], vmax=vlims["dB"])
    ax = fig.add_subplot(gs[2, 3])
    plot_heat(ax, out["gBch"], "∂L/∂B_ell", vmin=-vlims["gB"], vmax=vlims["gB"])
    # blanks
    ax = fig.add_subplot(gs[2, 4])
    ax.axis("off")
    ax = fig.add_subplot(gs[2, 5])
    ax.axis("off")

    # Row 3: full W before/after/Δ and grad
    ax = fig.add_subplot(gs[3, 0])
    plot_heat(
        ax, out["W_before"], "W = A⊗I + B⊗J (before)", vmin=-vlims["W"], vmax=vlims["W"]
    )
    ax = fig.add_subplot(gs[3, 1])
    plot_heat(
        ax, out["W_after"], "W = A⊗I + B⊗J (after)", vmin=-vlims["W"], vmax=vlims["W"]
    )
    ax = fig.add_subplot(gs[3, 2])
    plot_heat(ax, out["dW"], "ΔW", vmin=-vlims["dW"], vmax=vlims["dW"])
    ax = fig.add_subplot(gs[3, 3])
    plot_heat(ax, out["gW"], "∂L/∂W (full)", vmin=-vlims["gW"], vmax=vlims["gW"])
    # blanks
    ax = fig.add_subplot(gs[3, 4])
    ax.axis("off")
    ax = fig.add_subplot(gs[3, 5])
    ax.axis("off")

    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"dt_step_{step_idx:03d}.png")
    fig.savefig(out_path, dpi=190, bbox_inches="tight")
    plt.close(fig)


def plot_training_dynamics(
    loss_hist,
    normA_hist,
    normB_hist,
    jshare_hist,
    ishare_hist,
    num_steps,
    out_dir,
    p_value,
    projection_error_hist=None,  # Keep parameter for backward compatibility but ignore it
):
    """Plot training dynamics (loss, energy shares, normalized norms)"""
    steps = np.arange(1, num_steps + 1)

    # Always use 3 subplots (removed projection error from here)
    fig = plt.figure(figsize=(12, 10))
    n_subplots = 3

    ax1 = fig.add_subplot(n_subplots, 1, 1)
    ax1.plot(steps, loss_hist, marker="o")
    ax1.set_title(f"Training dynamics (p={p_value:.3f})")
    ax1.set_ylabel("Loss (batch mean)")
    ax1.grid(True, alpha=0.3)

    ax2 = fig.add_subplot(n_subplots, 1, 2)
    ax2.plot(
        steps, jshare_hist, marker="o", color="blue", label="A-channel share in ||W||^2"
    )
    ax2.plot(
        steps,
        ishare_hist,
        marker="s",
        color="orange",
        label="B-channel share in ||W||^2",
    )
    ax2.set_ylabel("Energy share")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Normalized Frobenius norms
    normA_arr = np.array(normA_hist)
    normB_arr = np.array(normB_hist)
    norm_sum = normA_arr + normB_arr
    normA_frac = normA_arr / norm_sum
    normB_frac = normB_arr / norm_sum
    ax3 = fig.add_subplot(n_subplots, 1, 3)
    ax3.plot(
        steps,
        normA_frac,
        marker="o",
        color="blue",
        label="||A||_F / (||A||_F + ||B||_F)",
    )
    ax3.plot(
        steps,
        normB_frac,
        marker="s",
        color="orange",
        label="||B||_F / (||A||_F + ||B||_F)",
    )
    ax3.set_xlabel("Step")
    ax3.set_ylabel("Normalized Frobenius norm")
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    os.makedirs(out_dir, exist_ok=True)
    dyn_path = os.path.join(out_dir, "dt_dynamics.png")
    fig.savefig(dyn_path, dpi=200, bbox_inches="tight")
    plt.close(fig)

    return dyn_path


def plot_projection_error_standalone(
    projection_error_hist,
    loss_history,
    num_steps,
    out_dir,
    p_value,
):
    """Plot relative projection error with loss overlay using dual y-axes"""
    if projection_error_hist is None:
        return None

    steps = np.arange(1, num_steps + 1)
    error_rel = [err[1] for err in projection_error_hist]

    # Set modern style
    plt.style.use("default")

    # Create figure with modern aesthetics
    fig, ax1 = plt.subplots(1, 1, figsize=(16, 8))
    fig.patch.set_facecolor("white")

    # Color scheme - modern and aesthetic
    color_error = "#2E86AB"  # Blue for projection error
    color_loss = "#F24236"  # Red for loss

    # Plot projection error on primary y-axis
    line1 = ax1.plot(
        steps,
        error_rel,
        color=color_error,
        linewidth=5,
        alpha=0.8,
        label=r"Relative projection error $\frac{\|W - (\hat{A} \otimes I_n + \hat{B} \otimes J_n)\|_F}{\|W\|_F}$",
    )
    ax1.set_xlabel("Training Step", fontsize=28, fontweight="bold")
    ax1.set_ylabel(
        "Relative Projection Error",
        fontsize=28,
        fontweight="bold",
        color=color_error,
    )
    ax1.tick_params(axis="y", labelcolor=color_error, labelsize=22)
    ax1.tick_params(axis="x", labelsize=22)
    # Make tick labels bold
    plt.setp(ax1.get_xticklabels(), fontweight="bold")
    plt.setp(ax1.get_yticklabels(), fontweight="bold")
    ax1.set_xscale("log")
    ax1.grid(True, alpha=0.3, linestyle="-", linewidth=0.5)
    ax1.set_ylim(bottom=0)

    # Create secondary y-axis for loss
    ax2 = ax1.twinx()
    line2 = ax2.plot(
        steps,
        loss_history,
        color=color_loss,
        linewidth=5,
        alpha=0.8,
        label="Training loss",
    )
    ax2.set_ylabel("Training Loss", fontsize=28, fontweight="bold", color=color_loss)
    ax2.tick_params(axis="y", labelcolor=color_loss, labelsize=22)
    # Make tick labels bold for secondary y-axis
    plt.setp(ax2.get_yticklabels(), fontweight="bold")
    ax2.set_ylim(bottom=0)

    # Title with improved formatting
    ax1.set_title(
        f"Training Dynamics of 1-Layer Disentangled Transformer",
        fontsize=24,
        fontweight="bold",
        pad=20,
    )

    # Combined legend
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax1.legend(
        lines,
        labels,
        loc="best",
        frameon=True,
        fancybox=True,
        shadow=False,
        framealpha=0.9,
        fontsize=16,
    )
    # Make legend text bold
    legend = ax1.get_legend()
    for text in legend.get_texts():
        text.set_fontweight("bold")

    # Improve layout
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    proj_path = os.path.join(out_dir, "dt_projection_error.pdf")
    fig.savefig(proj_path, dpi=300, bbox_inches="tight", facecolor="white")
    plt.close(fig)

    return proj_path


def plot_channel_shares_standalone(
    jshare_hist,
    bshare_hist,
    loss_history,
    num_steps,
    out_dir,
    p_value,
):
    """Plot A/B channel shares without loss overlay"""
    if jshare_hist is None or bshare_hist is None:
        return None

    steps = np.arange(1, num_steps + 1)

    # Set modern style
    plt.style.use("default")

    # Create figure with modern aesthetics
    fig, ax = plt.subplots(1, 1, figsize=(16, 8))
    fig.patch.set_facecolor("white")

    # Color scheme - modern and aesthetic
    color_A = "#2E86AB"  # Blue for A-channel
    color_B = "#A23B72"  # Purple for B-channel

    # Plot channel shares
    line1 = ax.plot(
        steps,
        jshare_hist,
        color=color_A,
        linewidth=5,
        alpha=0.8,
        label=r"$\hat{A} \otimes I_n$-channel share",
    )
    line2 = ax.plot(
        steps,
        bshare_hist,
        color=color_B,
        linewidth=5,
        alpha=0.8,
        label=r"$\hat{B} \otimes J_n$-channel share",
    )
    ax.set_xlabel("Training Step", fontsize=28, fontweight="bold")
    ax.set_ylabel(
        r"I/J Channel Share in $\|W\|^2$",
        fontsize=28,
        fontweight="bold",
    )
    ax.tick_params(axis="y", labelsize=22)
    ax.tick_params(axis="x", labelsize=22)
    # Make tick labels bold
    plt.setp(ax.get_xticklabels(), fontweight="bold")
    plt.setp(ax.get_yticklabels(), fontweight="bold")
    ax.set_xscale("log")
    ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5)
    ax.set_ylim([0, 1])  # Shares are between 0 and 1

    # Title with improved formatting
    ax.set_title(
        f"Channel Share Dynamics of 1-Layer Disentangled Transformer",
        fontsize=24,
        fontweight="bold",
        pad=20,
    )

    # Legend
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax.legend(
        lines,
        labels,
        loc="best",
        frameon=True,
        fancybox=True,
        shadow=False,
        framealpha=0.9,
        fontsize=20,
    )
    # Make legend text bold
    legend = ax.get_legend()
    for text in legend.get_texts():
        text.set_fontweight("bold")

    # Improve layout
    plt.tight_layout()

    os.makedirs(out_dir, exist_ok=True)
    channel_path = os.path.join(out_dir, "dt_channel_shares.pdf")
    fig.savefig(channel_path, dpi=300, bbox_inches="tight", facecolor="white")
    plt.close(fig)

    return channel_path


def read_projection_error_from_csv(csv_path: str):
    """
    Read projection error data from dt_metrics.csv file.
    Returns list of (abs_error, rel_error) tuples or None if not available.
    """
    if not os.path.exists(csv_path):
        return None

    try:
        projection_error_hist = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            header = next(reader)  # Read header

            # Check if projection error columns exist
            if "proj_error_abs" not in header or "proj_error_rel" not in header:
                return None

            abs_idx = header.index("proj_error_abs")
            rel_idx = header.index("proj_error_rel")

            for row in reader:
                if len(row) > max(abs_idx, rel_idx):
                    abs_error = float(row[abs_idx])
                    rel_error = float(row[rel_idx])
                    projection_error_hist.append((abs_error, rel_error))

        return projection_error_hist if projection_error_hist else None
    except Exception as e:
        print(f"Warning: Could not read projection error from {csv_path}: {e}")
        return None


def read_loss_from_csv(csv_path: str):
    """
    Read loss data from dt_metrics.csv file.
    Returns list of loss values or None if not available.
    """
    if not os.path.exists(csv_path):
        return None

    try:
        loss_history = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            header = next(reader)  # Read header

            # Check if loss column exists
            if "loss" not in header:
                return None

            loss_idx = header.index("loss")

            for row in reader:
                if len(row) > loss_idx:
                    loss_value = float(row[loss_idx])
                    loss_history.append(loss_value)

        return loss_history if loss_history else None
    except Exception as e:
        print(f"Warning: Could not read loss from {csv_path}: {e}")
        return None


def read_channel_shares_from_csv(csv_path: str):
    """
    Read channel share data from dt_metrics.csv file.
    Returns tuple of (A_share_hist, B_share_hist) or (None, None) if not available.
    """
    if not os.path.exists(csv_path):
        return None, None

    try:
        A_share_hist = []
        B_share_hist = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            header = next(reader)  # Read header

            # Check if channel share columns exist
            if "A_share" not in header or "B_share" not in header:
                return None, None

            A_idx = header.index("A_share")
            B_idx = header.index("B_share")

            for row in reader:
                if len(row) > max(A_idx, B_idx):
                    A_share = float(row[A_idx])
                    B_share = float(row[B_idx])
                    A_share_hist.append(A_share)
                    B_share_hist.append(B_share)

        if A_share_hist and B_share_hist:
            return A_share_hist, B_share_hist
        else:
            return None, None
    except Exception as e:
        print(f"Warning: Could not read channel shares from {csv_path}: {e}")
        return None, None


# ----------------------------
# Training config & helpers
# ----------------------------


@dataclass
class Config:
    n: int = 8
    alpha: float = 1.0  # link temperature in phi(z) = 1 - exp(-α z)
    batch_size: int = 4096  # average gradient over large batch
    lr_A: float = 0.5  # step size for A_ell
    lr_B: float = 0.5  # step size for B_ell
    clip_norm: float = -1  # gradient clipping (-1 means no clipping)
    seed: int = 42
    use_mse: bool = False  # True: MSE on phi; False: clamped BCE
    num_steps: int = 200  # number of training steps
    resample_each_step: bool = False  # re-sample ER graphs each step?
    restrict_diam: Union[int, bool] = (
        False  # restrict diameter: True->3, int->that value, False->no restriction
    )
    eval_batch_size: int = 1024  # batch size for evaluation
    use_lr_decay: bool = True  # whether to use cosine learning rate decay
    plot_steps: bool = True  # whether to save per-step plots


def phi(z: torch.Tensor, alpha: float) -> torch.Tensor:
    return 1.0 - torch.exp(-alpha * z)


def loss_mse_on_phi(Z: torch.Tensor, R: torch.Tensor, alpha: float) -> torch.Tensor:
    P = phi(Z, alpha)
    return (P - R).pow(2).mean()


def loss_bce_on_phi_clamped(
    Z: torch.Tensor, R: torch.Tensor, alpha: float, eps: float = 1e-4
) -> torch.Tensor:
    P = phi(Z, alpha).clamp(eps, 1 - eps)
    return -(R * torch.log(P) + (1 - R) * torch.log(1 - P)).mean()


def evaluate_accuracy(
    model: nn.Module, A_batch: torch.Tensor, R_batch: torch.Tensor, alpha: float
) -> Dict[str, float]:
    """Evaluate model accuracy on given batch. Returns both average and exact match accuracy."""
    model.eval()
    with torch.no_grad():
        Z = model(A_batch)
        P = phi(Z, alpha)
        # Binary predictions
        predictions = (P > 0.5).float()

        # Average accuracy (element-wise)
        avg_accuracy = (predictions == R_batch).float().mean().item()

        # Exact match accuracy (all elements in a sample must match)
        batch_size = R_batch.shape[0]
        exact_matches = 0

        for i in range(batch_size):
            if torch.equal(predictions[i], R_batch[i]):
                exact_matches += 1
        exact_accuracy = exact_matches / batch_size

    model.train()
    return {"avg_accuracy": avg_accuracy, "exact_accuracy": exact_accuracy}


def evaluate_accuracy_by_capacity(
    model: nn.Module,
    A_batch: torch.Tensor,
    R_batch: torch.Tensor,
    alpha: float,
    capacity_threshold: int = 3,
) -> Dict[str, float]:
    """Evaluate model accuracy partitioned by graph capacity (diameter)."""
    model.eval()
    batch_size = A_batch.shape[0]

    with torch.no_grad():
        Z = model(A_batch)
        P = phi(Z, alpha)
        predictions = (P > 0.5).float()

        # Determine capacity flag for each graph
        per_graph_capacity_flags = []  # True = within capacity, False = beyond capacity

        for i in range(batch_size):
            A_numpy = A_batch[i].cpu().numpy()
            diameter = compute_graph_diameter(A_numpy)
            is_within_capacity = diameter <= capacity_threshold
            per_graph_capacity_flags.append(is_within_capacity)

        per_graph_capacity_flags = np.array(per_graph_capacity_flags)

        # Overall accuracy (exactly same as original evaluate_accuracy function)
        overall_avg_accuracy = (predictions == R_batch).float().mean().item()
        overall_exact_matches = sum(
            1 for i in range(batch_size) if torch.equal(predictions[i], R_batch[i])
        )
        overall_exact_accuracy = overall_exact_matches / batch_size

        results = {
            "avg_accuracy": overall_avg_accuracy,
            "exact_accuracy": overall_exact_accuracy,
        }

        # Within-capacity accuracy
        within_capacity_mask = per_graph_capacity_flags
        within_indices = np.where(within_capacity_mask)[0]

        if len(within_indices) > 0:
            within_predictions = predictions[within_indices]
            within_targets = R_batch[within_indices]

            # Average accuracy: element-wise across all within-capacity graphs
            within_avg_acc = (
                (within_predictions == within_targets).float().mean().item()
            )

            # Exact accuracy: per-graph exact matches averaged
            within_exact_matches = sum(
                1 for i in within_indices if torch.equal(predictions[i], R_batch[i])
            )
            within_exact_acc = within_exact_matches / len(within_indices)

            results["within_capacity_avg_accuracy"] = within_avg_acc
            results["within_capacity_exact_accuracy"] = within_exact_acc
            results["within_capacity_count"] = len(within_indices)
        else:
            results["within_capacity_avg_accuracy"] = 0.0
            results["within_capacity_exact_accuracy"] = 0.0
            results["within_capacity_count"] = 0

        # Beyond-capacity accuracy
        beyond_capacity_mask = ~per_graph_capacity_flags
        beyond_indices = np.where(beyond_capacity_mask)[0]

        if len(beyond_indices) > 0:
            beyond_predictions = predictions[beyond_indices]
            beyond_targets = R_batch[beyond_indices]

            # Average accuracy: element-wise across all beyond-capacity graphs
            beyond_avg_acc = (
                (beyond_predictions == beyond_targets).float().mean().item()
            )

            # Exact accuracy: per-graph exact matches averaged
            beyond_exact_matches = sum(
                1 for i in beyond_indices if torch.equal(predictions[i], R_batch[i])
            )
            beyond_exact_acc = beyond_exact_matches / len(beyond_indices)

            results["beyond_capacity_avg_accuracy"] = beyond_avg_acc
            results["beyond_capacity_exact_accuracy"] = beyond_exact_acc
            results["beyond_capacity_count"] = len(beyond_indices)
        else:
            results["beyond_capacity_avg_accuracy"] = 0.0
            results["beyond_capacity_exact_accuracy"] = 0.0
            results["beyond_capacity_count"] = 0

    model.train()
    return results


# ----------------------------
# Single training run
# ----------------------------


def train_single_run(
    p: float, config: Config, save_dir: str, use_decomposition: bool = True
) -> Dict[str, Any]:
    """
    Train a single model with given edge probability p.
    Returns metrics and saves model checkpoint.
    """
    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    device = torch.device("cpu")
    os.makedirs(save_dir, exist_ok=True)

    print(f"Training with p={p:.3f}")

    # Generate training data
    A_batch, R_batch = batch_er_graphs(
        config.n,
        p,
        config.batch_size,
        seed=config.seed,
        restrict_diam=config.restrict_diam,
    )
    A_batch = A_batch.to(device)
    R_batch = R_batch.to(device)

    # Analyze capacity of training graphs
    capacity_analysis = analyze_capacity(A_batch, capacity_threshold=3)

    # Analyze diameter distribution and create plot
    diameter_distribution = analyze_diameter_distribution(A_batch, config.n)
    diameter_plot_path = plot_diameter_distribution(
        diameter_distribution, config.n, p, save_dir
    )

    # Model
    if use_decomposition:
        model = OneLayerDisentangledTF(
            n=config.n,
            A_init_scale=0.1,
            B_init_scale=1.0,
            seed=config.seed,
        ).to(device)
    else:
        model = OneLayerWholeMatrixTF(
            n=config.n,
            W_init_scale=0.1,
            seed=config.seed,
        ).to(device)

    # Dynamics storage
    loss_history = []
    jshare_hist, ishare_hist = [], []  # A/B-channel shares
    normA_hist, normB_hist = [], []
    projection_error_hist = [] if not use_decomposition else None

    # Running absolute maxima for fixed color scales across steps (for plotting)
    vlims = {
        "A": 1e-6,
        "B": 1e-6,
        "W": 1e-6,
        "gA": 1e-6,
        "gB": 1e-6,
        "gW": 1e-6,
        "dA": 1e-6,
        "dB": 1e-6,
        "dW": 1e-6,
    }

    # Training loop with dynamics tracking
    for step in tqdm(range(1, config.num_steps + 1), desc=f"Training p={p:.3f}"):
        # Optionally resample graphs
        if config.resample_each_step:
            A_batch, R_batch = batch_er_graphs(
                config.n,
                p,
                config.batch_size,
                seed=config.seed + step,
                restrict_diam=config.restrict_diam,
            )
            A_batch = A_batch.to(device)
            R_batch = R_batch.to(device)

        # -------- BEFORE update snapshots
        with torch.no_grad():
            if use_decomposition:
                A_before = model.A_ell.detach().clone()
                B_before = model.B_ell.detach().clone()
                # Compute full W before
                I = torch.eye(config.n, dtype=A_batch.dtype, device=device)
                J = torch.ones(config.n, config.n, dtype=A_batch.dtype, device=device)
                W_before = torch.kron(A_before, I) + torch.kron(B_before, J)
            else:
                W_before = model.W.detach().clone()
                A_before, B_before = model.project_W_to_channels()

        # Forward pass
        Z, intermediates = model(A_batch, return_intermediates=True)

        # Compute gradients for visualization
        if config.use_mse:
            loss = loss_mse_on_phi(Z, R_batch, alpha=config.alpha)
        else:
            loss = loss_bce_on_phi_clamped(Z, R_batch, alpha=config.alpha, eps=1e-4)

        # Compute gradient with respect to Z before backward pass (for visualization)
        dLdZ0 = torch.autograd.grad(loss, Z, retain_graph=True)[0][0].detach().clone()

        # Backward pass
        model.zero_grad(set_to_none=True)
        loss.backward()

        # Store gradients before clipping
        if use_decomposition:
            gA_ch = model.A_ell.grad.detach().clone()
            gB_ch = model.B_ell.grad.detach().clone()
            # Clip gradients
            if config.clip_norm > 0:
                torch.nn.utils.clip_grad_norm_([model.A_ell], max_norm=config.clip_norm)
                torch.nn.utils.clip_grad_norm_([model.B_ell], max_norm=config.clip_norm)
        else:
            gW_full = model.W.grad.detach().clone()
            # Project the gradient onto A and B channels for visualization
            gA_ch = torch.zeros(2, 2, dtype=A_batch.dtype, device=device)
            gB_ch = torch.zeros(2, 2, dtype=A_batch.dtype, device=device)
            for i in range(2):
                for j in range(2):
                    block = gW_full[
                        i * config.n : (i + 1) * config.n,
                        j * config.n : (j + 1) * config.n,
                    ]
                    gA_ch[i, j] = torch.diag(block).mean()
                    gB_ch[i, j] = block.mean()
            # Clip gradients
            if config.clip_norm > 0:
                torch.nn.utils.clip_grad_norm_([model.W], max_norm=config.clip_norm)

        # Update with optional learning rate decay
        with torch.no_grad():
            if config.use_lr_decay:
                # Cosine decay: starts at 1, goes to 0
                progress = step / config.num_steps
                decay_factor = 0.5 * (1 + math.cos(math.pi * progress))
                current_lr_A = config.lr_A * decay_factor
                current_lr_B = config.lr_B * decay_factor
            else:
                current_lr_A = config.lr_A
                current_lr_B = config.lr_B

            if use_decomposition:
                model.A_ell -= current_lr_A * model.A_ell.grad
                model.B_ell -= current_lr_B * model.B_ell.grad
            else:
                # For whole matrix, use the same learning rate for the entire W
                model.W -= current_lr_A * model.W.grad

        # -------- AFTER update snapshots
        with torch.no_grad():
            if use_decomposition:
                A_after = model.A_ell.detach().clone()
                B_after = model.B_ell.detach().clone()
                # Compute full W after
                I = torch.eye(config.n, dtype=A_batch.dtype, device=device)
                J = torch.ones(config.n, config.n, dtype=A_batch.dtype, device=device)
                W_after = torch.kron(A_after, I) + torch.kron(B_after, J)
                # Deltas
                dA_ch = A_after - A_before
                dB_ch = B_after - B_before
                dW = W_after - W_before
                # Full gradient on W
                gW = torch.kron(gA_ch, I) + torch.kron(gB_ch, J)
            else:
                W_after = model.W.detach().clone()
                A_after, B_after = model.project_W_to_channels()
                # Deltas
                dA_ch = A_after - A_before
                dB_ch = B_after - B_before
                dW = W_after - W_before
                # Use the full gradient we stored earlier
                gW = gW_full

        # Channel metrics at this step
        chan = channel_metrics(model, config.n, use_decomposition)

        # Track projection error for whole matrix model
        if not use_decomposition:
            proj_error = model.compute_projection_error()
            projection_error_hist.append(proj_error)

            # Save W visualization at specific steps
            if step in [1, 10, 100, 1000, 5000, 10000] and step <= config.num_steps:
                try:
                    fig_path = plot_W_and_projection(model, step, save_dir, p)
                    print(f"    Saved W visualization at step {step}: {fig_path}")
                except Exception as e:
                    print(
                        f"    Warning: Could not save W visualization at step {step}: {e}"
                    )

        # Update running maxima for color scales
        vlims["A"] = max(
            vlims["A"],
            float(torch.abs(A_before).max()),
            float(torch.abs(A_after).max()),
        )
        vlims["B"] = max(
            vlims["B"],
            float(torch.abs(B_before).max()),
            float(torch.abs(B_after).max()),
        )
        vlims["W"] = max(
            vlims["W"],
            float(torch.abs(W_before).max()),
            float(torch.abs(W_after).max()),
        )
        vlims["gA"] = max(vlims["gA"], float(torch.abs(gA_ch).max()))
        vlims["gB"] = max(vlims["gB"], float(torch.abs(gB_ch).max()))
        vlims["gW"] = max(vlims["gW"], float(torch.abs(gW).max()))
        vlims["dA"] = max(vlims["dA"], float(torch.abs(dA_ch).max()))
        vlims["dB"] = max(vlims["dB"], float(torch.abs(dB_ch).max()))
        vlims["dW"] = max(vlims["dW"], float(torch.abs(dW).max()))

        # Plot/save this step (every few steps to avoid too many files)
        if (
            step % max(1, config.num_steps // 50) == 0
            or step <= 10
            or (step <= 200 and step % 10 == 0)
            or step == config.num_steps
        ) and config.plot_steps:
            # Prepare output dict for plotting
            out = {
                "A0": A_batch[0],
                "R0": R_batch[0],
                "Z0": Z[0],
                "dLdZ0": dLdZ0,
                "A_before": A_before,
                "A_after": A_after,
                "B_before": B_before,
                "B_after": B_after,
                "W_before": W_before,
                "W_after": W_after,
                "dAch": dA_ch,
                "dBch": dB_ch,
                "dW": dW,
                "gAch": gA_ch,
                "gBch": gB_ch,
                "gW": gW,
                "loss": loss.item(),
                "chan": chan,
                "current_lr_A": current_lr_A,
                "current_lr_B": current_lr_B,
            }
            plot_step_figure(step, out, save_dir, vlims)

        # Record scalar dynamics
        loss_history.append(loss.item())
        jshare_hist.append(chan["share_A"])
        ishare_hist.append(chan["share_B"])
        normA_hist.append(chan["normA"])
        normB_hist.append(chan["normB"])

    # Final evaluation
    model.eval()

    # Evaluate on ER graphs (in-distribution) - fixed p=0.20 for evaluation
    A_eval_er, R_eval_er = batch_er_graphs(
        config.n,
        0.20,  # Fixed p value for evaluation ER graphs
        config.eval_batch_size,
        seed=config.seed + 99999999,
        restrict_diam=False,
    )
    A_eval_er = A_eval_er.to(device)
    R_eval_er = R_eval_er.to(device)
    accuracy_er = evaluate_accuracy_by_capacity(
        model, A_eval_er, R_eval_er, config.alpha
    )

    # Sanity check: also run original evaluation function for comparison
    accuracy_er_original = evaluate_accuracy(model, A_eval_er, R_eval_er, config.alpha)
    print(
        f"SANITY CHECK - Original function: avg={accuracy_er_original['avg_accuracy']:.4f}, exact={accuracy_er_original['exact_accuracy']:.4f}"
    )
    print(
        f"SANITY CHECK - New function: avg={accuracy_er['avg_accuracy']:.4f}, exact={accuracy_er['exact_accuracy']:.4f}"
    )
    print(
        f"SANITY CHECK - Difference: avg={abs(accuracy_er['avg_accuracy'] - accuracy_er_original['avg_accuracy']):.6f}, exact={abs(accuracy_er['exact_accuracy'] - accuracy_er_original['exact_accuracy']):.6f}"
    )

    # Evaluate on two-chains graphs (out-of-distribution)
    A_eval_tc, R_eval_tc = batch_two_chains_graphs(
        config.n, config.eval_batch_size, k=config.n // 2
    )
    A_eval_tc = A_eval_tc.to(device)
    R_eval_tc = R_eval_tc.to(device)
    accuracy_tc = evaluate_accuracy(model, A_eval_tc, R_eval_tc, config.alpha)
    # Final channel metrics
    final_metrics = channel_metrics(model, config.n, use_decomposition)

    # Plot training dynamics
    dynamics_path = plot_training_dynamics(
        loss_history,
        normA_hist,
        normB_hist,
        jshare_hist,
        ishare_hist,
        config.num_steps,
        save_dir,
        p,
        projection_error_hist,
    )

    # Plot standalone projection error if we have data
    projection_path = plot_projection_error_standalone(
        projection_error_hist,
        loss_history,
        config.num_steps,
        save_dir,
        p,
    )

    # Plot standalone channel shares
    channel_path = plot_channel_shares_standalone(
        jshare_hist,
        ishare_hist,
        loss_history,
        config.num_steps,
        save_dir,
        p,
    )

    # Save training dynamics CSV
    steps = np.arange(1, config.num_steps + 1)
    csv_path = os.path.join(save_dir, "dt_metrics.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        if projection_error_hist is not None:
            # Include projection error columns for whole matrix model
            writer.writerow(
                [
                    "step",
                    "loss",
                    "normA",
                    "normB",
                    "A_share",
                    "B_share",
                    "proj_error_abs",
                    "proj_error_rel",
                ]
            )
            for s, L, nA, nB, ash, bsh, proj_err in zip(
                steps,
                loss_history,
                normA_hist,
                normB_hist,
                jshare_hist,
                ishare_hist,
                projection_error_hist,
            ):
                writer.writerow(
                    [
                        int(s),
                        float(L),
                        float(nA),
                        float(nB),
                        float(ash),
                        float(bsh),
                        float(proj_err[0]),
                        float(proj_err[1]),
                    ]
                )
        else:
            # Original format for decomposed model
            writer.writerow(["step", "loss", "normA", "normB", "A_share", "B_share"])
            for s, L, nA, nB, ash, bsh in zip(
                steps, loss_history, normA_hist, normB_hist, jshare_hist, ishare_hist
            ):
                writer.writerow(
                    [int(s), float(L), float(nA), float(nB), float(ash), float(bsh)]
                )

    # Save model checkpoint
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "config": config,
        "use_decomposition": use_decomposition,
        "p": p,
        "loss_history": loss_history,
        "dynamics_history": {
            "jshare_hist": jshare_hist,
            "ishare_hist": ishare_hist,
            "normA_hist": normA_hist,
            "normB_hist": normB_hist,
            "projection_error_hist": projection_error_hist,
        },
        "final_metrics": final_metrics,
        "accuracy_er": accuracy_er,
        "accuracy_tc": accuracy_tc,
        "capacity_analysis": capacity_analysis,
        "diameter_distribution": diameter_distribution,
        "diameter_plot_path": diameter_plot_path,
        "dynamics_plot_path": dynamics_path,
        "metrics_csv_path": csv_path,
        "eval_graphs": {
            "A_eval_er": A_eval_er.cpu(),  # Save evaluation graphs for consistency
            "R_eval_er": R_eval_er.cpu(),
        },
    }

    checkpoint_path = os.path.join(save_dir, "checkpoint.pkl")
    with open(checkpoint_path, "wb") as f:
        pickle.dump(checkpoint, f)

    # Return summary metrics
    results = {
        "p": p,
        "final_loss": loss_history[-1],
        "accuracy_er_avg": accuracy_er["avg_accuracy"],
        "accuracy_er_exact": accuracy_er["exact_accuracy"],
        "accuracy_er_within_capacity_avg": accuracy_er["within_capacity_avg_accuracy"],
        "accuracy_er_within_capacity_exact": accuracy_er[
            "within_capacity_exact_accuracy"
        ],
        "accuracy_er_beyond_capacity_avg": accuracy_er["beyond_capacity_avg_accuracy"],
        "accuracy_er_beyond_capacity_exact": accuracy_er[
            "beyond_capacity_exact_accuracy"
        ],
        "accuracy_tc_avg": accuracy_tc["avg_accuracy"],
        "accuracy_tc_exact": accuracy_tc["exact_accuracy"],
        "final_share_A": final_metrics["share_A"],
        "final_share_B": final_metrics["share_B"],
        "final_normA": final_metrics["normA"],
        "final_normB": final_metrics["normB"],
        "within_capacity_prop": capacity_analysis["within_capacity_prop"],
        "beyond_capacity_prop": capacity_analysis["beyond_capacity_prop"],
        "avg_diameter": capacity_analysis["avg_diameter"],
        "loss_history": loss_history,
    }

    # Add projection error data for non-decomposition models
    if not use_decomposition and projection_error_hist is not None:
        # Store final projection error values
        final_proj_error = (
            projection_error_hist[-1] if projection_error_hist else (0.0, 0.0)
        )
        results["final_proj_error_abs"] = final_proj_error[0]
        results["final_proj_error_rel"] = final_proj_error[1]
        results["projection_error_hist"] = projection_error_hist

    print(
        f"p={p:.3f}: ER avg_acc={accuracy_er['avg_accuracy']:.3f}, ER exact_acc={accuracy_er['exact_accuracy']:.3f}, "
        f"TC avg_acc={accuracy_tc['avg_accuracy']:.3f}, TC exact_acc={accuracy_tc['exact_accuracy']:.3f}, "
        f"A-share={final_metrics['share_A']:.3f}, B-share={final_metrics['share_B']:.3f}"
    )
    print(f"  Saved dynamics plot: {dynamics_path}")
    print(f"  Saved diameter distribution plot: {diameter_plot_path}")
    print(f"  Saved metrics CSV: {csv_path}")
    print(f"  Saved per-step figures: {save_dir}/dt_step_*.png")

    return results


def load_checkpoint_and_extract_results(
    p: float, config: Config, save_dir: str, use_decomposition: bool = True
) -> Dict[str, Any]:
    """
    Load a previously saved checkpoint and extract evaluation results.
    Used when --resume flag is enabled to skip training.
    """
    checkpoint_path = os.path.join(save_dir, "checkpoint.pkl")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    print(f"Loading checkpoint for p={p:.3f} from {checkpoint_path}")

    with open(checkpoint_path, "rb") as f:
        checkpoint = pickle.load(f)

    # Extract the results directly from the saved checkpoint
    accuracy_er = checkpoint["accuracy_er"]
    accuracy_tc = checkpoint["accuracy_tc"]
    final_metrics = checkpoint["final_metrics"]
    capacity_analysis = checkpoint["capacity_analysis"]
    loss_history = checkpoint["loss_history"]

    # Check if capacity-partitioned accuracy metrics exist, re-evaluate if missing
    if "within_capacity_avg_accuracy" not in accuracy_er:
        print(
            "Capacity-partitioned accuracy metrics not found in checkpoint. Re-evaluating..."
        )

        # Load the model to re-evaluate
        model_state_dict = checkpoint["model_state_dict"]

        if use_decomposition:
            model = OneLayerDisentangledTF(config.n, config.alpha)
        else:
            model = OneLayerWholeMatrixTF(config.n, config.alpha)

        model.load_state_dict(model_state_dict)
        model.eval()

        # Try to use saved evaluation graphs first, generate new ones if not available
        if "eval_graphs" in checkpoint:
            print("Using saved evaluation graphs for consistency...")
            A_eval_er = checkpoint["eval_graphs"]["A_eval_er"]
            R_eval_er = checkpoint["eval_graphs"]["R_eval_er"]
        else:
            print("Generating new evaluation graphs...")
            # Generate ER graphs for evaluation
            A_eval_er, R_eval_er = batch_er_graphs(
                config.n,
                0.20,  # Fixed p value for evaluation ER graphs
                config.eval_batch_size,
                seed=config.seed + 99999999,  # Different seed for eval graphs
                restrict_diam=False,  # No diameter restriction for evaluation
            )

        # Move to device (assuming CPU for checkpoint loading)
        device = next(model.parameters()).device
        A_eval_er = A_eval_er.to(device)
        R_eval_er = R_eval_er.to(device)

        # Re-evaluate with capacity partitioning
        accuracy_er = evaluate_accuracy_by_capacity(
            model, A_eval_er, R_eval_er, config.alpha
        )

        # Sanity check: also run original evaluation function for comparison
        accuracy_er_original = evaluate_accuracy(
            model, A_eval_er, R_eval_er, config.alpha
        )
        print(
            f"RESUME SANITY CHECK - Original function: avg={accuracy_er_original['avg_accuracy']:.4f}, exact={accuracy_er_original['exact_accuracy']:.4f}"
        )
        print(
            f"RESUME SANITY CHECK - New function: avg={accuracy_er['avg_accuracy']:.4f}, exact={accuracy_er['exact_accuracy']:.4f}"
        )
        print(
            f"RESUME SANITY CHECK - Difference: avg={abs(accuracy_er['avg_accuracy'] - accuracy_er_original['avg_accuracy']):.6f}, exact={abs(accuracy_er['exact_accuracy'] - accuracy_er_original['exact_accuracy']):.6f}"
        )

    # Return summary metrics in the same format as train_single_run
    results = {
        "p": p,
        "final_loss": loss_history[-1],
        "accuracy_er_avg": accuracy_er["avg_accuracy"],
        "accuracy_er_exact": accuracy_er["exact_accuracy"],
        "accuracy_tc_avg": accuracy_tc["avg_accuracy"],
        "accuracy_tc_exact": accuracy_tc["exact_accuracy"],
        "final_share_A": final_metrics["share_A"],
        "final_share_B": final_metrics["share_B"],
        "final_normA": final_metrics["normA"],
        "final_normB": final_metrics["normB"],
        "within_capacity_prop": capacity_analysis["within_capacity_prop"],
        "beyond_capacity_prop": capacity_analysis["beyond_capacity_prop"],
        "avg_diameter": capacity_analysis["avg_diameter"],
        "accuracy_er_within_capacity_avg": accuracy_er["within_capacity_avg_accuracy"],
        "accuracy_er_within_capacity_exact": accuracy_er[
            "within_capacity_exact_accuracy"
        ],
        "accuracy_er_beyond_capacity_avg": accuracy_er["beyond_capacity_avg_accuracy"],
        "accuracy_er_beyond_capacity_exact": accuracy_er[
            "beyond_capacity_exact_accuracy"
        ],
        "loss_history": loss_history,
    }

    # Add projection error data for non-decomposition models
    if not use_decomposition:
        dynamics_history = checkpoint.get("dynamics_history", {})
        projection_error_hist = dynamics_history.get("projection_error_hist", None)
        if projection_error_hist is not None:
            # Store final projection error values
            final_proj_error = (
                projection_error_hist[-1] if projection_error_hist else (0.0, 0.0)
            )
            results["final_proj_error_abs"] = final_proj_error[0]
            results["final_proj_error_rel"] = final_proj_error[1]
            results["projection_error_hist"] = projection_error_hist

    print(
        f"p={p:.3f}: ER avg_acc={accuracy_er['avg_accuracy']:.3f}, ER exact_acc={accuracy_er['exact_accuracy']:.3f}, "
        f"TC avg_acc={accuracy_tc['avg_accuracy']:.3f}, TC exact_acc={accuracy_tc['exact_accuracy']:.3f}, "
        f"A-share={final_metrics['share_A']:.3f}, B-share={final_metrics['share_B']:.3f}"
    )

    return results


def regenerate_training_dynamics_plot(p: float, save_dir: str, use_decomposition: bool):
    """
    Regenerate training dynamics plot from saved checkpoint data.
    This function is used during resume mode to create plots for each p value.
    """
    checkpoint_path = os.path.join(save_dir, "checkpoint.pkl")

    if not os.path.exists(checkpoint_path):
        print(f"Warning: Checkpoint not found at {checkpoint_path}")
        return

    try:
        with open(checkpoint_path, "rb") as f:
            checkpoint = pickle.load(f)

        # Extract training dynamics data
        loss_history = checkpoint["loss_history"]
        dynamics_history = checkpoint.get("dynamics_history", {})

        jshare_hist = dynamics_history.get("jshare_hist", [])
        ishare_hist = dynamics_history.get("ishare_hist", [])
        normA_hist = dynamics_history.get("normA_hist", [])
        normB_hist = dynamics_history.get("normB_hist", [])
        projection_error_hist = dynamics_history.get("projection_error_hist", None)

        num_steps = len(loss_history)

        # Regenerate the training dynamics plot
        if num_steps > 0:
            dynamics_path = plot_training_dynamics(
                loss_history,
                normA_hist,
                normB_hist,
                jshare_hist,
                ishare_hist,
                num_steps,
                save_dir,
                p,
                projection_error_hist,
            )
            print(f"  Regenerated training dynamics plot: {dynamics_path}")

            # Also regenerate standalone projection error plot
            # If we don't have projection error in checkpoint, try to read from CSV
            if projection_error_hist is None and not use_decomposition:
                csv_path = os.path.join(save_dir, "dt_metrics.csv")
                projection_error_hist = read_projection_error_from_csv(csv_path)
                # If we're reading from CSV, also get loss data from CSV for consistency
                if projection_error_hist is not None:
                    csv_loss_history = read_loss_from_csv(csv_path)
                    if csv_loss_history is not None and len(csv_loss_history) == len(
                        projection_error_hist
                    ):
                        loss_history = csv_loss_history

            if projection_error_hist is not None:
                projection_path = plot_projection_error_standalone(
                    projection_error_hist,
                    loss_history,
                    num_steps,
                    save_dir,
                    p,
                )
                if projection_path:
                    print(f"  Regenerated projection error plot: {projection_path}")

            # Also regenerate standalone channel shares plot
            # Try to read channel shares from CSV if not in checkpoint
            if jshare_hist is None or ishare_hist is None:
                csv_path = os.path.join(save_dir, "dt_metrics.csv")
                jshare_hist_csv, ishare_hist_csv = read_channel_shares_from_csv(
                    csv_path
                )
                if jshare_hist_csv is not None and ishare_hist_csv is not None:
                    jshare_hist = jshare_hist_csv
                    ishare_hist = ishare_hist_csv

            if jshare_hist is not None and ishare_hist is not None:
                channel_path = plot_channel_shares_standalone(
                    jshare_hist,
                    ishare_hist,
                    loss_history,
                    num_steps,
                    save_dir,
                    p,
                )
                if channel_path:
                    print(f"  Regenerated channel shares plot: {channel_path}")
        else:
            print(f"  Warning: No training history found for p={p:.3f}")

    except Exception as e:
        print(f"  Error regenerating training dynamics plot for p={p:.3f}: {e}")


# ----------------------------
# Main capacity analysis
# ----------------------------


def run_capacity_analysis(
    restrict_diam: Union[int, bool] = 3,
    resume: bool = False,
    use_decomposition: bool = True,
    specific_p: float = None,
):
    """Run the full capacity analysis with multiple p values."""
    if use_decomposition:
        lr_shared = 0.01
    else:
        lr_shared = 0.1
    config = Config(
        n=8,
        alpha=5.0,
        batch_size=1024,
        lr_A=lr_shared,
        lr_B=lr_shared,
        clip_norm=1,
        seed=42,
        use_mse=False,
        num_steps=10_000,
        resample_each_step=False,
        restrict_diam=restrict_diam,
        eval_batch_size=10_000,
        use_lr_decay=True,
        plot_steps=True,
    )

    # Range of p values to test
    if specific_p is not None:
        p_values = [specific_p]
        print(f"Running analysis for specific p value: {specific_p}")
    else:
        p_values = np.linspace(0.02, 0.6, 30)  # 30 different p values
        print(
            f"Running analysis for {len(p_values)} p values from {p_values[0]:.3f} to {p_values[-1]:.3f}"
        )

    # Create folder name that reflects restrict_diam and decomposition setting
    if isinstance(config.restrict_diam, bool):
        if config.restrict_diam:
            diam_suffix = "/restrict_dim=3"
        else:
            diam_suffix = "/restrict_dim=False"
    else:
        diam_suffix = f"/restrict_dim={config.restrict_diam}"

    decomp_suffix = "" if use_decomposition else "_no_decomp"
    base_dir = f"analysis_capacity_0915{decomp_suffix}{diam_suffix}"
    analysis_dir = os.path.join(base_dir, "analysis")
    os.makedirs(analysis_dir, exist_ok=True)

    # Generate evaluation ER diameter distribution before training
    eval_plot_path, eval_diameter_dist = analyze_eval_er_diameter_distribution(
        config, analysis_dir, eval_p=0.20
    )

    # Store all results
    all_results = []

    if resume:
        print("Resume mode: Loading checkpoints and extracting evaluation results...")
        # In resume mode, load existing checkpoints or train if not found
        for p in p_values:
            save_dir = os.path.join(base_dir, f"p_{p:.3f}")
            try:
                results = load_checkpoint_and_extract_results(
                    p, config, save_dir, use_decomposition
                )
                all_results.append(results)
                # Regenerate training dynamics plot for each p value
                regenerate_training_dynamics_plot(p, save_dir, use_decomposition)
            except FileNotFoundError as e:
                print(f"Warning: {e}")
                print(f"Checkpoint not found for p={p:.3f} - running training instead")
                results = train_single_run(p, config, save_dir, use_decomposition)
                all_results.append(results)
    else:
        print("Training mode: Running full training for each p value...")
        # Run training for each p value
        for p in p_values:
            save_dir = os.path.join(base_dir, f"p_{p:.3f}")
            results = train_single_run(p, config, save_dir, use_decomposition)
            all_results.append(results)

    # Save summary CSV
    if all_results:  # Only save if we have results
        csv_path = os.path.join(base_dir, "capacity_metrics.csv")
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)

            # Check if any result has projection error data
            has_projection_error = any(
                r.get("final_proj_error_abs") is not None for r in all_results
            )

            # Define header based on whether we have projection error data
            header = [
                "p",
                "final_loss",
                "accuracy_er_avg",
                "accuracy_er_exact",
                "accuracy_er_within_capacity_avg",
                "accuracy_er_within_capacity_exact",
                "accuracy_er_beyond_capacity_avg",
                "accuracy_er_beyond_capacity_exact",
                "accuracy_tc_avg",
                "accuracy_tc_exact",
                "final_share_A",
                "final_share_B",
                "final_normA",
                "final_normB",
                "within_capacity_prop",
                "beyond_capacity_prop",
                "avg_diameter",
            ]

            if has_projection_error:
                header.extend(["final_proj_error_abs", "final_proj_error_rel"])

            writer.writerow(header)

            for result in all_results:
                row = [
                    result["p"],
                    result["final_loss"],
                    result["accuracy_er_avg"],
                    result["accuracy_er_exact"],
                    result["accuracy_tc_avg"],
                    result["accuracy_tc_exact"],
                    result["final_share_A"],
                    result["final_share_B"],
                    result["final_normA"],
                    result["final_normB"],
                    result["within_capacity_prop"],
                    result["beyond_capacity_prop"],
                    result["avg_diameter"],
                    result["accuracy_er_within_capacity_avg"],
                    result["accuracy_er_within_capacity_exact"],
                    result["accuracy_er_beyond_capacity_avg"],
                    result["accuracy_er_beyond_capacity_exact"],
                ]

                if has_projection_error:
                    # Add projection error data if available, otherwise add empty values
                    row.extend(
                        [
                            result.get("final_proj_error_abs", ""),
                            result.get("final_proj_error_rel", ""),
                        ]
                    )

                writer.writerow(row)

        # Generate analysis plots
        if len(all_results) > 1:
            generate_analysis_plots(all_results, analysis_dir)
            print(f"Analysis plots saved to: {analysis_dir}")
        else:
            print(f"Skipping analysis plots (only one p value)")

        print(f"Capacity analysis complete!")
        print(f"Results saved to: {base_dir}")
        print(f"Summary CSV saved to: {csv_path}")
    else:
        print("No results found. Cannot generate analysis plots or CSV.")
        if resume:
            print("Make sure you have run the training first before using --resume.")


def generate_projection_error_dynamics_plot(
    all_results: List[Dict[str, Any]], output_dir: str
):
    """
    Generate a comprehensive plot showing projection error training dynamics across all p values.
    This creates a heatmap-like visualization showing how projection error evolves during training for different p values.
    """
    # Filter results that have projection error data
    projection_error_results = [r for r in all_results if "projection_error_hist" in r]

    if not projection_error_results:
        print(
            "No projection error data found - skipping projection error dynamics plot"
        )
        return

    print(
        f"Generating projection error dynamics plot for {len(projection_error_results)} p values..."
    )

    # Create figure with multiple subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    # Plot 1: Relative projection error over training steps for different p values
    for result in projection_error_results:
        p = result["p"]
        projection_error_hist = result["projection_error_hist"]

        steps = np.arange(1, len(projection_error_hist) + 1)
        error_rel = [err[1] for err in projection_error_hist]

        # Use different colors for different p values
        ax1.plot(steps, error_rel, label=f"p={p:.3f}", alpha=0.7, linewidth=1.5)

    ax1.set_xlabel("Training Step")
    ax1.set_ylabel("Relative Projection Error")
    ax1.set_title("Projection Error Dynamics During Training (Relative)")
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(bottom=0)

    # Add legend but limit number of entries to avoid clutter
    if len(projection_error_results) <= 10:
        ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        # For many p values, just add a note
        ax1.text(
            0.02,
            0.98,
            f"Showing {len(projection_error_results)} p values",
            transform=ax1.transAxes,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
        )

    # Plot 2: Final projection error vs p value
    p_values = [r["p"] for r in projection_error_results]
    final_rel_errors = [
        r["projection_error_hist"][-1][1] for r in projection_error_results
    ]
    final_abs_errors = [
        r["projection_error_hist"][-1][0] for r in projection_error_results
    ]

    ax2_twin = ax2.twinx()

    line1 = ax2.plot(
        p_values,
        final_rel_errors,
        "o-",
        color="blue",
        label="Final Relative Error",
        linewidth=2,
        markersize=6,
    )
    line2 = ax2_twin.plot(
        p_values,
        final_abs_errors,
        "s-",
        color="red",
        label="Final Absolute Error",
        linewidth=2,
        markersize=6,
    )

    ax2.set_xlabel("Edge Probability (p)")
    ax2.set_ylabel("Final Relative Projection Error", color="blue")
    ax2_twin.set_ylabel("Final Absolute Projection Error", color="red")
    ax2.set_title("Final Projection Error vs Edge Probability")
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(bottom=0)
    ax2_twin.set_ylim(bottom=0)

    # Combined legend
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax2.legend(lines, labels, loc="upper right")

    plt.tight_layout()

    # Save plot
    plot_path = os.path.join(output_dir, "projection_error_training_dynamics.png")
    fig.savefig(plot_path, dpi=200, bbox_inches="tight")
    plt.close(fig)

    print(f"Generated projection error dynamics plot: {plot_path}")


def generate_analysis_plots(all_results: List[Dict[str, Any]], output_dir: str):
    """Generate analysis plots from all results."""

    p_values = [r["p"] for r in all_results]

    # 1. Average Accuracy on ER graphs (in-distribution) with capacity partitioning
    plt.figure(figsize=(10, 6))
    accuracy_er_avg = [r["accuracy_er_avg"] for r in all_results]
    accuracy_er_within_avg = [r["accuracy_er_within_capacity_avg"] for r in all_results]
    accuracy_er_beyond_avg = [r["accuracy_er_beyond_capacity_avg"] for r in all_results]

    plt.plot(
        p_values,
        accuracy_er_avg,
        "o-",
        linewidth=2,
        markersize=8,
        label="Overall ER graphs",
    )
    plt.plot(
        p_values,
        accuracy_er_within_avg,
        "s-",
        linewidth=2,
        markersize=6,
        label="Within capacity (diam ≤ 3)",
        color="green",
    )
    plt.plot(
        p_values,
        accuracy_er_beyond_avg,
        "^-",
        linewidth=2,
        markersize=6,
        label="Beyond capacity (diam > 3)",
        color="red",
    )

    plt.xlabel("Edge probability (p)")
    plt.ylabel("Average Accuracy on ER graphs")
    plt.title("In-Distribution Average Performance vs Edge Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_er_avg.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 1b. Exact Match Accuracy on ER graphs (in-distribution) with capacity partitioning
    plt.figure(figsize=(10, 6))
    accuracy_er_exact = [r["accuracy_er_exact"] for r in all_results]
    accuracy_er_within_exact = [
        r["accuracy_er_within_capacity_exact"] for r in all_results
    ]
    accuracy_er_beyond_exact = [
        r["accuracy_er_beyond_capacity_exact"] for r in all_results
    ]

    plt.plot(
        p_values,
        accuracy_er_exact,
        "o-",
        linewidth=2,
        markersize=8,
        label="Overall ER graphs",
    )
    plt.plot(
        p_values,
        accuracy_er_within_exact,
        "s-",
        linewidth=2,
        markersize=6,
        label="Within capacity (diam ≤ 3)",
        color="green",
    )
    plt.plot(
        p_values,
        accuracy_er_beyond_exact,
        "^-",
        linewidth=2,
        markersize=6,
        label="Beyond capacity (diam > 3)",
        color="red",
    )

    plt.xlabel("Edge probability (p)")
    plt.ylabel("Exact Match Accuracy on ER graphs")
    plt.title("In-Distribution Exact Match Performance vs Edge Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_er_exact.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 2. Average Accuracy on two-chain graphs (out-of-distribution)
    plt.figure(figsize=(8, 6))
    accuracy_tc_avg = [r["accuracy_tc_avg"] for r in all_results]
    plt.plot(p_values, accuracy_tc_avg, "o-", linewidth=2, markersize=8, color="orange")
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Average Accuracy on two-chain graphs")
    plt.title("Out-of-Distribution Average Performance vs Edge Probability")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_two_chains_avg.png"),
        dpi=200,
        bbox_inches="tight",
    )
    plt.close()

    # 2b. Exact Match Accuracy on two-chain graphs (out-of-distribution)
    plt.figure(figsize=(8, 6))
    accuracy_tc_exact = [r["accuracy_tc_exact"] for r in all_results]
    plt.plot(p_values, accuracy_tc_exact, "o-", linewidth=2, markersize=8, color="red")
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Exact Match Accuracy on two-chain graphs")
    plt.title("Out-of-Distribution Exact Match Performance vs Edge Probability")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_two_chains_exact.png"),
        dpi=200,
        bbox_inches="tight",
    )
    plt.close()

    # 3. Final energy shares
    plt.figure(figsize=(8, 6))
    share_A = [r["final_share_A"] for r in all_results]
    share_B = [r["final_share_B"] for r in all_results]
    plt.plot(
        p_values,
        share_A,
        "o-",
        linewidth=2,
        markersize=8,
        color="blue",
        label="A-channel share",
    )
    plt.plot(
        p_values,
        share_B,
        "s-",
        linewidth=2,
        markersize=8,
        color="orange",
        label="B-channel share",
    )
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Energy share in ||W||²")
    plt.title("Channel Energy Shares vs Edge Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "energy_shares.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 4. Final channel norms
    plt.figure(figsize=(8, 6))
    norm_A = [r["final_normA"] for r in all_results]
    norm_B = [r["final_normB"] for r in all_results]
    plt.plot(
        p_values,
        norm_A,
        "o-",
        linewidth=2,
        markersize=8,
        color="blue",
        label="||A_ell||_F",
    )
    plt.plot(
        p_values,
        norm_B,
        "s-",
        linewidth=2,
        markersize=8,
        color="orange",
        label="||B_ell||_F",
    )
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Frobenius norm")
    plt.title("Channel Parameter Norms vs Edge Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "parameter_norms.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 5. Capacity analysis (proportion within/beyond capacity)
    plt.figure(figsize=(8, 6))
    within_cap = [r["within_capacity_prop"] for r in all_results]
    beyond_cap = [r["beyond_capacity_prop"] for r in all_results]
    plt.plot(
        p_values,
        within_cap,
        "o-",
        linewidth=2,
        markersize=8,
        label="Within capacity (diam ≤ 3)",
    )
    plt.plot(
        p_values,
        beyond_cap,
        "s-",
        linewidth=2,
        markersize=8,
        label="Beyond capacity (diam > 3)",
    )
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Proportion of graphs")
    plt.title("Graph Capacity Distribution vs Edge Probability")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "capacity_distribution.png"),
        dpi=200,
        bbox_inches="tight",
    )
    plt.close()

    # 6. Average diameter
    plt.figure(figsize=(8, 6))
    avg_diam = [r["avg_diameter"] for r in all_results]
    plt.plot(p_values, avg_diam, "o-", linewidth=2, markersize=8, color="purple")
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Average diameter")
    plt.title("Average Graph Diameter vs Edge Probability")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "average_diameter.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 7. Final loss
    plt.figure(figsize=(8, 6))
    final_loss = [r["final_loss"] for r in all_results]
    plt.plot(p_values, final_loss, "o-", linewidth=2, markersize=8, color="red")
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Final training loss")
    plt.title("Final Training Loss vs Edge Probability")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "final_loss.png"), dpi=200, bbox_inches="tight"
    )
    plt.close()

    # 8a. Combined average accuracy plot with capacity partitioning
    plt.figure(figsize=(12, 6))
    plt.plot(
        p_values,
        accuracy_er_avg,
        "o-",
        linewidth=2,
        markersize=8,
        label="ER graphs (overall)",
        color="blue",
    )
    plt.plot(
        p_values,
        accuracy_er_within_avg,
        "s-",
        linewidth=2,
        markersize=6,
        label="ER within capacity (diam ≤ 3)",
        color="green",
    )
    plt.plot(
        p_values,
        accuracy_er_beyond_avg,
        "^-",
        linewidth=2,
        markersize=6,
        label="ER beyond capacity (diam > 3)",
        color="red",
    )
    plt.plot(
        p_values,
        accuracy_tc_avg,
        "d-",
        linewidth=2,
        markersize=6,
        label="Two-chains (out-of-dist)",
        color="orange",
    )
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Average Accuracy")
    plt.title("Model Average Performance: Capacity-Partitioned Analysis")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_comparison_avg.png"),
        dpi=200,
        bbox_inches="tight",
    )
    plt.close()

    # 8b. Combined exact match accuracy plot with capacity partitioning
    plt.figure(figsize=(12, 6))
    plt.plot(
        p_values,
        accuracy_er_exact,
        "o-",
        linewidth=2,
        markersize=8,
        label="ER graphs (overall)",
        color="blue",
    )
    plt.plot(
        p_values,
        accuracy_er_within_exact,
        "s-",
        linewidth=2,
        markersize=6,
        label="ER within capacity (diam ≤ 3)",
        color="green",
    )
    plt.plot(
        p_values,
        accuracy_er_beyond_exact,
        "^-",
        linewidth=2,
        markersize=6,
        label="ER beyond capacity (diam > 3)",
        color="red",
    )
    plt.plot(
        p_values,
        accuracy_tc_exact,
        "d-",
        linewidth=2,
        markersize=6,
        label="Two-chains (out-of-dist)",
        color="orange",
    )
    plt.xlabel("Edge probability (p)")
    plt.ylabel("Exact Match Accuracy")
    plt.title("Model Exact Match Performance: Capacity-Partitioned Analysis")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        os.path.join(output_dir, "accuracy_comparison_exact.png"),
        dpi=200,
        bbox_inches="tight",
    )
    plt.close()

    # 9. Decomposition projection error vs edge probability (for non-decomposition models only)
    # Check if we have projection error data (only available when use_decomposition=False)
    projection_error_data = [r.get("final_proj_error_abs") for r in all_results]
    if any(pe is not None for pe in projection_error_data):
        plt.figure(figsize=(8, 6))

        # Filter out None values and corresponding p_values
        valid_p_values = []
        valid_proj_error_rel = []

        for i, result in enumerate(all_results):
            if result.get("final_proj_error_abs") is not None:
                valid_p_values.append(p_values[i])
                valid_proj_error_rel.append(result["final_proj_error_rel"])

        if valid_proj_error_rel:
            # Plot only relative projection error
            plt.plot(
                valid_p_values,
                valid_proj_error_rel,
                "o-",
                linewidth=2,
                markersize=8,
                color="blue",
                label="Relative Projection Error",
            )

            plt.xlabel("Edge probability (p)")
            plt.ylabel("Relative Projection Error")
            plt.title(
                "Decomposition Projection Error vs Edge Probability\n(Non-decomposition Model)"
            )
            plt.grid(True, alpha=0.3)
            plt.legend()

            plt.tight_layout()
            plt.savefig(
                os.path.join(output_dir, "projection_error_vs_p.png"),
                dpi=200,
                bbox_inches="tight",
            )
            plt.close()

    # 9. Generate comprehensive projection error training dynamics plot
    generate_projection_error_dynamics_plot(all_results, output_dir)

    print(f"Generated analysis plots in {output_dir}")


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Capacity Analysis for Disentangled Transformer on Erdős–Rényi graphs",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--restrict_diam",
        type=str,
        default="3",
        help="Diameter restriction for training graphs. Use 'false' for no restriction, 'true' for default (3), or an integer value",
    )

    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume from previously saved checkpoints and only run evaluation (skip training)",
    )

    parser.add_argument(
        "--use_decomposition",
        action="store_true",
        default=False,
        help="Use decomposed W = A⊗I + B⊗J. If False, use whole matrix W with I/J projections",
    )

    parser.add_argument(
        "--p",
        type=float,
        default=None,
        help="Run analysis for a specific p value only. If not provided, runs for all p values in the default range",
    )

    return parser.parse_args()


def parse_restrict_diam(restrict_diam_str: str) -> Union[int, bool]:
    """Parse the restrict_diam argument from string to appropriate type."""
    if restrict_diam_str.lower() == "false":
        return False
    elif restrict_diam_str.lower() == "true":
        return True
    else:
        try:
            return int(restrict_diam_str)
        except ValueError:
            raise ValueError(
                f"Invalid restrict_diam value: {restrict_diam_str}. Use 'false', 'true', or an integer."
            )


# ----------------------------
# Main
# ----------------------------

if __name__ == "__main__":
    args = parse_arguments()
    restrict_diam = parse_restrict_diam(args.restrict_diam)
    print(
        f"Running capacity analysis with restrict_diam={restrict_diam}, resume={args.resume}, use_decomposition={args.use_decomposition}"
    )
    run_capacity_analysis(
        restrict_diam=restrict_diam,
        resume=args.resume,
        use_decomposition=args.use_decomposition,
        specific_p=args.p,
    )
