import time
import operator
import jax
import jax.numpy as jnp  # JAX NumPy
from flax import linen as nn  # Linen API
from clu import metrics
from flax import struct
import optax  # Common loss functions and optimizers
import chex
import numpy as np
import os
import sys
import json
import flax.serialization as serialization
import jax.tree_util
from numpy.fft import fft2, ifft2
import plotly.io as pio
from flax.core.frozen_dict import freeze, unfreeze
from typing import Dict, Any, Tuple, Union, List
import re

jax.config.update("jax_traceback_filtering", 'off')
# jax.config.update("jax_debug_nans", True)

import optimizers
import training
from pca_diffusion_plots_w_helpers_transformers import generate_pdf_plots_for_matrix
from plotting import make_cluster_html_pages

#############################
### 1) DEFINE THE TRANSFORMER IN FLAX
#############################
loss_log_dir = ""

def check_unembed_rank(params, *, atol=1e-17, label="(unknown)") -> None:
    """Assert that params['W_U'] is full-rank and print its top-5 σ."""
    W = np.asarray(params["W_U"]).astype(np.float64)  # promote to f64
    s = np.linalg.svd(W, compute_uv=False)
    rel = s / s.max()
    eff_rank = int((rel > atol).sum())
    top5 = ", ".join(f"{v:8.3e}" for v in s[:9])
    print(f"[rank-check {label}] effective-rank={eff_rank:>2}   top-5 σ: {top5}")
    if eff_rank < 2:
        raise RuntimeError(
            f"W_U lost rank! ({eff_rank} < {W.shape[1]})  "
            "Did you mutate a shared buffer?"
        )

class HookPoint(nn.Module):
    """Captures activations via sow so they appear in the
    intermediates collection of model.apply.

    Args
    ----
    key : str | None
        Optional explicit name.  If None we fall back to the module’s
        scope path, e.g. "blocks_0/mlp/hook_pre".
    """
    key: str | None = None

    @nn.compact
    def __call__(self, x):
        # Derive a unique name if none was given
        name = self.key or "/".join(self.scope.path)
        # Store the tensor every forward pass
        self.sow("intermediates", name, x, reduce_fn=lambda _, v: v)
        return x


class Embed(nn.Module):
    d_vocab: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        """x: [batch, seq_len], returns [batch, seq_len, d_model]."""
        embedding = nn.Embed(
            num_embeddings=self.d_vocab,
            features=self.d_model,
            embedding_init=nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model))
        )
        return embedding(x)

class PosEmbed(nn.Module):
    max_ctx: int
    d_model: int

    def setup(self):
        self.W_pos = self.param(
            "W_pos",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.max_ctx, self.d_model),
        )

    def __call__(self, x):
        """
        x: [batch, seq_len, d_model]
        Add learned position embeddings for the first seq_len positions.
        """
        seq_len = x.shape[1]
        pos_emb = self.W_pos[:seq_len]  # [seq_len, d_model]
        return x + pos_emb[jnp.newaxis, :, :]

class Attention(nn.Module):
    d_model: int
    num_heads: int
    d_head: int
    n_ctx: int
    attn_coeff: float

    def setup(self):
        self.W_K = self.param(
            "W_K",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        self.W_Q = self.param(
            "W_Q",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        self.W_V = self.param(
            "W_V",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        # Final linear after concatenating heads
        self.W_O = self.param(
            "W_O",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.d_model, self.num_heads * self.d_head),
        )
        # Causal mask of shape (n_ctx, n_ctx), typically (2,2) for this example
        causal_mask = np.tril(np.ones((self.n_ctx, self.n_ctx), dtype=np.float32))
        self.causal_mask = jnp.array(causal_mask)

        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def __call__(self, x):
        batch_size, seq_len, _ = x.shape
        def project(W, x_):
            return jnp.einsum("ihd,bpd->biph", W, x_)

        k = self.hook_k(project(self.W_K, x))
        q = self.hook_q(project(self.W_Q, x))
        v = self.hook_v(project(self.W_V, x))

        attn_scores_pre = jnp.einsum("biph,biqh->biqp", k, q)
        attn_scores_pre = self.hook_attn_pre(attn_scores_pre / np.sqrt(self.d_head))

        full_mask = self.causal_mask[:seq_len, :seq_len]
        mask = (1.0 - full_mask) * -1e10
        attn_scores_masked = attn_scores_pre #+ mask

        attn_matrix = nn.softmax(attn_scores_masked, axis=-1)
        attn_matrix = attn_matrix * self.attn_coeff + (1.0 - self.attn_coeff)
        attn_matrix = self.hook_attn(attn_matrix)

        z = jnp.einsum("biph,biqp->biqh", v, attn_matrix)
        z = self.hook_z(z)

        z_trans = jnp.transpose(z, (0, 2, 1, 3))  # (b, seq_len, heads, d_head)
        z_flat = jnp.reshape(z_trans, (batch_size, seq_len, self.num_heads * self.d_head))

        out = jnp.einsum("df,bpf->bpd", self.W_O, z_flat)
        return out

class MLP(nn.Module):
    d_model: int
    d_mlp: int
    num_layers: int
    act_type: str = "ReLU"

    # ---------- parameters & hooks ----------
    def setup(self):
        # For each hidden layer i, stash its params and hooks under unique names.
        for i in range(self.num_layers):
            in_dim  = self.d_model if i == 0 else self.d_mlp
            out_dim = self.d_mlp

            # weight + bias
            setattr(
                self, f"W_{i}",
                self.param(f"W_{i}",
                           nn.initializers.normal(stddev=1/np.sqrt(out_dim)),
                           (out_dim, in_dim))
            )
            setattr(
                self, f"b_{i}",
                self.param(f"b_{i}", nn.initializers.zeros, (out_dim,))
            )

            # hooks before & after activation
            setattr(self, f"hook_pre{i+1}",
                    HookPoint(key=f"blocks_0/mlp/hook_pre{i+1}"))
            setattr(self, f"hook_post{i+1}",
                    HookPoint(key=f"blocks_0/mlp/hook_post{i+1}"))

        # final projection back to d_model
        self.W_out = self.param(
            "W_out",
            nn.initializers.normal(stddev=1/np.sqrt(self.d_model)),
            (self.d_model, self.d_mlp)
        )
        self.b_out = self.param("b_out", nn.initializers.zeros, (self.d_model,))

    # ---------- forward ----------
    def _act(self, x):
        if self.act_type == "ReLU":
            return nn.relu(x)
        raise ValueError(f"Unsupported activation {self.act_type!r}")

    def __call__(self, x):
        h = x
        for i in range(self.num_layers):
            W     = getattr(self, f"W_{i}")
            b     = getattr(self, f"b_{i}")
            pre_h = getattr(self, f"hook_pre{i+1}")
            post_h= getattr(self, f"hook_post{i+1}")

            pre = pre_h(jnp.einsum("md,bpd->bpm", W, h) + b)
            h   = post_h(self._act(pre))

        # final output projection
        return jnp.einsum("dm,bpm->bpd", self.W_out, h) + self.b_out



class TransformerBlock(nn.Module):
    d_model: int
    d_head: int
    num_heads: int
    n_ctx: int
    act_type: str
    attn_coeff: float
    num_mlp_layers: int 

    def setup(self):
        self.attn = Attention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            d_head=self.d_head,
            n_ctx=self.n_ctx,
            attn_coeff=self.attn_coeff
        )
        self.mlp = MLP(
            d_model=self.d_model,
            d_mlp=self.d_model * num_neurons,
            num_layers=self.num_mlp_layers,
            act_type=self.act_type,
        )
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def __call__(self, x):
        resid_pre = self.hook_resid_pre(x)
        attn_out = self.attn(resid_pre)
        attn_out = self.hook_attn_out(attn_out)
        x_mid = self.hook_resid_mid(resid_pre + attn_out)

        mlp_out = self.mlp(x_mid)
        mlp_out = self.hook_mlp_out(mlp_out)
        x_post = self.hook_resid_post(x_mid + mlp_out)
        return x_post

class Transformer(nn.Module):
    """
    A simple Transformer with:
    - Embedding
    - PosEmbed
    - N Transformer blocks
    - A final linear "unembed" (W_U) -> [batch, seq_len, d_vocab]
    """
    num_layers: int
    num_mlp_layers: int
    d_vocab: int
    d_model: int
    d_head: int
    num_heads: int
    n_ctx: int
    act_type: str
    attn_coeff: float

    def setup(self):
        self.embed = Embed(self.d_vocab, self.d_model)
        self.pos_embed = PosEmbed(self.n_ctx, self.d_model)
        self.blocks = [TransformerBlock(
            d_model=self.d_model,
            d_head=self.d_head,
            num_heads=self.num_heads,
            n_ctx=self.n_ctx,
            act_type=self.act_type,
            attn_coeff=self.attn_coeff,
            num_mlp_layers=self.num_mlp_layers
        ) for _ in range(self.num_layers)]
        self.blocks = nn.Sequential(self.blocks)
        self.W_U = self.param(
            "W_U",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_vocab)),
            (self.d_model, self.d_vocab),
        )

    def __call__(self, x, training=False):
        x_emb = self.embed(x)           # [batch, seq_len, d_model]
        x_emb = self.pos_embed(x_emb)   # [batch, seq_len, d_model]
        x_out = self.blocks(x_emb)      # [batch, seq_len, d_model]
        logits = jnp.einsum("dm,bpd->bpm", self.W_U, x_out)
        return logits

def _extract_embeddings_ab_transformer(self, params):
    """
    Returns W_E twice (transformer shares one table for 'a' and 'b').

    Works whether the weight lives at
        params["embed"]["embedding"]          (older Flax)
    or at
        params["embed"]["Embed_0"]["embedding"] (current Flax).
    """
    embed_sub = params["embed"]

    # 1️⃣  most common case
    if "embedding" in embed_sub:
        W_E = embed_sub["embedding"]
    else:
        # 2️⃣  look one level deeper for the first array leaf
        #     (this covers the 'Embed_0' scope)
        #     We purposely stop at the first ndarray we find.
        for v in embed_sub.values():
            leaf = jax.tree_util.tree_leaves(v)
            if leaf:                       # non-empty list → we found the weight
                W_E = leaf[0]
                break
        else:
            raise KeyError(
                "Could not locate the token-embedding weight inside params['embed']"
            )

    return W_E, W_E                            # share for a and b

def _call_from_embedding_transformer(self, x_emb, params):
    """
    Forward pass that starts from an *embedding* tensor instead of token ids.
    Returns the length-p logits for the final token.
    """
    # 1️⃣  add the learned position embeddings
    seq_len = x_emb.shape[0]            # here: 2
    x_emb = x_emb + params["pos_embed"]["W_pos"][:seq_len]

    # 2️⃣  run the single TransformerBlock manually
    block_out = TransformerBlock(
        d_model=self.d_model,
        d_head=self.d_head,
        num_heads=self.num_heads,
        n_ctx=self.n_ctx,
        act_type=self.act_type,
        attn_coeff=self.attn_coeff,
        num_mlp_layers=self.num_mlp_layers,
    ).apply({"params": params["blocks_0"]}, x_emb[None, ...])   # (1,2,d_model)

    # 3️⃣  unembed and return the last-token logits
    logits = jnp.einsum("dm,bpd->bpm", params["W_U"], block_out)   # (1,2,p)
    return logits[0, -1]                                         # (p,)

# Attach the two helpers to the class
Transformer.extract_embeddings_ab = _extract_embeddings_ab_transformer
Transformer.call_from_embedding   = _call_from_embedding_transformer

def call_from_embedding_sequence(self, seq_emb, params):
    # seq_emb: (1,2,d_model)
    # 1) add pos‑emb
    seq_len = seq_emb.shape[1]
    x = seq_emb + params["pos_embed"]["W_pos"][:seq_len]

    # 2) run each TransformerBlock manually (here just blocks_0 since num_layers=1)
    x = TransformerBlock(
        d_model=self.d_model,
        d_head=self.d_head,
        num_heads=self.num_heads,
        n_ctx=self.n_ctx,
        act_type=self.act_type,
        attn_coeff=self.attn_coeff,
        num_mlp_layers=self.num_mlp_layers
    ).apply({"params": params["blocks_0"]}, x)    # → (1,2,d_model)

    # 3) unembed
    logits = jnp.einsum("dm,bpd->bpm", params["W_U"], x)  # (1,2,p)
    return logits
Transformer.call_from_embedding_sequence = call_from_embedding_sequence

if len(sys.argv) < 15:
    print("Usage: script.py <learning_rate> <weight_decay> <p> <batch_size> <optimizer> <epochs> <k> <batch_experiment> <num_neurons> <zeta> <training_set_size> <momentum> <injected_noise> <num_mlp_layers> <random_seed_int_1> [<random_seed_int_2> ...]")
    sys.exit(1)

print("start args parsing")
# Parse command-line arguments
learning_rate = float(sys.argv[1])  # stepsize_
weight_decay = float(sys.argv[2])  # L2 norm
p = int(sys.argv[3])
batch_size = int(sys.argv[4])
optimizer = sys.argv[5]
epochs = int(sys.argv[6])
k = int(sys.argv[7])
batch_experiment = sys.argv[8]
num_neurons = int(sys.argv[9])  # not used, but kept for consistency
zeta = int(sys.argv[10])
training_set_size = k * batch_size
momentum = float(sys.argv[12])
injected_noise = float(sys.argv[13]) / float(k)
num_mlp_layers = int(sys.argv[14])

# Accept multiple random seeds
random_seed_ints = [int(seed) for seed in sys.argv[15:]]
num_models = len(random_seed_ints)

def compute_useless_metrics(
        *,
        model,
        params: dict,
        p: int,
        rng_seed: int = 42,
        max_samples: int = 3481,       # <= p² keeps memory small
    ) -> Tuple[Dict[str, float], Dict[str, float]]:
        """
        Returns
        -------
        grad_stats : {"average_gradient_symmetricity", "std_dev_gradient_symmetricity"}
        dist_stats : {"average_distance_irrelevance",  "std_dev_distance_irrelevance"}
        """
        # ---------- (A)  gradient-symmetricity ---------------------------
        a_arr, b_arr, c_arr = jnp.meshgrid(
            jnp.arange(p, dtype=jnp.int32),
            jnp.arange(p, dtype=jnp.int32),
            jnp.arange(p, dtype=jnp.int32),
            indexing="ij"
        )
        # flatten to 1-D
        a_arr = a_arr.ravel()
        b_arr = b_arr.ravel()
        c_arr = c_arr.ravel()

        cos_sims = batched_gradient_similarity(
            model=model,
            params=params,
            a_batch=a_arr,
            b_batch=b_arr,
            c_batch=c_arr,
        )
        cos_np = np.asarray(cos_sims)

        grad_stats = {
            "average_gradient_symmetricity": float(cos_np.mean()),
            "std_dev_gradient_symmetricity": float(cos_np.std()),
        }

        # ---------- (B)  distance-irrelevance ----------------------------
        a_grid, b_grid = np.meshgrid(np.arange(p), np.arange(p), indexing="ij")
        x_full = jnp.stack([a_grid.ravel(), b_grid.ravel()], axis=-1).astype(jnp.int32)

        logits = model.apply({"params": params}, x_full, training=False)[:, -1, :]  # (p²,p)
        logits_np = np.asarray(logits)
        correct_idx = ((a_grid + b_grid) % p).ravel()
        correct_logits = logits_np[np.arange(p *    p), correct_idx]            # (p²,)

        # arrange into L[i,j] with i=a+b, j=a-b  (both mod p)
        L = np.empty((p, p), dtype=float)
        i_mat, j_mat = (a_grid + b_grid) % p, (a_grid - b_grid) % p
        L[i_mat, j_mat] = correct_logits.reshape(p, p)

        col_stds   = L.std(axis=0)
        global_std = L.std() + 1e-12      # avoid /0 for degenerate nets

        q_vals = col_stds / global_std
        dist_stats = {
            "average_distance_irrelevance": float(q_vals.mean()),
            "std_dev_distance_irrelevance": float(q_vals.std()),
        }

        return grad_stats, dist_stats

# def lr_schedule_fn(step):
#     total_steps = epochs * k
#     warmup_steps = total_steps // 4  # warmup over first 25% of training

#     def warmup_fn(step_):
#         return learning_rate * (step_ / warmup_steps)

#     def constant_fn(step_):
#         return learning_rate

#     return jax.lax.cond(
#         step < warmup_steps,
#         warmup_fn,
#         constant_fn,
#         operand=step
#     )

def lr_schedule_fn(step):
    total_steps = epochs * k
    warmup_steps = total_steps // 2
    cooldown_steps = total_steps - warmup_steps

    def warmup_fn(step_):
        return learning_rate * (step_ / warmup_steps)

    def cooldown_fn(step_):
        return learning_rate * (1 - (step_ - warmup_steps) / cooldown_steps)

    return jax.lax.cond(
        step < warmup_steps,
        warmup_fn,
        cooldown_fn,
        operand=step
    )

print("making dataset")

def generate_polynomial_dataset(p, f, num_train_batches, batch_size, rng: jax.random.PRNGKey):
    total_possible_samples = p * p
    total_samples_needed = num_train_batches * batch_size
    if total_samples_needed > total_possible_samples:
        raise ValueError("Not enough data samples for the requested number of batches.")

    a, b = jnp.mgrid[0:p, 0:p]
    y = f(a, b)
    a_flat = a.ravel()
    b_flat = b.ravel()
    y_flat = y.ravel()
    data = jnp.stack([a_flat, b_flat, y_flat], axis=1)

    rng, subkey = jax.random.split(rng)
    indices = jax.random.choice(subkey, total_possible_samples, (total_samples_needed,), replace=False)
    train_data = data[indices]
    train_data = train_data.reshape(num_train_batches, batch_size, 3)
    return train_data

def generate_polynomial_dataset_for_seed(seed):
    rng_key = jax.random.PRNGKey(seed)
    num_train_batches = k
    train_data = generate_polynomial_dataset(
        p, lambda a, b: jnp.mod(a + b, p), num_train_batches, batch_size, rng_key)
    return train_data

if batch_experiment == "random_random":
    train_ds_list = []
    for seed in random_seed_ints:
        train_data = generate_polynomial_dataset_for_seed(seed)
        train_ds_list.append(train_data)
    train_ds = jnp.stack(train_ds_list)  # [num_models, k, batch_size, 3]
    print(f"Number of training batches: {train_ds.shape[1]}")

print("made dataset")

def compute_pytree_size(pytree):
    total_size = 0
    for array in jax.tree_util.tree_leaves(pytree):
        total_size += array.size * array.dtype.itemsize
    return total_size

dataset_size_bytes = (train_ds.shape[1] *
                      train_ds.shape[2] *
                      train_ds.shape[3] *
                      train_ds.dtype.itemsize)
dataset_size_mb = dataset_size_bytes / (1024 ** 2)
print(f"Dataset size per model: {dataset_size_mb:.2f} MB")

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')
    l2_loss: metrics.Average.from_output('l2_loss')

transformer_config = {
    "num_layers": 1,
    "num_mlp_layers": 1,    
    "d_vocab": p,
    "d_model": 128,
    "d_head": 32,
    "num_heads": 4,
    "n_ctx": 2,
    "act_type": "ReLU",
    "attn_coeff": 1.0
}


transformer_config["num_mlp_layers"] = num_mlp_layers
NUM_MLP_LAYERS = transformer_config["num_mlp_layers"]
assert transformer_config["d_head"] * transformer_config["num_heads"] == transformer_config["d_model"]
BASE_DIR = f"scratch/ICLR-freqs-run-1/quantitative_metrics_transformer_{num_mlp_layers}_heatmaps_log{p}_{transformer_config['attn_coeff']}_k_{k}"
os.makedirs(BASE_DIR, exist_ok=True)

model = Transformer(**transformer_config)
dummy_x = jnp.zeros((batch_size, 2), dtype=jnp.int32)


def cross_entropy_loss(y_pred, y):
    logits_last = y_pred[:, -1, :]
    log_probs = jax.nn.log_softmax(logits_last, axis=-1)
    nll = - log_probs[jnp.arange(y.shape[0]), y]
    return nll.mean()

def total_loss(y_pred_and_l2, y):
    y_pred, _, l2_loss = y_pred_and_l2
    return cross_entropy_loss(y_pred, y) + l2_loss * weight_decay

def apply(variables, x, training=False):
    params = variables["params"]
    outputs = model.apply({"params": params}, x, training=training)
    l2_loss = sum(jnp.sum(jnp.square(p_)) for p_ in jax.tree_util.tree_leaves(params))
    return outputs, {}, l2_loss

def batched_apply(variables_batch, x_batch, training=False):
    return jax.vmap(lambda vars_, xx: apply(vars_, xx, training),
                    in_axes=(0, 0))(variables_batch, x_batch)

def sample_hessian(prediction, sample):
    logits_2d = prediction[0][:, -1, :]
    labels = sample[0][:, 2]
    return (optimizers.sample_crossentropy_hessian(logits_2d, labels), prediction[1], 0.0)

def compute_metrics(metrics_, *, loss, l2_loss, outputs, labels):
    logits_last = outputs[0][:, -1, :]
    metric_updates = metrics_.single_from_model_output(
        logits=logits_last, labels=labels, loss=loss, l2_loss=l2_loss)
    return metrics_.merge(metric_updates)

def prepare_batches(batches_array):
    x = batches_array[:, :, :, :2].astype(jnp.int32)  # [num_models, k, batch_size, 2]
    y = batches_array[:, :, :, 2].astype(jnp.int32)   # [num_models, k, batch_size]
    return x, y

print("Transformer model created.")
variables_list = []
for seed in random_seed_ints:
    rng_key = jax.random.PRNGKey(seed)
    variables = model.init(rng_key, dummy_x, training=False)

    # Now print out every parameter name & its shape:
    def print_param_tree(tree, prefix=""):
        for name, subtree in tree.items():
            if isinstance(subtree, dict):
                print_param_tree(subtree, prefix + name + "/")
            else:
                # leaf array: print its full key and shape
                print(f"{prefix + name} — shape {tuple(subtree.shape)}")

    print("=== Transformer parameter names ===")
    print_param_tree(variables["params"])

    variables_list.append(variables)

model_size_bytes = compute_pytree_size(variables_list[0]["params"])
model_size_mb = model_size_bytes / (1024 ** 2)
print(f"Single model size: {model_size_mb:.2f} MB")

variables_batch = {
    "params": jax.tree_util.tree_map(
        lambda *args: jnp.stack(args),
        *(v["params"] for v in variables_list)
    ),
    "batch_stats": None
}

if optimizer == "adam":
    tx = optax.adam(lr_schedule_fn)
elif optimizer[:3] == "SGD":
    tx = optax.sgd(learning_rate, momentum)
else:
    raise ValueError("Unsupported optimizer type")

def init_opt_state(params):
    return tx.init(params)

opt_state_list = []
for i in range(num_models):
    params_i = jax.tree_map(lambda x: x[i].copy(), variables_batch["params"])
    opt_state_i = init_opt_state(params_i)
    opt_state_list.append(opt_state_i)

opt_state_batch = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *opt_state_list)

def create_train_state(params, opt_state, rng_key):
    return training.TrainState(
        apply_fn=apply,
        params=params,
        tx=tx,
        opt_state=opt_state,
        loss_fn=total_loss,
        loss_hessian_fn=sample_hessian,
        compute_metrics_fn=compute_metrics,
        rng_key=rng_key,
        initial_metrics=Metrics,
        batch_stats=None,
        injected_noise=injected_noise
    )

states_list = []
for i in range(num_models):
    seed = random_seed_ints[i]
    rng_key = jax.random.PRNGKey(seed)
    params_i = jax.tree_map(lambda x: x[i].copy(), variables_batch["params"])
    opt_state_i = jax.tree_map(lambda x: x[i].copy(), opt_state_batch)
    st = create_train_state(params_i, opt_state_i, rng_key)
    states_list.append(st)

states = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *states_list)

train_x, train_y = prepare_batches(train_ds)
train_x = jax.device_put(train_x)  # [num_models, k, batch_size, 2]
train_y = jax.device_put(train_y)  # [num_models, k, batch_size]

initial_metrics_list = [st.initial_metrics.empty() for st in states_list]
initial_metrics = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *initial_metrics_list)

# EVAL data
if training_set_size == p * p:
    print("Train set is the entire dataset. Using the training set for evaluation.")
    # Reshape training data: train_ds is [num_models, k, batch_size, 3].
    # We extract the first two columns for x and the third column for y.
    x_eval = train_ds[:, :, :, :2].reshape(num_models, -1, 2)
    y_eval = train_ds[:, :, :, 2].reshape(num_models, -1)
else:
    a_eval, b_eval = jnp.mgrid[0:p, 0:p]
    a_eval = a_eval.ravel()
    b_eval = b_eval.ravel()
    x_eval = jnp.stack([a_eval, b_eval], axis=-1).astype(jnp.int32)
    y_eval = jnp.mod(a_eval + b_eval, p).astype(jnp.int32)
    # Expand eval set for each model (each model gets the same eval set)
    x_eval = jax.device_put(x_eval)
    y_eval = jax.device_put(y_eval)
    x_eval = jnp.tile(x_eval[None, :, :], (num_models, 1, 1))
    y_eval = jnp.tile(y_eval[None, :], (num_models, 1))

# If using the training set as eval, ensure data is on-device.
x_eval = jax.device_put(x_eval)
y_eval = jax.device_put(y_eval)

eval_batch_size = 1024
total_eval_samples = x_eval.shape[1]
num_full_batches = total_eval_samples // eval_batch_size
remaining_samples = total_eval_samples % eval_batch_size

if remaining_samples > 0:
    pad_size = eval_batch_size - remaining_samples
    x_padding = jnp.zeros((num_models, pad_size, x_eval.shape[2]), dtype=x_eval.dtype)
    y_padding = jnp.zeros((num_models, pad_size), dtype=y_eval.dtype)
    x_eval_padded = jnp.concatenate([x_eval, x_padding], axis=1)
    y_eval_padded = jnp.concatenate([y_eval, y_padding], axis=1)
    num_eval_batches = num_full_batches + 1
else:
    x_eval_padded = x_eval
    y_eval_padded = y_eval
    num_eval_batches = num_full_batches

x_eval_batches = x_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size, -1)
y_eval_batches = y_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size)


@jax.jit
def train_epoch(states_, x_batches, y_batches, init_metrics):
    """
    x_batches: [num_models, k2, batch_size2, 2]
    y_batches: [num_models, k2, batch_size2]
    We do a jax.lax.scan over the k2 dimension (the 'batch' dimension).
    """
    def train_step(carry, batch):
        (st, mets) = carry
        x_, y_ = batch
        new_states, new_metrics = jax.vmap(
            lambda s, m, xx, yy: s.train_step(m, (xx, yy)),
            in_axes=(0, 0, 0, 0)
        )(st, mets, x_, y_)
        return (new_states, new_metrics), None

    carry0 = (states_, init_metrics)
    transposed_x = x_batches.transpose(1, 0, 2, 3)  # shape [k2, num_models, batch_size2, 2]
    transposed_y = y_batches.transpose(1, 0, 2)     # shape [k2, num_models, batch_size2]
    (final_states, final_metrics), _ = jax.lax.scan(
        train_step, carry0, (transposed_x, transposed_y)
    )
    return final_states, final_metrics

@jax.jit
def eval_model(states_, x_batches, y_batches, init_metrics):
    def eval_step(mets, batch):
        x_, y_ = batch
        new_metrics = jax.vmap(
            lambda s, mm, xx, yy: s.eval_step(mm, (xx, yy)),
            in_axes=(0, 0, 0, 0)
        )(states_, mets, x_, y_)
        return new_metrics, None

    mets_ = init_metrics
    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    final_metrics, _ = jax.lax.scan(eval_step, mets_, (transposed_x, transposed_y))
    return final_metrics


# Track the first epoch where each model hits 100% accuracy
first_100_test_loss = [None] * num_models
first_100_cross_entropy_loss = [None] * num_models 
first_100_epoch           = [None] * num_models          
first_100_summary         = [None] * num_models 
cross_entropies = {}

######################################
# 1) Original Training Loop
######################################
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs} (original batch_size={batch_size})")
    states, train_metrics = train_epoch(states, train_x, train_y, initial_metrics)
    # Print training metrics
    for i in range(num_models):
        train_metric_i = jax.tree_util.tree_map(lambda x: x[i], train_metrics)
        train_metric_i = train_metric_i.compute()
        print(f"Model {i + 1}/{num_models}: Train Loss: {train_metric_i['loss']:.6f}, "
              f"Train Accuracy: {train_metric_i['accuracy']:.2%}")

    # Evaluate on test
    print(f"\n--- Test Evaluation at Epoch {epoch + 1} ---")
    test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
    for i in range(num_models):
        test_metric_i = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
        test_metric_i = test_metric_i.compute()
        test_loss = test_metric_i['loss']
        test_accuracy = test_metric_i['accuracy']
        print(f" Model {i + 1}/{num_models} -> Test Loss: {test_loss:.6f},  Test Acc: {test_accuracy:.2%}")
        cross_entropy_val = test_loss - weight_decay * test_metric_i['l2_loss']

        # Record first time hitting 100%
        if first_100_test_loss[i] is None and test_accuracy > 0.99999:
            first_100_epoch[i]              = epoch + 1            # 1-based
            first_100_test_loss[i]          = float(test_loss)
            first_100_cross_entropy_loss[i] = float(cross_entropy_val)

            # Build once so we can dump it later verbatim
            first_100_summary[i] = {
                "epoch":   first_100_epoch[i],
                "loss":    first_100_test_loss[i],
                "ce_loss": first_100_cross_entropy_loss[i],
            }
    check_unembed_rank(jax.tree_util.tree_map(lambda x: x[0], states.params), label=f"after epoch {epoch+1}")
    print("--- End of Test Evaluation ---\n")

test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
final_test_accuracies = []

def average_margin(logits_2d: jnp.ndarray, labels_1d: jnp.ndarray) -> float:
    """
    logits_2d : (N , p)   – logits for the last token
    labels_1d : (N,)      – correct class indices
    Returns the mean margin  (logit_correct − best_wrong_logit) over N samples.
    """
    # logit of the correct class
    corr = logits_2d[jnp.arange(logits_2d.shape[0]), labels_1d]

    # best wrong logit: mask the correct one with −∞, then row-max
    wrong = logits_2d.at[(jnp.arange(logits_2d.shape[0]), labels_1d)].set(-jnp.inf)
    best_wrong = wrong.max(axis=1)

    return float(jnp.mean(corr - best_wrong))

for i in range(num_models):
    test_metric = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
    test_metric = test_metric.compute()
    test_accuracy = test_metric["accuracy"]
    final_test_accuracies.append(test_accuracy)
    print(f"Model {i + 1} final test accuracy: {test_accuracy:.2%}")
    cross_entropy = test_metric['loss'] - weight_decay * test_metric['l2_loss']
    seed = random_seed_ints[i]

    # ---------- single-model views ----------
    x_eval_i = x_eval[i]          # (p² , 2)
    y_eval_i = y_eval[i]          # (p² ,)
    params_i = jax.tree_util.tree_map(lambda x: x[i].copy(), states.params)

    # ---------- logits & margin ----------
    logits_full = model.apply({'params': params_i},
                              x_eval_i,
                              training=False)         # ✓ correct shape
    logits_last = logits_full[:, -1, :]               # (p² , p)
    avg_margin  = average_margin(logits_last, y_eval_i)

    cross_entropies[seed] = {
        "cross_entropy": float(cross_entropy),
        "avg_margin":    avg_margin
    }
    
    # (Saving or not saving logic the same as before)
    if test_accuracy >= 0.99999:
        experiment_name = batch_experiment
        optimizer_name = optimizer + str(momentum)
        num_neurons_transformer = transformer_config["d_model"] * num_neurons
        params_file_path = (
            f"{BASE_DIR}/params_transformer_r2_heatmap_k={k}_{epochs}/{p}_{k}_nn_{num_neurons}_fits_attn-co={transformer_config['attn_coeff']}_top-k_layers={num_mlp_layers}/"
            f"p={p}_bs={batch_size}_k={k}_nn={num_neurons_transformer}_lr={learning_rate}_wd={weight_decay}_epochs={epochs}_"
            f"training_set_size={training_set_size}/params_p_{p}_{batch_experiment}_"
            f"{optimizer_name}_ts_{training_set_size}_bs={batch_size}_nn={num_neurons}_"
            f"lr={learning_rate}_wd={weight_decay}_noise={injected_noise}_zeta={zeta}_k={k}_"
            f"rs_{random_seed_ints[i]}.params"
        )
        os.makedirs(os.path.dirname(params_file_path), exist_ok=True)
        with open(params_file_path, 'wb') as f:
            f.write(serialization.to_bytes(jax.tree_util.tree_map(lambda x: x[i], states.params)))
        print(f"Model {i + 1}: Parameters saved to {params_file_path}")
    else:
        print(f"Model {i + 1}: Test accuracy did not exceed 99.9%. Model parameters will not be saved.")
        # print(f"\n--- Misclassified Test Examples for Model {i + 1} ---")
        # single_params = jax.tree_map(lambda x: x[i], states.params)
        # logits = model.apply({'params': single_params}, x_eval, training=False)
        # predictions = jnp.argmax(logits[:, -1, :], axis=-1)
        # y_true = y_eval
        # incorrect_mask = predictions != y_true
        # incorrect_indices = jnp.where(incorrect_mask)[0]
        # if incorrect_indices.size > 0:
        #     print(f"  Total misclassifications: {len(incorrect_indices)}")
        #     for idx, (x_vals, true_label, pred_label) in enumerate(
        #         zip(x_eval[incorrect_indices],
        #             y_true[incorrect_indices],
        #             predictions[incorrect_indices]), 1
        #     ):
        #         a_val, b_val = x_vals
        #         print(f"    {idx}. a: {int(a_val)}, b: {int(b_val)}, True: {int(true_label)}, Predicted: {int(pred_label)}")
        # else:
        #     print("No misclassifications found. All predictions correct.")

    # Write cross-entropy to loss_log.txt
    num_neurons_transformer = transformer_config["d_model"] * num_neurons
    loss_log_dir = (
        f"{BASE_DIR}/transformer_r2_heatmap_k={k}_{epochs}/{p}_{k}_nn_{num_neurons}_fits_attn-co={transformer_config['attn_coeff']}_top-k_layers={num_mlp_layers}/"
        f"p={p}_bs={batch_size}_k={k}_nn={num_neurons_transformer}_lr={learning_rate}_wd={weight_decay}_epochs={epochs}_"
        f"training_set_size={training_set_size}/"
    )
    loss_log_path = os.path.join(loss_log_dir, "loss_log.txt")
    os.makedirs(loss_log_dir, exist_ok=True)
    with open(loss_log_path, "a") as log_file:
        log_file.write(f"{random_seed_ints[i]},{cross_entropy}\n")

# Finally, write the first_100% records
first_100_path = os.path.join(loss_log_dir, "first_100_acc_test_loss_records.txt")
with open(first_100_path, "a") as f:
    for i in range(num_models):
        if first_100_test_loss[i] is not None:
            f.write(f"{random_seed_ints[i]},{first_100_test_loss[i]},{first_100_cross_entropy_loss[i]}\n")

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict


n = p                               # grid size
a_vals = np.arange(n)
b_vals = np.arange(n)
a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")
full_inputs = np.stack([a_grid.ravel(), b_grid.ravel()], axis=1).astype(np.int32)

def _extract_hook_pre(params, x_batch, layer: int = 1):
    suffix = f"hook_pre{layer}"

    # 1) Run and grab intermediates
    _, inter = model.apply(
        {"params": params},
        x_batch,
        mutable=["intermediates"],
        training=False,
    )
    ints = inter["intermediates"]

    # 2) DEBUG dump
    print(f"\n[debug] searching for '{suffix}' in your intermediates tree:")
    def _dump_tree(d, indent=0):
        for k, v in d.items():
            if isinstance(v, dict):
                print(" " * indent + f"{k}/")
                _dump_tree(v, indent + 2)
            elif isinstance(v, list):
                shape = getattr(v[0], "shape", None) if v else None
                print(" " * indent + f"{k} -> list(len={len(v)}), item.shape={shape}")
            else:
                shape = getattr(v, "shape", None)
                print(" " * indent + f"{k}: array.shape={shape}")
    _dump_tree(ints)

    # 3) Recursive finder that unwraps both lists and one‐key dicts
    def _find_hook(d):
        if isinstance(d, dict):
            for k, v in d.items():
                # if this key ends in our suffix, unwrap v:
                if k.endswith(suffix):
                    if isinstance(v, list):
                        return v[0]
                    if isinstance(v, dict):
                        # assume exactly one entry mapping full_name -> array
                        return next(iter(v.values()))
                    return v
                # else recurse deeper
                found = _find_hook(v)
                if found is not None:
                    return found
        return None

    arr = _find_hook(ints)
    if arr is None:
        raise KeyError(
            f"Couldn't find any key ending in '{suffix}'.\n"
            f"(Top-level keys were: {list(ints.keys())})"
        )

    # 4) Final sanity check / conversion
    if not hasattr(arr, "shape"):
        raise RuntimeError(f"Expected an array for '{suffix}', but got {type(arr)}")
    print(f"[debug] found '{suffix}' with shape {arr.shape}")
    return np.array(jax.device_get(arr))



def _fit_cosine_product(grid_true, freq, n):
    """
    Fits grid_true using term1 = cos(f*(a+b))*cos(f*c) and term2 = sin(f*(a+b))*sin(f*c),
    but if freq == n/2 on an even n, does a single-column parity fit instead.
    Returns
    -------
    grid_fit : (n,n) reconstruction
    coeffs   : array of length 1 or 2
    """
    # coords
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    C = (A + B) % n
    y = grid_true.ravel()

    # Nyquist parity special case
    if (n % 2 == 0) and (freq == n // 2):
        # parity on (a+b) mod n  ↦  (-1)^(a+b)
        P = ((-1) ** (A + B)).ravel().astype(float)
        # safe‐standardize
        mu, sigma = P.mean(), P.std()
        sigma = sigma if sigma > 0 else 1.0
        Pn = (P - mu) / sigma
        # one‐coef least squares
        c, *_ = np.linalg.lstsq(Pn[:, None], y, rcond=None)
        grid_fit = (Pn * c).reshape(n, n)
        bias = np.mean(y - grid_fit.ravel())
        return grid_fit + bias, c

    # otherwise original 2-term fit
    f_scaled = 2 * np.pi * freq / n
    term1 = np.cos(f_scaled * (A + B)) * np.cos(f_scaled * C)
    term2 = np.sin(f_scaled * (A + B)) * np.sin(f_scaled * C)
    X = np.column_stack([term1.ravel(), term2.ravel()])

    # safe‐standardize (never divide by zero)
    means = X.mean(axis=0)
    stds  = X.std(axis=0)
    stds_safe = np.where(stds == 0, 1.0, stds)
    Xn = (X - means) / stds_safe

    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())
    return grid_fit + bias, coeffs


# def _fit_with_sin_cos(grid_true, fa, fb, n):
#     a_vals = np.arange(n)
#     b_vals = np.arange(n)
#     a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")

#     sin_a = np.sin(2 * np.pi * fa * a_grid / n)
#     cos_a = np.cos(2 * np.pi * fa * a_grid / n)
#     sin_b = np.sin(2 * np.pi * fb * b_grid / n)
#     cos_b = np.cos(2 * np.pi * fb * b_grid / n)

#     X = np.column_stack([sin_a.ravel(), cos_a.ravel(), sin_b.ravel(), cos_b.ravel()])
#     y = grid_true.ravel()
#     X_mean, X_std = X.mean(axis=0), X.std(axis=0)
#     X_norm = (X - X_mean) / X_std

#     coeffs, *_ = np.linalg.lstsq(X_norm, y, rcond=None)
#     grid_fit = (X_norm @ coeffs).reshape(n, n)
#     return grid_fit, coeffs


def _fit_sequential(
    grid_true: np.ndarray,
    fab1: int,
    fab2: int,
    fa: int,
    fb: int,
    n: int,
    p: int,
):
    """
    Sequential fit:
      1. Fit the (a + b) terms only,
      2. Fit the (a - b) terms on the residual,
      3. Fit the axis‐aligned sine/cosine terms on the remaining residual.

    Returns:
      grid_full  –  [n,n] full reconstructed grid,
      coeffs_sum –  (4,) coefficients for the (a + b) terms,
      coeffs_diff – (4,) coefficients for the (a - b) terms,
      coeffs_sep –  (4,) coefficients for the a-only and b-only terms,
      intermediate_grids – dict with 'sum', 'diff', 'residual', and partial reconstructions.
    """
    a_vals = np.arange(n)
    b_vals = np.arange(n)
    a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")
    ab_grid = a_grid + b_grid
    ab_diff = a_grid - b_grid

    # frequencies (scaled to radians)
    w1 = 2 * np.pi * fab1 / p
    w2 = 2 * np.pi * fab2 / p
    wa = 2 * np.pi * fa / n
    wb = 2 * np.pi * fb / n

    y_true = grid_true.ravel()

    # ----- Stage 1: Fit (a + b) terms -----
    basis_sum = np.column_stack([
        np.cos(w1 * ab_grid).ravel(),
        np.sin(w1 * ab_grid).ravel(),
        np.cos(w2 * ab_grid).ravel(),
        np.sin(w2 * ab_grid).ravel(),
    ])
    mean1, std1 = basis_sum.mean(0), basis_sum.std(0)
    X1 = (basis_sum - mean1) / std1

    coeffs_sum, *_ = np.linalg.lstsq(X1, y_true, rcond=None)
    grid_sum = (X1 @ coeffs_sum).reshape(n, n)
    residual1 = y_true - grid_sum.ravel()

    # ----- Stage 2: Fit (a - b) terms -----
    basis_diff = np.column_stack([
        np.cos(w1 * ab_diff).ravel(),
        np.sin(w1 * ab_diff).ravel(),
        np.cos(w2 * ab_diff).ravel(),
        np.sin(w2 * ab_diff).ravel(),
    ])
    mean2, std2 = basis_diff.mean(0), basis_diff.std(0)
    X2 = (basis_diff - mean2) / std2

    coeffs_diff, *_ = np.linalg.lstsq(X2, residual1, rcond=None)
    grid_diff = (X2 @ coeffs_diff).reshape(n, n)
    residual2 = residual1 - grid_diff.ravel()

    # ----- Stage 3: Fit axis-aligned terms -----
    basis_sep = np.column_stack([
        np.sin(wa * a_grid).ravel(),
        np.cos(wa * a_grid).ravel(),
        np.sin(wb * b_grid).ravel(),
        np.cos(wb * b_grid).ravel(),
    ])
    mean3, std3 = basis_sep.mean(0), basis_sep.std(0)
    X3 = (basis_sep - mean3) / std3

    coeffs_sep, *_ = np.linalg.lstsq(X3, residual2, rcond=None)
    grid_sep = (X3 @ coeffs_sep).reshape(n, n)

    # Final combined reconstruction
    grid_full = grid_sum + grid_diff + grid_sep

    intermediate_grids = {
        "sum": grid_sum,
        "diff": grid_diff,
        "sep": grid_sep,
        "residual": residual2.reshape(n, n)
    }

    return grid_full, coeffs_sum, coeffs_diff, coeffs_sep, intermediate_grids


import numpy as np
from collections import defaultdict
# (Other necessary modules such as jax, plotly, etc., are assumed to be imported elsewhere.)

def fit_conjecture_direct(grid_true, n, fa, fb):
    """
    Original: builds 4 sin/cos basis for (a±b) at freqs fa,fb.
    New: if either fa or fb is n/2 on even n, collapses to parity basis
         [(-1)^(a+b), (-1)^(a-b)] only.
    Returns
    -------
    dict with keys "grid", "coeffs", "labels", "r2", "mse", "bias", "fa", "fb"
    """
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    y = grid_true.ravel()

    # special‐case any Nyquist
    if (n % 2 == 0) and (fa == n//2 or fb == n//2):
        # parity of sum & parity of difference
        sump  = ((-1) ** (A + B)).ravel().astype(float)
        diffp = np.where(((A - B) % 2) == 0, 1.0, -1.0).ravel()
        X = np.column_stack([sump, diffp])
        labels = ["parity(a+b)", "parity(a−b)"]

        # drop any zero‐variance columns
        vars_ = X.var(axis=0)
        mask  = vars_ > 0
        X     = X[:, mask]
        labels = [lab for lab, ok in zip(labels, mask) if ok]

        # safe‐standardize
        mu   = X.mean(axis=0)
        sig  = X.std(axis=0)
        sig  = np.where(sig == 0, 1.0, sig)
        Xn   = (X - mu) / sig

        coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
        grid_fit = (Xn @ coeffs).reshape(n, n)
        bias     = np.mean(y - grid_fit.ravel())

        resid = y - grid_fit.ravel()
        mse   = np.mean(resid**2)
        var   = np.var(y)
        r2    = 1 - mse/(var if var>0 else 1.0)

        return {
            "grid": grid_fit + bias,
            "coeffs": coeffs,
            "labels": labels,
            "bias": bias,
            "mse": mse,
            "r2": r2,
            "fa": fa,
            "fb": fb,
        }

    # otherwise fall back to original 4-term code
    sin_a = np.sin(2*np.pi*fa * A / n)
    cos_a = np.cos(2*np.pi*fa * A / n)
    sin_b = np.sin(2*np.pi*fb * B / n)
    cos_b = np.cos(2*np.pi*fb * B / n)

    basis_sin_ab   = (sin_a*cos_b + cos_a*sin_b).ravel()
    basis_sin_amab = (sin_a*cos_b - cos_a*sin_b).ravel()
    basis_cos_ab   = (cos_a*cos_b - sin_a*sin_b).ravel()
    basis_cos_amab = (cos_a*cos_b + sin_a*sin_b).ravel()

    X_basis = np.column_stack([
        basis_sin_ab,
        basis_sin_amab,
        basis_cos_ab,
        basis_cos_amab
    ])
    labels = ["sin(a+b)", "sin(a−b)", "cos(a+b)", "cos(a−b)"]

    # standardize & fit exactly as before
    mu = X_basis.mean(axis=0)
    sig = X_basis.std(axis=0)
    sig_safe = np.where(sig==0, 1.0, sig)
    Xn = (X_basis - mu)/sig_safe

    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())

    resid = y - grid_fit.ravel()
    mse   = np.mean(resid**2)
    var   = np.var(y)
    r2    = 1 - mse/(var if var>0 else 1.0)

    return {
        "grid": grid_fit + bias,
        "coeffs": coeffs,
        "labels": labels,
        "bias": bias,
        "mse": mse,
        "r2": r2,
        "fa": fa,
        "fb": fb,
    }


def _fit_with_sin_cos(grid_true, fa, fb, n):
    """
    Fits grid_true with either:
      • full sin/cos basis (4 columns), or
      • a parity basis (1 or 2 columns) if fa or fb == n/2 on even n.

    Returns
    -------
    grid_fit : np.ndarray of shape (n,n)
    coeffs   : np.ndarray of length 1–4
    """
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    y = grid_true.ravel()

    # parity branch if we hit the Nyquist freq
    if n % 2 == 0 and (fa == n//2 or fb == n//2):
        cols = []
        # parity(a+b)?
        if fa == n//2:
            parity_sum = np.where(((A + B) % 2) == 0, 1.0, -1.0).ravel()
            cols.append(parity_sum)
        # parity(a−b)?
        if fb == n//2:
            parity_diff = np.where(((A - B) % 2) == 0, 1.0, -1.0).ravel()
            cols.append(parity_diff)

        X = np.column_stack(cols)
    else:
        # original sin/cos basis
        sin_a = np.sin(2*np.pi*fa * A / n).ravel()
        cos_a = np.cos(2*np.pi*fa * A / n).ravel()
        sin_b = np.sin(2*np.pi*fb * B / n).ravel()
        cos_b = np.cos(2*np.pi*fb * B / n).ravel()
        X = np.column_stack([sin_a, cos_a, sin_b, cos_b])

    # safe standardize
    mu = X.mean(axis=0)
    sigma = X.std(axis=0)
    sigma_safe = np.where(sigma == 0, 1.0, sigma)
    Xn = (X - mu) / sigma_safe

    # least squares + bias
    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())
    grid_fit += bias

    return grid_fit, coeffs


def fit_conjecture_enhanced(grid_true, n, fa, fb, top_k=3):
    """
    Enhanced additive conjecture fit with top_k second-order fits and top_k first-order fits.

    Parameters:
      grid_true : np.ndarray of shape (n, n) – the true grid.
      n         : int – grid dimension.
      fa, fb    : int – initial frequencies (unused, for backward compatibility).
      top_k     : int – number of second-order and first-order fits to perform.

    Returns a dictionary containing:
      "r2", "mse", "grid", "bias",
      "second_order_details": list of dicts (one per second-order fit),
      "first_order_details": list of dicts (one per first-order fit),
      "coeffs1", "labels": backward-compatible first SO coefficients/labels,
      "res1": residual after second-order fits,
      "res2": residual after first-order fits,
      "res_total": overall residual.
    """
    # Stage 1: top_k second-order fits
    total_fit_so = np.zeros_like(grid_true)
    resid = grid_true.copy()
    second_order_details = []
    for i in range(top_k):
        # compute dominant freqs on current residual
        fft2d = np.fft.fft2(resid)
        mag = np.abs(fft2d)
        mag[0, 0] = 0
        fa_i = int(np.argmax(mag.sum(axis=1)[1:(n//2 + 1)]) + 1)
        fb_i = int(np.argmax(mag.sum(axis=0)[1:(n//2 + 1)]) + 1)
        # second-order fit
        fit_i = fit_conjecture_direct(resid, n, fa_i, fb_i)
        grid_fit_i = fit_i["grid"]
        second_order_details.append({
            "fa": fa_i,
            "fb": fb_i,
            "coeffs": fit_i["coeffs"],
            "labels": fit_i["labels"],
            "r2": fit_i["r2"],
            "mse": fit_i["mse"]
        })
        total_fit_so += grid_fit_i
        resid -= grid_fit_i
    res1 = grid_true - total_fit_so

    # Stage 2: top_k first-order sin/cos fits
    total_fit_fo = np.zeros_like(grid_true)
    first_order_details = []
    for i in range(top_k):
        fft2d_r = np.fft.fft2(resid)
        mag_r = np.abs(fft2d_r)
        mag_r[0, 0] = 0
        fa_r = int(np.argmax(mag_r.sum(axis=1)[1:(n//2 + 1)]) + 1)
        fb_r = int(np.argmax(mag_r.sum(axis=0)[1:(n//2 + 1)]) + 1)
        grid_fit_cos, coeffs_cos = _fit_with_sin_cos(resid, fa_r, fb_r, n)
        first_order_details.append({
            "fa": fa_r,
            "fb": fb_r,
            "coeffs": coeffs_cos
        })
        total_fit_fo += grid_fit_cos
        resid -= grid_fit_cos
    res2 = grid_true - (total_fit_so + total_fit_fo)

    # Combine all fits and adjust bias
    grid_total = total_fit_so + total_fit_fo
    y_true = grid_true.ravel()
    bias_total = np.mean(y_true - grid_total.ravel())
    grid_total += bias_total
    res_total = grid_true - grid_total

    # Performance metrics
    var = np.var(y_true)
    mse = np.mean((y_true - grid_total.ravel()) ** 2)
    r2 = 1.0 - mse / (var if var > 0 else 1.0)

    # Backward compatibility: first second-order fit
    coeffs1 = second_order_details[0]["coeffs"]
    labels = second_order_details[0]["labels"]

    return {
        "r2": r2,
        "mse": mse,
        "grid": grid_total,
        "bias": bias_total,
        "second_order_details": second_order_details,
        "first_order_details": first_order_details,
        "coeffs1": coeffs1,
        "labels": labels,
        "res1": res1,
        "res2": res2,
        "res_total": res_total,
    }

def compute_injected_accuracy(params, fitted_pre, x_all, y_all, layer_idx):
    suffix = f"hook_pre{layer_idx}"

    # 1) Capture the current intermediates
    _, inter = model.apply(
        {"params": params},
        x_all,
        mutable=["intermediates"],
        training=False,
    )
    ints = inter["intermediates"]

    # 2) Find the full key-path to your hook_preX, handling dict/list/array leaves
    def find_hook_path(d, path):
        if isinstance(d, dict):
            for k, v in d.items():
                new_path = path + [k]
                # If this key ends in our suffix, dive one final level if needed
                if k.endswith(suffix):
                    if isinstance(v, list):
                        return new_path
                    if isinstance(v, dict):
                        # assume exactly one entry in that dict
                        inner = next(iter(v.keys()))
                        return new_path + [inner]
                    # v is already an array
                    return new_path
                # otherwise recurse
                res = find_hook_path(v, new_path)
                if res:
                    return res
        return None

    hook_path = find_hook_path(ints, [])
    if not hook_path:
        raise KeyError(f"No intermediate ending in '{suffix}' found. Keys seen:\n{list(ints.keys())}")

    print(f"[debug] will inject into path: {' ➔ '.join(hook_path)}")

    # 3) Build the minimal intermediates override dict
    new_int = fitted_pre
    # we want the final (deepest) key to map to a list-of-arrays
    for key in reversed(hook_path):
        if key == hook_path[-1]:
            new_int = { key: [new_int] }
        else:
            new_int = { key: new_int }
    new_intermediates = new_int

    # 4) Rerun with our override
    vars2 = freeze({"params": params, "intermediates": new_intermediates})
    logits_inj = model.apply(vars2, x_all, training=False)
    preds = jnp.argmax(logits_inj[:, -1, :], axis=-1)
    return jnp.mean(preds == y_all)


def batched_gradient_similarity(
        *, model, params,
        a_batch: jnp.ndarray,
        b_batch: jnp.ndarray,
        c_batch: jnp.ndarray
    ) -> jnp.ndarray:
    """
    Computes the cosine-similarity of (∂Q/∂E_a , ∂Q/∂E_b) for all triples in
    the three equal-length 1-D arrays a_batch, b_batch, c_batch.
    Gradients are taken w.r.t. the *sum* of token-embedding and
    positional-embedding, exactly like the authors’ PyTorch code.
    """

    # ── take gradient w.r.t. embedding *after* positional add ────────
    emb, _ = model.extract_embeddings_ab(params)           # (p , d_model)
    pos0, pos1 = params["pos_embed"]["W_pos"][:2]          # (d_model,)

    def scalar_logit(ea_plus_pos, eb_plus_pos, cls):
        seq = jnp.stack([ea_plus_pos, eb_plus_pos])[None, ...]   # (1,2,d)
        logits = model.call_from_embedding_sequence(seq, params)[0]
        return logits[-1, cls]                                   # scalar

    grad_a = jax.grad(scalar_logit, argnums=0)
    grad_b = jax.grad(scalar_logit, argnums=1)

    # add the position vectors now ↴
    vec_a = emb[a_batch]        # (N , d_model)
    vec_b = emb[b_batch]        # (N , d_model)

    g_a = jax.vmap(grad_a)(vec_a, vec_b, c_batch)          # (N , d_model)
    g_b = jax.vmap(grad_b)(vec_a, vec_b, c_batch)

    dot   = jnp.sum(g_a * g_b, axis=1)
    norms = (jnp.linalg.norm(g_a, axis=1) *
             jnp.linalg.norm(g_b, axis=1) + 1e-12)
    return dot / norms  

def _filter_neurons_by_max(mat: np.ndarray, thr: float = 1e-2) -> tuple[np.ndarray, np.ndarray]:
    """
    Keep only columns (neurons) whose max activation over the p^2 grid >= thr.
    mat: (p^2, num_neurons)
    returns: (filtered_mat, keep_mask)
    """
    if mat.ndim != 2:
        raise ValueError(f"Expected 2D matrix (p^2, N), got shape {mat.shape}")
    keep = (np.max(mat, axis=0) >= thr)
    if not np.any(keep):
        # Return a correctly shaped empty matrix and the mask
        return mat[:, :0], keep
    return mat[:, keep], keep


def get_all_preacts_and_embeddings(
    *,
    neuron_data: dict[int, dict[int, dict]],
    dominant_freq_clusters,              # list[dict] or {freq: [ids]}
    params: dict,
):
    """
    Returns
    -------
    embeddings            : np.ndarray                         # (p , d_model)
    layer_preacts         : list[dict[str, np.ndarray]]        # as before
    cluster_contribs_logits: dict[str, np.ndarray]             # (p² , p) – after W_U
    cluster_contribs_dmodel: dict[str, np.ndarray]             # (p² , d_model) – before W_U
    """
    import numpy as np
    import jax.numpy as jnp

    # ─────────────────────────── 0. constants & helpers ──────────────────────────
    last_layer_idx   = max(neuron_data)
    last_block_key   = f'blocks_0'            # we only have 1 block

    W_out = np.array(params[last_block_key]['mlp']['W_out'])   # (d_model , d_mlp)
    W_U   = np.array(params['W_U'])                            # (d_model , p)

    eff_W_logits  = W_out.T @ W_U                              # (d_mlp , p)
    eff_W_dmodel  = W_out.T                                    # (d_mlp , d_model)

    # ─────────────────────────── 1. token embeddings ────────────────────────────
    W_E, _ = model.extract_embeddings_ab(params)               # (p , d_model)
    embeddings = np.asarray(W_E)

    # ─────────────────────────── 2. layer-wise pre-acts ─────────────────────────
    layer_preacts: list[dict[str, np.ndarray]] = []

    # allow both the “single-dict” and the “list of dicts” formats you’re using
    if isinstance(dominant_freq_clusters, dict):
        clusters_by_layer = {1: dominant_freq_clusters}
    else:
        clusters_by_layer = {i + 1: d for i, d in enumerate(dominant_freq_clusters)}

    for layer_idx in sorted(neuron_data):
        pre_dict = {}
        for freq_key, ids in clusters_by_layer[layer_idx].items():
            if not ids:
                continue
            cols = []
            for nid in ids:
                v = neuron_data[layer_idx][nid]['real_preactivations'].reshape(-1)  # (p^2,)
                if float(v.max()) >= 1e-2:  # 0.01 threshold
                    cols.append(v)
            if cols:
                mat = np.stack(cols, axis=1)  # (p^2, kept_neurons)
                pre_dict[freq_key] = mat
        layer_preacts.append(pre_dict)

    # ─────────────────────────── 3. cluster → logits  &  cluster → d_model ──────
    cluster_contribs_logits  : dict[str, np.ndarray] = {}
    cluster_contribs_dmodel : dict[str, np.ndarray] = {}

    for freq_key, ids in clusters_by_layer[last_layer_idx].items():
        if not ids:
            continue
        kept_cols = []
        kept_ids  = []
        for nid in ids:
            post = neuron_data[last_layer_idx][nid].get(
                "postactivations",
                np.maximum(neuron_data[last_layer_idx][nid]["real_preactivations"], 0.0)
            ).reshape(-1)  # (p^2,)
            if float(post.max()) >= 1e-2:  # 0.01 threshold on *activations*
                kept_cols.append(post)
                kept_ids.append(nid)

        if kept_cols:
            post_mat = np.stack(kept_cols, axis=1)  # (p^2, kept_neurons)

            # IMPORTANT: index the effective weights with the *kept* neuron ids
            W_logits_sub = eff_W_logits[kept_ids, :]   # (kept_neurons, p)
            cluster_contribs_logits[freq_key]  = post_mat @ W_logits_sub

            W_dmodel_sub = eff_W_dmodel[kept_ids, :]   # (kept_neurons, d_model)
            cluster_contribs_dmodel[freq_key] = post_mat @ W_dmodel_sub

    return embeddings, layer_preacts, cluster_contribs_logits, cluster_contribs_dmodel

def detach_to_numpy(tree):
    """
    Return a genuine NumPy copy of every leaf.
    • jax.device_get()  → bring to host
    • np.array(..., copy=True)  → own the buffer
    """
    return jax.tree_util.tree_map(
        lambda x: np.array(jax.device_get(x), copy=True),
        tree,
    )

# --- Main analysis loop ---
threshold = 0.00001
# collect layer‑1 fits per model
layer1_fits_all = [[] for _ in range(len(random_seed_ints))]

for m_i, seed in enumerate(random_seed_ints):
    trained_params_i = jax.tree_util.tree_map(lambda x: x[m_i].copy(), states.params)
    params_i_safe = detach_to_numpy(trained_params_i)
    check_unembed_rank(params_i_safe, label=f"seed {seed} – analysis-copy")
    print(f"\n===== Seed {seed} =====")
    # extract this model’s params
    # params_i = jax.tree_util.tree_map(lambda x: x[m_i], states.params)
    params_i = params_i_safe
    W_U = np.array(params_i['W_U'])                   # or params_top['W_U'] if using block param split
    s = np.linalg.svd(W_U, compute_uv=False)
    rank = np.sum(s > 1e-6)
    W = np.array(params_i['W_U'])          # (128, 59)

    # don't trust Float32 SVD – promote to float64
    s  = np.linalg.svd(W.astype(np.float64), compute_uv=False)
    rel = s / s.max()                    # relative magnitudes

    print("top 10 σ:", s[:10])
    print("relative :", rel[:10])
    print("effective rank(1e-8):", (rel > 1e-8).sum())

    c0 = W[:, 0]
    c1 = W[:, 1]

    # angle between them
    angle = np.rad2deg(np.arccos( np.dot(c0, c1) /
                                (np.linalg.norm(c0)*np.linalg.norm(c1)) ))
    print(f"angle: {angle}")    

    print("old: W_U shape:", W_U.shape)
    print("Singular values:", s[:10])              # show top 10 for context
    print("Estimated rank(W_U):", rank)
    # np.set_printoptions(threshold=np.inf, linewidth=200, suppress=True)
    print(W_U)

    deep_layer_grids = [] 
    layer_summaries = {}

    # Analyze each layer
    for layer in range(1, NUM_MLP_LAYERS + 1):
        is_deepest = (layer == NUM_MLP_LAYERS)
        print(f"\n[analyzing] Seed {seed} · Layer {layer}")
        # 1) Extract true pre-activations
        pre_all = _extract_hook_pre(params_i, full_inputs, layer)
        pre_tok1 = pre_all[:, 1, :]
        d_mlp = pre_tok1.shape[-1]

        freq_r2s = defaultdict(list)
        freq_fit_type = {}

        # Fit each neuron
        for neuron in range(d_mlp):
            grid_true = pre_tok1[:, neuron].reshape(n, n)

            # first-order candidate freqs
            fft2d = np.fft.fft2(grid_true)
            mag = np.abs(fft2d)
            mag[0, 0] = 0
            fa_raw = int(np.argmax(mag.sum(axis=1)[1:(n//2 + 1)]) + 1)
            fb_raw = int(np.argmax(mag.sum(axis=0)[1:(n//2 + 1)]) + 1)

            # best sin/cos fit
            best_sc = {"r2": -np.inf}
            for candidate_fa in (fa_raw, n - fa_raw):
                for candidate_fb in (fb_raw, n - fb_raw):
                    grid_sc, _ = _fit_with_sin_cos(grid_true, candidate_fa, candidate_fb, n)
                    y_true = grid_true.ravel()
                    y_pred = grid_sc.ravel()
                    var = float(np.var(y_true))
                    mse = float(np.mean((y_true - y_pred) ** 2))
                    r2 = 1 - mse / (var if var > 0 else 1.0)
                    bias = np.mean(y_true - y_pred)
                    ypb = y_pred + bias
                    mse_b = float(np.mean((y_true - ypb) ** 2))
                    r2_b = 1 - mse_b / (var if var > 0 else 1.0)
                    if r2_b > r2:
                        r2, grid_sc = r2_b, grid_sc + bias
                    if r2 > best_sc["r2"]:
                        best_sc = {"grid": grid_sc, "r2": r2, "fa": candidate_fa, "fb": candidate_fb}

            # Store layer-1 fits
            if layer == 1:
                layer1_fits_all[m_i].append({
                    "fa": best_sc["fa"],
                    "fb": best_sc["fb"],
                    "grid": best_sc["grid"]
                })

            # Layer-2 additional fits
            if layer != 1:
                # cosine-product fit
                grid_cp, _ = _fit_cosine_product(grid_true, best_sc["fa"], p)
                y_t = grid_true.ravel()
                y_cp = grid_cp.ravel()
                var = float(np.var(y_t))
                mse_cp = float(np.mean((y_t - y_cp) ** 2))
                r2_cp = 1 - mse_cp / (var if var > 0 else 1.0)
                bias_cp = np.mean(y_t - y_cp)
                ycpb = y_cp + bias_cp
                mse_cpb = float(np.mean((y_t - ycpb) ** 2))
                r2_cpb = 1 - mse_cpb / (var if var > 0 else 1.0)
                if r2_cpb > r2_cp:
                    r2_cp, grid_cp = r2_cpb, grid_cp + bias_cp

                # compute dynamic top_k
                unique_freqs = set()
                for entry in layer1_fits_all[m_i]:
                    unique_freqs.update([entry["fa"], entry["fb"]])
                top_k = len(unique_freqs)

                # enhanced conjecture with top_k
                conj = fit_conjecture_enhanced(
                    grid_true, n, best_sc["fa"], best_sc["fb"], top_k=top_k
                )
                grid_conj, r2_conj = conj["grid"], conj["r2"]

            # choose best fit
            best_r2 = best_sc["r2"]
            best_type = "sin_cos"
            grid_best = best_sc["grid"]

            if layer != 1:
                if r2_cp > best_r2:
                    best_r2, best_type, grid_best = r2_cp, "cos_product", grid_cp
                if r2_conj > best_r2:
                    best_r2, best_type, grid_best = r2_conj, "conjecture", grid_conj
            if is_deepest:
                deep_layer_grids.append(grid_best)

            freq_r2s[(best_sc["fa"], best_sc["fb"])].append(best_r2)
            freq_fit_type[(best_sc["fa"], best_sc["fb"])] = best_type

        # Determine top_k for summary
        if is_deepest:
            unique_freqs = set()
            for entry in layer1_fits_all[m_i]:
                unique_freqs.update([entry["fa"], entry["fb"]])
            top_k = len(unique_freqs)
        else:
            top_k = None

        # Build summary dict
        summary = {
            "test_accuracy": float(final_test_accuracies[m_i]),
            "top_k": top_k
        }

        # manual forward debug (unchanged)
        embed_mod = Embed(d_vocab=p, d_model=transformer_config['d_model'])
        pos_mod = PosEmbed(max_ctx=2, d_model=transformer_config['d_model'])
        attn_mod = Attention(
            d_model=transformer_config['d_model'], num_heads=transformer_config['num_heads'], d_head=transformer_config['d_head'],
            n_ctx=transformer_config['n_ctx'], attn_coeff=transformer_config['attn_coeff']
        )

        x_debug = full_inputs[:16]
        y_debug = (x_debug[:, 0] + x_debug[:, 1]) % p

        logits_ref = model.apply({'params': params_i}, x_debug)
        preds_ref = jnp.argmax(logits_ref[:, -1, :], axis=-1)

        def manual_forward(params, x):
            x_emb  = embed_mod.apply({'params': params['embed']}, x)
            x_emb  = pos_mod.apply({'params': params['pos_embed']}, x_emb)
            att    = attn_mod.apply({'params': params['blocks_0']['attn']}, x_emb)
            resid  = x_emb + att           # after attention

            h = resid
            for i in range(NUM_MLP_LAYERS):
                W = params['blocks_0']['mlp'][f'W_{i}']
                b = params['blocks_0']['mlp'][f'b_{i}']
                pre = jnp.einsum("md,bpd->bpm", W, h) + b
                h   = jax.nn.relu(pre)
            out = jnp.einsum("dm,bpm->bpd",
                            params['blocks_0']['mlp']['W_out'], h) + params['blocks_0']['mlp']['b_out']
            resid = resid + out
            return jnp.einsum("dm,bpd->bpm", params['W_U'], resid)

        preds_man = jnp.argmax(manual_forward(params_i, x_debug)[:, -1, :], axis=-1)
        print("manual == reference?", jnp.all(preds_man == preds_ref))
        print("debug accuracy:", jnp.mean(preds_ref == y_debug))

        # compute fake-injection accuracy
        y_full = (full_inputs[:, 0] + full_inputs[:, 1]) % p
        real_pre = pre_all[:, 1, :]
        fake_acc = compute_injected_accuracy(
            params_i, real_pre, full_inputs, y_full, layer_idx=layer
        )
        print(f"▶️ fake-injection acc: {fake_acc:.2%}")

        # compute and record injected accuracy for layer 2
        if is_deepest:
            fitted_pre = jnp.stack([g.reshape(-1) for g in deep_layer_grids], axis=1)
            inj_acc = compute_injected_accuracy(params_i, fitted_pre, full_inputs, y_full, layer_idx=layer)
            print(f"▶️ injected test accuracy: {inj_acc:.2%}")
            summary["injected_test_accuracy"] = float(inj_acc)

        # record freq stats
        for (fa, fb), r2_list in sorted(freq_r2s.items(), key=lambda x: -np.mean(x[1])):
            summary[f"{fa},{fb}"] = {
                "avg_r2": np.mean(r2_list),
                "count": len(r2_list),
                "fit_type": freq_fit_type[(fa, fb)]
            }

        layer_summaries[layer] = summary

    # ─── Propagate the deepest-layer injected accuracy upward ───
    deep_key = NUM_MLP_LAYERS                              # last (deepest) hidden layer
    deep_acc = layer_summaries.get(deep_key, {}).get("injected_test_accuracy")

    if deep_acc is not None:
        # Copy to every shallower layer
        for l in range(1, deep_key):                      # 1-based indexing
            layer_summaries[l]["injected_test_accuracy"] = deep_acc

    # right after building layer_summaries, before dumping JSON:
    # for each layer, count only keys that look like "number,number"
    THRESH = 4  
    unique_freqs_before_comma = []   # one entry per layer
    unique_freqs_after_comma  = []   # one entry per layer
    min_unique_freqs          = []   # the min() you ultimately want

    for L in sorted(layer_summaries):
        # keep only "int,int" keys **and** whose count > THRESH
        keys = [
            k for k in layer_summaries[L]
            if re.fullmatch(r"\d+,\d+", k)
            and layer_summaries[L][k]["count"] > THRESH
        ]

        before = {int(k.split(",")[0]) for k in keys}
        after  = {int(k.split(",")[1]) for k in keys}

        unique_freqs_before_comma.append(len(before))
        unique_freqs_after_comma.append(len(after))
        min_unique_freqs.append(min(len(before), len(after)))

    margin_log_dir = f"{BASE_DIR}/margins_on_pizza-9"
    os.makedirs(margin_log_dir, exist_ok=True)
    
    file_path_name = os.path.join(margin_log_dir, f"seed_{seed}_loss_layer_freq_counts.txt")
    with open(file_path_name, "w") as f:
        f.write(
            f"{seed},"
            f"{cross_entropies[seed]},"
            + ",".join(str(c) for c in min_unique_freqs)
            + "\n"
        )
    print(f"created {file_path_name}")

    # Write out JSON for both layers
    for layer, summary in layer_summaries.items():  
        pth = os.path.join(
            loss_log_dir,
            f"all_preactivations_L{layer}_seed{seed}.json"
        )
        with open(pth, 'w') as jf:
            json.dump(summary, jf, indent=2)
        print(f"[analysis] saved → {pth}")


    

    # ──────────────────────────────────────────────────────────────
    # 0)  imports we still need here
    # ──────────────────────────────────────────────────────────────
    import collections, json, os, functools
    from collections import Counter
    import numpy as np
    import jax, jax.numpy as jnp


    # ──────────────────────────────────────────────────────────────
    # 1)  helper – effective weights of the **last MLP layer**
    #     into the final logits   (W_outᵀ · W_U)
    # ──────────────────────────────────────────────────────────────
    def _effective_final_weights(params_block, params_top) -> np.ndarray:
        """
        params_block : the *last* TransformerBlock params  (e.g. params['blocks_0'])
        params_top   : the root-level params  (contains 'W_U')

        Returns
        -------
        eff_W : (d_mlp , p)   weight that multiplies the post-ReLU neuron
                            activations of the *deepest* MLP layer and goes
                            straight to the p logits.
        """
        W_out = np.array(params_block['mlp']['W_out'])         # (d_model , d_mlp)
        W_U   = np.array(params_top['W_U'])                    # (d_model , p)
        #   logits = W_Uᵀ · (resid + W_out · h)
        # incremental term for neuron n is  (W_Uᵀ · W_out[:, n]) · hₙ
        return W_out.T @ W_U                                   # (d_mlp , p)

    def _mod_inverse(a: int, p: int) -> int:
        """Modular inverse (p is prime)."""
        return pow(a, p - 2, p)   # Fermat little theorem
    # ──────────────────────────────────────────────────────────────
    # 2)  helper – grab **all pre-activations** we need
    # ──────────────────────────────────────────────────────────────
    def _collect_neuron_data(params_single,
                            *, p: int,
                            num_mlp_layers: int,
                            model) -> (dict, dict):
        """
        Builds the neuron_data and the dominant_freq_clusters
        exactly in the same format as the MLP code.

        Returns
        -------
        neuron_data : { layer_idx -> { neuron_id -> {...} } }
        clusters    : list[dict]  (one dict per layer) mapping freq -> [neuron_ids]
        """
        n = p                       # just a short alias
        a_vals = np.arange(n)
        b_vals = np.arange(n)
        A, B = np.meshgrid(a_vals, b_vals, indexing="ij")
        full_inputs = np.stack([A.ravel(), B.ravel()], axis=-1).astype(np.int32)

        neuron_data  : dict = {}
        freq_clusts  : list = []

        for layer in range(1, num_mlp_layers + 1):
            # ------------- pull out pre-acts for this layer -----------------
            pre_all = _extract_hook_pre(params_single, full_inputs, layer)
            # keep only token-1 (= b index); shape (p² , d_mlp)
            pre_tok1 = np.array(pre_all[:, 1, :])               # to CPU
            d_mlp    = pre_tok1.shape[1]

            neuron_data[layer] = {}
            freq2ids: Dict[str, List[int]] = collections.defaultdict(list)


            for n_id in range(d_mlp):
                grid = pre_tok1[:, n_id].reshape(n, n)          # (p,p)
                post = np.maximum(grid, 0.0)

                # ---- 2-D FFT dominant row/col frequencies -----------------------
                fft  = np.fft.fft2(grid)
                mag  = np.abs(fft);  mag[0, 0] = 0

                fa = int(np.argmax(mag.sum(axis=1)[1:(n//2 + 1)]) + 1)  # row-freq
                fb = int(np.argmax(mag.sum(axis=0)[1:(n//2 + 1)]) + 1)  # col-freq

                # energy along each axis
                energy_a = mag[fa, :].sum()
                energy_b = mag[:, fb].sum()

                # put the dominant one first
                if energy_a >= energy_b:
                    key = f"{fa},{fb}"
                else:
                    key = f"{fb},{fa}"

                freq2ids[key].append(n_id)

                neuron_data[layer][n_id] = {
                    "a_values": a_vals,
                    "b_values": b_vals,
                    "real_preactivations": grid,
                    "postactivations": post,
                }

            # prune empty freqs & store
            freq_clusts.append({f: ids for f, ids in freq2ids.items() if ids})

        return neuron_data, freq_clusts


    # ──────────────────────────────────────────────────────────────
    # 3)  build everything and call the generic tracker
    # ──────────────────────────────────────────────────────────────
    for m_i, seed in enumerate(random_seed_ints):
        print(f"\n[metrics] collecting full metrics for seed {seed}")
        # single-model params → plain lil’ pytree on host
        params_i = jax.tree_util.tree_map(lambda x: x[m_i], states.params)

        # 3.1  neuron-level tensors & frequency clusters
        neuron_data, dominant_freq_clusters = _collect_neuron_data(
            params_i,
            p=p,
            num_mlp_layers=NUM_MLP_LAYERS,
            model=model,
        )

        

        # 3.2  effective weights of deepest MLP → logits
        last_block_key = f'blocks_{transformer_config["num_layers"]-1}'
        eff_W = _effective_final_weights(params_i[last_block_key], params_i)  # (d_mlp, p)
        def _phase_distribution(
            preacts: jnp.ndarray,      # (p , p , N)  real-valued
            threshold: float,
            p: int
        ) -> Counter:
            """
            Returns Counter mapping "phi_a,phi_b" -> count for all neurons whose
            max pre-activation > `threshold`.

            Model assumption for every eligible neuron n                           ⎧
                g_n(a,b) ≈ sin(2π f a/p + 2π φ_a/p)  +  sin(2π f b/p + 2π φ_b/p)  ⎩

            Phase is recovered from the **angle** of the 1-D FFT coefficients at
            the dominant row / column frequencies.  Everything is vectorised and
            runs on GPU –  no Python loops over neurons.
            """
            p_float = float(p)
            N = preacts.shape[-1]                       # number of neurons

            # ---------- keep only “strong” neurons ------------------------------
            max_per_neuron = jnp.max(preacts, axis=(0, 1))        # (N,)
            strong_mask = max_per_neuron > threshold
            if not bool(jnp.any(strong_mask)):
                return Counter()

            pre_strong = jnp.compress(strong_mask, preacts, axis=2)   # (p,p,N’)
            N_strong   = pre_strong.shape[-1]

            # ---------- row / column means  -------------------------------------
            row_mean = jnp.mean(pre_strong, axis=1)          # (p , N’)
            col_mean = jnp.mean(pre_strong, axis=0)          # (p , N’)

            # ---------- FFT along the appropriate axis --------------------------
            fft_row = jnp.fft.fft(row_mean, axis=0)          # (p , N’)
            fft_col = jnp.fft.fft(col_mean, axis=0)          # (p , N’)

            power_row = jnp.abs(fft_row) ** 2
            power_col = jnp.abs(fft_col) ** 2

            # ignore the DC component when searching for the dominant freq
            pos_freq_slice = slice(1, p // 2 + 1)            # 1 … ⌊p/2⌋
            row_slice = power_row[pos_freq_slice, :]
            col_slice = power_col[pos_freq_slice, :]

            fa = jnp.argmax(row_slice, axis=0) + 1           # (N’,)  1-based
            fb = jnp.argmax(col_slice, axis=0) + 1

            # gather the complex coefficients we need
            #    take_along_axis expects indices to have the gather dim present
            coeff_row = jnp.take_along_axis(
                fft_row, fa[None, :], axis=0).squeeze(0)      # (N’,)
            coeff_col = jnp.take_along_axis(
                fft_col, fb[None, :], axis=0).squeeze(0)

            # phase recovery  –  see derivation in the answer text
            phi_a = (-jnp.angle(coeff_row) * p_float) / (2 * jnp.pi * fa.astype(jnp.float32))
            phi_b = (-jnp.angle(coeff_col) * p_float) / (2 * jnp.pi * fb.astype(jnp.float32))

            phi_a_int = jnp.mod(jnp.rint(phi_a), p).astype(jnp.int32)   # 0 … p-1
            phi_b_int = jnp.mod(jnp.rint(phi_b), p).astype(jnp.int32)

            phi_pairs = jnp.stack([phi_a_int, phi_b_int], axis=1)       # (N’,2)
            phi_pairs_np = np.asarray(phi_pairs)

            ctr = Counter()
            for a, b in phi_pairs_np:
                ctr[f"{int(a)},{int(b)}"] += 1
            return ctr
        
        def _phase_distribution_equal_freq(
            preacts: jnp.ndarray,          # (p , p , N)
            threshold: float,
            p: int
        ) -> tuple[Counter, Counter, Counter, Counter, Counter]:
            """
            Three equal‑freq fits + two histogram counters.
            Returns:
            ctr_first:        Counter of phases from first fit
            ctr_second:       Counter of phases from second fit
            ctr_third:        Counter of phases from third fit
            freq_pairs_ctr:   Counter of (f1,f2) frequency‑pairs
            freq_triplets_ctr: Counter of (f1,f2,f3) frequency‑triplets
            """
            p_float = float(p)
            fft_lim = p // 2 + 1

            strong_mask = jnp.max(preacts, axis=(0, 1)) > threshold
            if not bool(jnp.any(strong_mask)):
                return Counter(), Counter(), Counter()

            pre_str = jnp.compress(strong_mask, preacts, axis=2)        # (p,p,N’)
            N_str   = pre_str.shape[-1]

            # ── helper: ONE equal-freq fit ───────────────────────────────────────
            def _single_equal_freq_fit(tensor, avoid_f=None):
                row_m = jnp.mean(tensor, axis=1)
                col_m = jnp.mean(tensor, axis=0)

                fft_r = jnp.fft.fft(row_m, axis=0)
                fft_c = jnp.fft.fft(col_m, axis=0)
                pow_r = jnp.abs(fft_r) ** 2
                pow_c = jnp.abs(fft_c) ** 2

                row_p = pow_r[1:fft_lim, :]
                col_p = pow_c[1:fft_lim, :]

                if avoid_f is not None:
                    # turn avoid_f into a list of per-neuron arrays
                    if avoid_f.ndim == 1:
                        avoids = [avoid_f]
                    else:
                        avoids = [avoid_f[i] for i in range(avoid_f.shape[0])]

                    # build one mask that bans *any* of the listed frequencies
                    rows = jnp.arange(row_p.shape[0])[:, None]   # shape (fft_lim‑1, 1)
                    mask = sum(rows == (af - 1)[None, :] for af in avoids) > 0

                    # apply it
                    row_p = jnp.where(mask, -1.0, row_p)
                    col_p = jnp.where(mask, -1.0, col_p)

                f_sel = jnp.argmax(row_p + col_p, axis=0) + 1           # (N’,)

                coeff_r = jnp.take_along_axis(fft_r, f_sel[None, :], axis=0).squeeze(0)
                coeff_c = jnp.take_along_axis(fft_c, f_sel[None, :], axis=0).squeeze(0)

                phi_a = (-jnp.angle(coeff_r) * p_float) / (2 * jnp.pi * f_sel.astype(jnp.float32))
                phi_b = (-jnp.angle(coeff_c) * p_float) / (2 * jnp.pi * f_sel.astype(jnp.float32))

                phi_a_i = jnp.mod(jnp.rint(phi_a), p).astype(jnp.int32)
                phi_b_i = jnp.mod(jnp.rint(phi_b), p).astype(jnp.int32)

                ctr = Counter()
                for a, b in np.asarray(jnp.stack([phi_a_i, phi_b_i], axis=1)):
                    ctr[f"{int(a)},{int(b)}"] += 1

                # build reconstruction for residual
                a_lin = jnp.arange(p)[:, None, None]
                b_lin = jnp.arange(p)[None, :, None]
                two_pi_over_p = 2 * jnp.pi / p
                recon = (jnp.sin(two_pi_over_p * f_sel * a_lin + two_pi_over_p * phi_a_i)
                    + jnp.sin(two_pi_over_p * f_sel * b_lin + two_pi_over_p * phi_b_i))

                return ctr, f_sel, recon

            def build_freq_counter(*freq_arrays: jnp.ndarray) -> Counter[str]:
                """
                Count how often each tuple of frequencies occurs.
                E.g. build_freq_counter(f1, f2)  → Counter of "f1,f2"
                    build_freq_counter(f1, f2, f3) → Counter of "f1,f2,f3"
                """
                ctr = Counter()
                # turn them into plain Python lists of ints
                lists = [np.asarray(arr).reshape(-1).tolist() for arr in freq_arrays]
                for freqs in zip(*lists):
                    key = ",".join(str(int(f)) for f in freqs)
                    ctr[key] += 1
                return ctr
            # ── first fit ────────────────────────────────────────────────────────
            ctr_first, f1, recon1 = _single_equal_freq_fit(pre_str)

            # ── second fit on residual ──────────────────────────────────────────
            residual1 = pre_str - recon1
            ctr_second, f2, recon2 = _single_equal_freq_fit(residual1, avoid_f=f1)
            residual2 = residual1 - recon2
            # avoid both f1 and f2 to force a new dominant freq
            avoid_both = jnp.stack([f1, f2], axis=0)
            ctr_third, f3, _ = _single_equal_freq_fit(residual2, avoid_f=avoid_both)

            # ── frequency-pair counter  -----------------------------------------
            freq_pairs_ctr    = build_freq_counter(f1, f2)
            freq_triplets_ctr = build_freq_counter(f1, f2, f3)

            return ctr_first, ctr_second, ctr_third, freq_pairs_ctr, freq_triplets_ctr
        
        def compute_and_track_quantities(
            *,
            seed: int,
            p: int,
            model,                        # trained DonutMLP (or subclass)
            params: dict,                 # parameters for this seed
            neuron_data: Dict[int, Dict[int, Dict[str, Any]]],
            cluster_groupings: Union[Dict[int, list], list],
            final_layer_weights: np.ndarray,     # shape (num_neurons_last, p)
            save_dir: str = ".",
        ) -> None:
            """
            Writes *quantities_{seed}.json* containing:

            • distribution_of_max_preactivations
            • networks_equivariantness_stats      (correct-logit stats)
            • network_margin_stats                (margin  stats)
            • network_loss_stats                  (per-sample CE-loss stats)   ← NEW
            • clusters_equivariantness_stats      (per-cluster correct-logit stats)
            • clusters_margin_stats               (per-cluster margin stats)
            """

            # ───────────── 1) where does each neuron reach its maximum? ─────────────
            dist_counter: collections.Counter[str] = collections.Counter()
            for layer_dict in neuron_data.values():
                for nd in layer_dict.values():
                    real = np.asarray(nd.get("real_preactivations", []))
                    if real.size:
                        a_idx, b_idx = np.unravel_index(real.argmax(), real.shape)
                        dist_counter[f"{a_idx},{b_idx}"] += 1
            distribution_of_max_preactivations = dict(dist_counter)

            # ───────────── 2) run the whole network on the complete p² grid ─────────
            a_grid, b_grid = np.meshgrid(np.arange(p), np.arange(p), indexing="ij")
            x_full = np.stack([a_grid.ravel(), b_grid.ravel()], axis=-1).astype(jnp.int32)

            logits_full = model.apply({"params": params}, x_full, training=False)
            logits = logits_full[:, -1, :]        # shape (p² , p)
            logits_np = np.asarray(logits)
            correct_idx = ((a_grid + b_grid) % p).ravel()

            correct_logits = logits_np[np.arange(p * p), correct_idx]                 # (p²,)

            # ----- margins -----
            tmp = logits_np.copy()
            tmp[np.arange(p * p), correct_idx] = -np.inf
            second_logits = tmp.max(axis=1)
            margins = correct_logits - second_logits

            # ----- per-sample CE loss (log-softmax trick, row-wise) -----
            row_max = logits_np.max(axis=1, keepdims=True)
            logsumexp = row_max + np.log(np.exp(logits_np - row_max).sum(axis=1, keepdims=True))
            ce_losses = (logsumexp.squeeze() - correct_logits)                         # (p²,)

            networks_equivariantness_stats = {
                "min":  float(correct_logits.min()),
                "max":  float(correct_logits.max()),
                "mean": float(correct_logits.mean()),
                "std":  float(correct_logits.std()),
            }
            network_margin_stats = {
                "avg_margin":     float(margins.mean()),
                "min_margin":     float(margins.min()),
                "max_margin":     float(margins.max()),
                "std_dev_margin": float(margins.std()),
            }
            network_loss_stats = {
                "avg_loss":  float(ce_losses.mean()),
                "min_loss":  float(ce_losses.min()),
                "max_loss":  float(ce_losses.max()),
                "std_dev_loss": float(ce_losses.std()),
            }

            # ───────────── 3) stats for frequency-clusters in last hidden layer ─────
            if isinstance(cluster_groupings, collections.abc.Mapping):
                last_clusters = cluster_groupings            # type: ignore
                last_layer_idx = max(neuron_data)
            else:
                last_clusters = cluster_groupings[-1]
                last_layer_idx = len(cluster_groupings)

            layer_nd = neuron_data[last_layer_idx]
            correct_idx_grid = (a_grid + b_grid) % p                                   # p×p

            clusters_equivariantness_stats = {}
            clusters_margin_stats = {}

            for freq, neuron_ids in last_clusters.items():
                if not neuron_ids:
                    continue

                # build cluster logits: (p, p, p)
                cluster_logits = np.zeros((p, p, p), dtype=float)
                for n in neuron_ids:
                    nd = layer_nd.get(n)
                    if nd is None:
                        continue
                    post = np.asarray(
                        nd.get("postactivations",
                            np.maximum(nd["real_preactivations"], 0.0))
                    )                                           # p×p
                    w_row = final_layer_weights[n]              # p,
                    cluster_logits += post[..., None] * w_row

                # correct-logit stats
                corr = cluster_logits[np.arange(p)[:, None],
                                    np.arange(p)[None, :],
                                    correct_idx_grid]
                corr_flat = corr.ravel()
                clusters_equivariantness_stats[str(freq)] = {
                    "min":  float(corr_flat.min()),
                    "max":  float(corr_flat.max()),
                    "mean": float(corr_flat.mean()),
                    "std":  float(corr_flat.std()),
                }

                # margin stats (for the cluster contribution alone)
                logits_flat = cluster_logits.reshape(p * p, p)
                tmp = logits_flat.copy()
                tmp[np.arange(p * p), correct_idx] = -np.inf
                second = tmp.max(axis=1)
                cluster_margins = corr_flat - second
                clusters_margin_stats[str(freq)] = {
                    "avg_margin":     float(cluster_margins.mean()),
                    "min_margin":     float(cluster_margins.min()),
                    "max_margin":     float(cluster_margins.max()),
                    "std_dev_margin": float(cluster_margins.std()),
                }

            # ───────────── 4) dump everything to JSON ───────────────────────────────
            out = {
                "distribution_of_max_preactivations": distribution_of_max_preactivations,
                "networks_equivariantness_stats":     networks_equivariantness_stats,
                "network_margin_stats":               network_margin_stats,
                "network_loss_stats":                 network_loss_stats,   # ← NEW
                "clusters_equivariantness_stats":     clusters_equivariantness_stats,
                "clusters_margin_stats":              clusters_margin_stats,
            }

            grad_stats, dist_stats = compute_useless_metrics(
                model=model,
                params=params,
                p=p,                    # 59
                rng_seed=42,
                max_samples=p*p         # use the full 59² = 3 481 triples
            )
            out.update(grad_stats)
            out.update(dist_stats)

            distribution_of_center_mass = compute_center_mass_distribution(
                neuron_data=neuron_data,
                dominant_freq_clusters=cluster_groupings,
                p=p,
            )

            out["distribution_of_center_mass"] = distribution_of_center_mass

           # ─────────────  Phase & frequency histograms  ─────────────
            phases_free              = Counter()
            phases_equal_first       = Counter()
            phases_equal_second_fit  = Counter()
            phases_equal_third_fit   = Counter()
            freq_pairs_total         = Counter()
            freq_triplets_total      = Counter()

            for layer_dict in neuron_data.values():
                if not layer_dict:
                    continue
                pre_layer = jnp.stack(
                    [layer_dict[n]["real_preactivations"]
                     for n in sorted(layer_dict)],
                    axis=-1                                   # (p,p,N)
                )

                phases_free += _phase_distribution(
                    pre_layer, 0.01, p)

                ctr_first, ctr_second, ctr_third, ctr_pairs, ctr_triplets = _phase_distribution_equal_freq(
                    pre_layer, 0.01, p)

                phases_equal_first       += ctr_first
                phases_equal_second_fit  += ctr_second
                phases_equal_third_fit   += ctr_third
                freq_pairs_total         += ctr_pairs
                freq_triplets_total      += ctr_triplets
                

            out["distribution_of_phases"]                       = dict(phases_free)
            out["distribution_of_phases_f_a=f_b"]               = dict(phases_equal_first)
            out["distribution_of_phases_f_a=f_b_second_fit"]    = dict(phases_equal_second_fit)
            out["distribution_of_phases_f_a=f_b_third_fit"]     = dict(phases_equal_third_fit)
            out["frequencies_equal"]                            = dict(freq_pairs_total)
            out["frequencies_equal_triplets"]                   = dict(freq_triplets_total)
            

            os.makedirs(save_dir, exist_ok=True)
            path = os.path.join(save_dir, f"quantities_{seed}.json")
            with open(path, "w") as f:
                json.dump(out, f, indent=2)

            print(f"[compute_and_track_quantities] wrote {path}")

        # 3.3  where to save
        transf_eqv_dir = os.path.join(
            BASE_DIR,
            f"{p}_distributions_equivariantness",
            f"transformer_p={p}_bs={batch_size}_k={k}_dm={transformer_config['d_model']}"
            f"_wd={weight_decay}_lr={learning_rate}"
        )
        os.makedirs(transf_eqv_dir, exist_ok=True)

        def _layer_centres_of_mass(
            preacts: jnp.ndarray,   # (p, p, N)
            freqs:   np.ndarray,    # (N,)
            p: int
        ) -> np.ndarray:            # → (N, 2)  [CoM_a , CoM_b]
            """GPU-optimised centre-of-mass in circular coordinates."""
            # --- modular inverses on CPU ------------------------------------
            invs = np.array([_mod_inverse(int(f), p) for f in freqs], dtype=np.int32)

            a_idx = jnp.arange(p, dtype=jnp.float32)           # (p,)
            b_idx = jnp.arange(p, dtype=jnp.float32)           # (p,)

            @jax.jit
            def _com(act_3d, invs_1d):
                invs_b = invs_1d.astype(jnp.float32)[None, None, :]    # (1,1,N)

                # linear indices → angles; then straighten by invs_b
                ang_a = (2 * jnp.pi * a_idx[:, None, None] / p) * invs_b   # (p,1,N)
                ang_b = (2 * jnp.pi * b_idx[None, :, None] / p) * invs_b   # (1,p,N)

                # use absolute activation as weight  (keeps both peaks)
                w   = jnp.abs(act_3d)                          # (p,p,N)  ≥0
                vec_a = jnp.sum(w * jnp.exp(1j * ang_a), axis=(0, 1))  # (N,)
                vec_b = jnp.sum(w * jnp.exp(1j * ang_b), axis=(0, 1))  # (N,)

                # circular mean → angle in [0, 2π)
                ang_com_a = (jnp.angle(vec_a) + 2 * jnp.pi) % (2 * jnp.pi)
                ang_com_b = (jnp.angle(vec_b) + 2 * jnp.pi) % (2 * jnp.pi)

                com_a = ang_com_a / (2 * jnp.pi) * p           # back to 0…p
                com_b = ang_com_b / (2 * jnp.pi) * p
                return jnp.stack([com_a, com_b], axis=1)       # (N,2)

            return np.asarray(_com(preacts, invs))   

        def compute_center_mass_distribution(
            *,
            neuron_data: Dict[int, Dict[int, Dict[str, Any]]],
            dominant_freq_clusters,      # same structure you already use
            p: int,
        ) -> Dict[str, int]:
            """
            Builds the distribution_of_center_mass counter across *all* layers.
            Keys are "a,b" strings with integer-rounded CoM coordinates.
            """
            counter = Counter()

            # iterate layer-wise ---------------------------------------------
            for layer_idx, layer_dict in neuron_data.items():
                # assemble (p,p,N) tensor and parallel freq list -------------
                neuron_ids      = sorted(layer_dict)
                if not neuron_ids:
                    continue
                pre_list        = [layer_dict[n]["real_preactivations"] for n in neuron_ids]
                pre_layer       = np.stack(pre_list, axis=-1)           # (p,p,N)

                # frequencies: look them up from dominant_freq_clusters -----
                if isinstance(dominant_freq_clusters, dict):
                    freq_map = dominant_freq_clusters                  # 1-layer case
                else:
                    freq_map = dominant_freq_clusters[layer_idx - 1]   # list-of-dicts
                freqs = np.array([
                    int(next((k.split(',')[0]          # dominant freq
                            for k, ids in freq_map.items() if n in ids), '1'))
                    for n in neuron_ids
                ], dtype=int)

                coms = _layer_centres_of_mass(jnp.asarray(pre_layer), freqs, p)

                # round to nearest integer grid point ------------------------
                com_int = np.rint(coms).astype(int) % p                 # wrap to 0..p-1
                for a, b in com_int:
                    counter[f"{a},{b}"] += 1

            return dict(counter)
    
        # 3.4  compute & dump – ***same routine as before***
        # compute_and_track_quantities(
        #     seed=seed,
        #     p=p,
        #     model=model,
        #     params=params_i,
        #     neuron_data=neuron_data,
        #     cluster_groupings=dominant_freq_clusters,
        #     final_layer_weights=eff_W,
        #     save_dir=transf_eqv_dir,
        # )

        mlp_class_lower = f"transformer_{transformer_config['attn_coeff']}_{num_mlp_layers}"
        features = transformer_config['d_model']
        PDF_DIR = f"ICLR-appendix-run-7-s-2-250/qualitative_{p}_{mlp_class_lower}_{num_neurons}_features_{features}_k_{k}"

        html_out_dir = os.path.join(PDF_DIR, "cluster_html", f"seed_{seed}")
        make_cluster_html_pages(
            neuron_data=neuron_data,
            clusters=dominant_freq_clusters,   # list[dict] per layer in your code
            layer_idx=NUM_MLP_LAYERS,          # last hidden layer clusters
            p=p,
            out_dir=html_out_dir,
            show_full_fft=False,               # change to True if you want full fftshift view
        )

        pdf_root = os.path.join(PDF_DIR, "pdf_plots", f"seed_{seed}")
        embeddings, layer_preacts, cluster_contribs, cluster_contribs_no_wu = get_all_preacts_and_embeddings(
            neuron_data=neuron_data,
            dominant_freq_clusters=dominant_freq_clusters,
            params=params_i,                 # or any single-model params pytree
        )

        first_dir_path = os.path.join(
            PDF_DIR,
            f"{p}_models_embed_{transformer_config['d_model']}"
            f"p={p}_bs={batch_size}_nn={num_neurons}"
            f"_wd={weight_decay}_epochs={epochs}"
            f"_training_set_size={training_set_size}",
        )
        os.makedirs(first_dir_path, exist_ok=True)

        first100_path = os.path.join(
            first_dir_path,
            f"first100_testacc_seed_{seed}.json",
        )

        info = first_100_summary[m_i]              # m_i is the index of this seed
        if info is not None:                       # wasn’t guaranteed to reach 100 %
            with open(first100_path, "w") as f:
                json.dump(info, f, indent=2)
            print(f"First-100-epoch summary for seed {seed} saved to {first100_path}")
        else:
            print(f"Seed {seed} never hit 100 % test accuracy – no summary written.")
        
        def plot_cluster_logit_heatmap(
            cluster_contribs_logits: dict[str, np.ndarray],
            cluster_key: str,
            logit_idx: int,
            p: int,
            *,
            title: str | None = None,
            colorscale: str = "RdBu",
            symmetric: bool = True,
            show: bool = True,
        ):
            """
            Visualise how one frequency-cluster contributes to a single logit
            over the (a, b) input grid.

            Parameters
            ----------
            cluster_contribs_logits : dict[str, np.ndarray]
                Output of `get_all_preacts_and_embeddings`.  
                Each value is shape `(p², p)` –  rows are flattened (a,b) pairs,
                columns are logits 0…p-1.
            cluster_key : str
                Which cluster to show (must be a key in `cluster_contribs_logits`).
            logit_idx : int
                The logit you care about (0 ≤ logit_idx < p).
            p : int
                Modulus – size of the a,b grid (so the heat-map will be p × p).
            title : str, optional
                Custom Plotly title.  If None, a default is generated.
            colorscale : str
                Any Plotly colourscale – e.g. "Viridis", "Cividis", "RdBu", …
            symmetric : bool
                If True, the colour range is centred on 0 (nice for positive/negative).
            show : bool
                If True (default) call `fig.show()`.  Otherwise just return the Figure.

            Returns
            -------
            plotly.graph_objects.Figure
            """
            # ── 1. fetch & reshape ────────────────────────────────────────────
            if cluster_key not in cluster_contribs_logits:
                raise KeyError(f"{cluster_key!r} not found in cluster_contribs_logits")
            contrib_mat = cluster_contribs_logits[cluster_key]               # (p², p)
            if not (0 <= logit_idx < contrib_mat.shape[1]):
                raise IndexError(f"logit_idx {logit_idx} out of range 0–{contrib_mat.shape[1]-1}")

            flat = contrib_mat[:, logit_idx]                                 # (p²,)
            # The training code flattened with `np.meshgrid(..., indexing='ij')`
            # so rows iterate over *a* and columns over *b*.  We want x=a, y=b,
            # which is the transpose of that layout:
            heat = flat.reshape(p, p).T                                      # (b, a)

            # ── 2. colour limits (optional symmetric) ────────────────────────
            if symmetric:
                vmax = np.abs(heat).max()
                vmin = -vmax
            else:
                vmin, vmax = heat.min(), heat.max()

            # ── 3. build the Plotly figure ───────────────────────────────────
            fig = go.Figure(
                go.Heatmap(
                    x=np.arange(p),          # a-values → x-axis
                    y=np.arange(p),          # b-values → y-axis
                    z=heat,
                    colorscale=colorscale,
                    zmin=vmin,
                    zmax=vmax,
                    colorbar=dict(
                        title=f"Δ logit {logit_idx}",
                        title_side="right"
                    ),
                )
            )

            fig.update_layout(
                title=title or f'Cluster "{cluster_key}" → logit {logit_idx}',
                xaxis_title="a",
                yaxis_title="b",
                yaxis=dict(autorange="reversed"),  # (0,0) bottom-left
                width=500, height=500,
            )

            if show:
                fig.show()

            return fig

        out_dir = os.path.join(pdf_root, "temp_logit_plots")
        os.makedirs(out_dir, exist_ok=True)

        # 2) loop over clusters & desired logits
        for cluster_key, mat in cluster_contribs.items():
            # extract just the numeric freq (assumes keys like "freq_3" or "3")
            try:
                freq = int(cluster_key.split("_")[-1])
            except ValueError:
                freq = cluster_key

            for logit_idx in (30, 31, 32):
                # 2a) plot (but don't auto-show)
                fig = plot_cluster_logit_heatmap(
                    cluster_contribs,
                    cluster_key=cluster_key,
                    logit_idx=logit_idx,
                    p=p,
                    show=False
                )

                # 2b) save as PDF
                filename = f"f={freq}-logit={logit_idx}_cluster.pdf"
                out_path = os.path.join(out_dir, filename)
                fig.write_image(out_path, format="pdf")

                print(f"Saved {out_path}")

        def _best_line_and_freq(mat: np.ndarray, p: int) -> tuple[str, int] | None:
            """
            Look at average-over-columns grid, take 2D FFT, and find the strongest of:
            (0,f)  vertical   → 'axis'
            (f,0)  horizontal → 'axis'
            (f,f)  diagonal   → 'diag'
            Returns ('axis'|'diag', f) or None if everything is zero.
            """
            if mat.shape[0] != p * p:
                raise ValueError("The first dimension must be p².")
            grid = mat.mean(axis=1).reshape(p, p)
            fft2 = np.fft.fft2(grid)
            mag  = np.abs(fft2)
            mag[0, 0] = 0.0

            vert   = mag[0, 1:]          # (0,f)
            horiz  = mag[1:, 0]          # (f,0)
            diag   = np.diag(mag)[1:]    # (f,f)

            f_vert  = int(np.argmax(vert)  + 1); m_vert  = float(vert[f_vert  - 1])
            f_horiz = int(np.argmax(horiz) + 1); m_horiz = float(horiz[f_horiz - 1])
            f_diag  = int(np.argmax(diag)  + 1); m_diag  = float(diag[f_diag  - 1])

            # choose the biggest line; vertical/horizontal are both 'axis'
            candidates = [
                ("axis", m_vert,  f_vert),
                ("axis", m_horiz, f_horiz),
                ("diag", m_diag,  f_diag),
            ]
            kind, val, f = max(candidates, key=lambda t: t[1])
            if val <= 0.0:
                return None
            return kind, f


        def _concat_mats(mat_list: list[np.ndarray]) -> np.ndarray:
            """Horizontally concatenate a list of matrices (or return the single one)."""
            if not mat_list:
                raise ValueError("Empty matrix list.")
            if len(mat_list) == 1:
                return mat_list[0]
            return np.concatenate(mat_list, axis=1)

        # embeddings  →  freq_list_embeds
        freq_set = set()
        for layer_dict in layer_preacts:
            for freq_key in layer_dict:
                fa, fb = map(int, freq_key.split(","))
                freq_set.update((fa, fb))

        freq_list_embeds = sorted(freq_set)

        # generate_pdf_plots_for_matrix(
        #     embeddings,
        #     p,
        #     save_dir=pdf_root,
        #     seed=seed,
        #     freq_list=freq_list_embeds,
        #     tag="embeds",
        #     class_string=mlp_class_lower,
        #     num_principal_components=2,
        # )

        # each MLP layer pre-activations
        for layer_idx, layer_dict in enumerate(layer_preacts, start=1):
            axis_by_f: dict[int, list[np.ndarray]] = {}
            diag_by_f: dict[int, list[np.ndarray]] = {}

            for freq_key, mat in layer_dict.items():
                fa, fb = map(int, freq_key.split(","))

                # If fa == fb we treat it as a diagonal (a+b)–type cluster, else axis-aligned.
                if fa == fb:
                    diag_by_f.setdefault(fa, []).append(mat)
                else:
                    # You already ensured the first entry in the key is the dominant axis.
                    axis_by_f.setdefault(fa, []).append(mat)

            # (0,f) / (f,0) → concatenate & tag “…_second_order”
            for best_f, mats in axis_by_f.items():
                merged = _concat_mats(mats)  # (p^2, total_neurons)
                merged, _ = _filter_neurons_by_max(merged, thr=1e-2)
                if merged.shape[1] < 2:
                    continue 
                print(f"[debug] merged AXIS matrix shape: {merged.shape}")
                if np.linalg.matrix_rank(merged) < 2:
                    continue
                tag = f"layer{layer_idx}_freq={best_f}_second_order"
                generate_pdf_plots_for_matrix(
                    merged, p, save_dir=pdf_root, seed=seed,
                    freq_list=[best_f], tag=tag, class_string=mlp_class_lower,
                    num_principal_components=3,
                )

            # (f,f) → separate concatenation & base tag (no suffix)
            for best_f, mats in diag_by_f.items():
                merged = _concat_mats(mats)  # (p^2, total_neurons)
                merged, _ = _filter_neurons_by_max(merged, thr=1e-2)
                if merged.shape[1] < 2:
                    continue 
                print(f"[debug] merged DIAG matrix shape: {merged.shape}")
                if np.linalg.matrix_rank(merged) < 2:
                    continue
                tag = f"layer{layer_idx}_freq={best_f}"
                generate_pdf_plots_for_matrix(
                    merged, p, save_dir=pdf_root, seed=seed,
                    freq_list=[best_f], tag=tag, class_string=mlp_class_lower,
                    num_principal_components=3,
                )

        # last-layer cluster-to-logit contributions
        def _parse_freq_key(freq_key: str) -> tuple[int, int]:
            # Accept "fa,fb", "freq_fa_fb", "fa_fb", etc.
            nums = list(map(int, re.findall(r"\d+", str(freq_key))))
            if not nums:
                raise ValueError(f"Unrecognised freq_key: {freq_key!r}")
            if len(nums) == 1:
                return nums[0], nums[0]
            return nums[0], nums[1]

        axis_cc: dict[int, list[np.ndarray]] = {}
        diag_cc: dict[int, list[np.ndarray]] = {}

        for freq_key, mat in cluster_contribs.items():
            fa, fb = _parse_freq_key(freq_key)

            # Your clustering guaranteed the *first* entry is the dominant axis
            # (see energy_a >= energy_b). Use that consistently here.
            if fa == fb:
                # diagonal (a+b)-type clusters
                diag_cc.setdefault(fa, []).append(mat)
            else:
                # axis-aligned (“second order”) clusters → group by the dominant axis freq
                axis_cc.setdefault(fa, []).append(mat)

        # axis → “…_second_order”
        for best_f, mats in axis_cc.items():
            merged = _concat_mats(mats)
            print(f"[debug] merged AXIS (cluster) shape: {merged.shape}")
            if np.linalg.matrix_rank(merged) < 2:
                continue
            tag = f"cluster_contributions_to_logits_freq={best_f}_second_order"
            generate_pdf_plots_for_matrix(
                merged, p, save_dir=pdf_root, seed=seed,
                freq_list=[best_f], tag=tag, class_string=mlp_class_lower,
                num_principal_components=3,
            )

        # diagonal → base tag
        for best_f, mats in diag_cc.items():
            merged = _concat_mats(mats)
            print(f"[debug] merged DIAG (cluster) shape: {merged.shape}")
            if np.linalg.matrix_rank(merged) < 2:
                continue
            tag = f"cluster_contributions_to_logits_freq={best_f}"
            generate_pdf_plots_for_matrix(
                merged, p, save_dir=pdf_root, seed=seed,
                freq_list=[best_f], tag=tag, class_string=mlp_class_lower,
                num_principal_components=3,
            )

        from functools import partial
        # 1)  pure-JAX embedding extractor
        def compute_embeddings_transformer(
            params: dict,
            x: jnp.ndarray,                       # (B , 2)  int32 tokens  (a , b)
            *,
            concat: bool = False,                 # False →  “Eₐ + E_b”    True →  “[Eₐ‖E_b]”
        ) -> jnp.ndarray:
            """
            Returns the *input* embedding vector that will be fed into the first
            Transformer block, **after** adding learnt position embeddings.

            • `concat == False`   →  shape  (B , D)
            • `concat == True`    →  shape  (B , 2 D)
            """
            # ---- 1. grab weights ----------------------------------------------------
            # shared token table (W_E)  &  first two learned position vectors
            W_E, _ = model.extract_embeddings_ab(params)            # (p , D)
            pos0, pos1 = params["pos_embed"]["W_pos"][:2]                   # (D,)

            # ---- 2. look-up & add positions ----------------------------------------
            embed_a = W_E[x[:, 0]] + pos0                                   # (B , D)
            embed_b = W_E[x[:, 1]] + pos1                                   # (B , D)

            if concat:
                return jnp.concatenate([embed_a, embed_b], axis=-1)         # (B , 2 D)
            else:
                return embed_a + embed_b                                    # (B , D)

        # 2)  make_energy_funcs_transformer
        def make_energy_funcs_transformer(
            model: "Transformer",           # the *initialised* model instance
            params: dict,                   # its parameters
            *,
            concat: bool = False,
        ):
            """
            Returns two callables **emb_fn** and **batch_energy_sum** that exactly
            mirror the MLP helpers:

            • emb_fn(x_int)              →  embedding batch  (see above)
            • batch_energy_sum(e_batch)  →  Σ ‖J‖²_F  over that *batch*

            where J is the Jacobian  ∂ logits / ∂ embedding.
            """

            # ---------- f_embed : (D,) or (2D,)  →  (p,) logits ----------------------
            W_E, _ = model.extract_embeddings_ab(params)
            d_model = W_E.shape[1] 

            def _to_seq(x_flat: jnp.ndarray) -> jnp.ndarray:
                """
                Convert *one* flattened embedding vector back to a (2 , D) tensor
                that the Transformer can consume.

                * concat=False :  we arbitrarily put ½ x on each token –
                                this keeps the mapping differentiable
                                and preserves the total “energy”.
                * concat=True  :  simple split.
                """
                if concat:
                    ea, eb = jnp.split(x_flat, 2)                           # (D,) each
                else:
                    ea = x_flat * 0.5                                       # (D,)
                    eb = x_flat * 0.5
                return jnp.stack([ea, eb])[None, ...]                       # (1,2,D)

            def f_embed(x_flat: jnp.ndarray) -> jnp.ndarray:                # (p,)
                seq_emb = _to_seq(x_flat)                                   # (1,2,D)
                logits  = model.call_from_embedding_sequence(seq_emb, params)[0]
                return logits[-1]                                           # last-token

            # jit once so jacrev uses the cached XLA executable
            f_embed = jax.jit(f_embed)

            # ---------- jacrev‖·‖²_F helper -----------------------------------------
            def _squared_frobenius_norm_of_jac(x_flat: jnp.ndarray) -> jnp.ndarray:
                J = jax.jacrev(f_embed)(x_flat)            # (p , D | 2D)
                return jnp.sum(J * J)                      # scalar
            _squared_frobenius_norm_of_jac = jax.jit(_squared_frobenius_norm_of_jac)

            # ---------- public wrappers ---------------------------------------------
            emb_fn = partial(compute_embeddings_transformer, params, concat=concat)

            def batch_energy_sum(batch_emb: jnp.ndarray) -> jnp.ndarray:    # Σ ‖J‖² over batch
                return jax.vmap(_squared_frobenius_norm_of_jac)(batch_emb).sum()

            return emb_fn, batch_energy_sum

        # 3)  driver that averages over an arbitrary input set
        def compute_dirichlet_energy_embedding_transformer(
            model: "Transformer",
            params: dict,
            x_data: jnp.ndarray,                 # (N , 2)  all (a , b) pairs of interest
            *,
            batch_size: int = 1024,
            concat: bool = False,
        ) -> float:
            """
            Plain (non-JIT) wrapper that chunks `x_data` to keep memory modest.
            """
            emb_fn, batch_energy_sum = make_energy_funcs_transformer(
                model, params, concat=concat
            )

            total = 0.0
            n     = x_data.shape[0]

            for i in range(0, n, batch_size):
                x_batch   = x_data[i : i + batch_size]
                e_batch   = emb_fn(x_batch)                       # (B , D | 2D)
                total    += batch_energy_sum(e_batch)

            return float(total / n)

        a_all, b_all = jnp.mgrid[0:p, 0:p]
        X_FULL_GRID  = jnp.stack([a_all.ravel(), b_all.ravel()], axis=-1).astype(jnp.int32)

        dirichlet_E = compute_dirichlet_energy_embedding_transformer(
            model, params_i_safe, X_FULL_GRID, concat=False)
        print(f"[Dirichlet] seed {seed}: {dirichlet_E:.6e}")

        # optional: keep it in `layer_summaries` so you see it in the layer-JSONs
        layer_summaries.setdefault(1, {})["dirichlet_energy_everything"] = float(dirichlet_E)

        # ----------------------------------------------------------------
        # 1)  append to the *reconstruction_metrics_…seed_*.json file
        # ----------------------------------------------------------------
        freq_json_dir = os.path.join(
            PDF_DIR,
            f"{p}_freqs_distribution_r2_jsons",
            f"mlp=transformer_p={p}_bs={batch_size}_k={k}_nn={transformer_config['d_model']*num_neurons}"
            f"_wd={weight_decay}_lr={learning_rate}"
        )
        os.makedirs(freq_json_dir, exist_ok=True)

        out_path = os.path.join(
            freq_json_dir,
            f"reconstruction_metrics_top-k={k}_seed_{seed}.json"
        )

        if os.path.exists(out_path):
            with open(out_path) as f:
                rec_data = json.load(f)
        else:
            rec_data = {}

        rec_data.setdefault("model", {})["dirichlet_energy_everything"] = float(dirichlet_E)

        with open(out_path, "w") as f:
            json.dump(rec_data, f, indent=2)

        print(f"[Dirichlet] appended to → {out_path}")