import time
import operator
import jax
import jax.numpy as jnp  # JAX NumPy
from flax import linen as nn  # Linen API
from clu import metrics
from flax import struct
import optax  # Common loss functions and optimizers
import chex
import numpy as np
import os
import sys
import json
import flax.serialization as serialization
import jax.tree_util
from numpy.fft import fft2, ifft2
import plotly.io as pio
from flax.core.frozen_dict import freeze, unfreeze


jax.config.update("jax_traceback_filtering", 'off')
# jax.config.update("jax_debug_nans", True)

import optimizers
import training

#############################
### 1) DEFINE THE TRANSFORMER IN FLAX
#############################

class HookPoint(nn.Module):
    """Captures activations via sow so they appear in the
    intermediates collection of model.apply.

    Args
    ----
    key : str | None
        Optional explicit name.  If None we fall back to the module’s
        scope path, e.g. "blocks_0/mlp/hook_pre".
    """
    key: str | None = None

    @nn.compact
    def __call__(self, x):
        # Derive a unique name if none was given
        name = self.key or "/".join(self.scope.path)
        # Store the tensor every forward pass
        self.sow("intermediates", name, x, reduce_fn=lambda _, v: v)
        return x


class Embed(nn.Module):
    d_vocab: int
    d_model: int

    @nn.compact
    def __call__(self, x):
        """x: [batch, seq_len], returns [batch, seq_len, d_model]."""
        embedding = nn.Embed(
            num_embeddings=self.d_vocab,
            features=self.d_model,
            embedding_init=nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model))
        )
        return embedding(x)

class PosEmbed(nn.Module):
    max_ctx: int
    d_model: int

    def setup(self):
        self.W_pos = self.param(
            "W_pos",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.max_ctx, self.d_model),
        )

    def __call__(self, x):
        """
        x: [batch, seq_len, d_model]
        Add learned position embeddings for the first seq_len positions.
        """
        seq_len = x.shape[1]
        pos_emb = self.W_pos[:seq_len]  # [seq_len, d_model]
        return x + pos_emb[jnp.newaxis, :, :]

class Attention(nn.Module):
    d_model: int
    num_heads: int
    d_head: int
    n_ctx: int
    attn_coeff: float

    def setup(self):
        self.W_K = self.param(
            "W_K",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        self.W_Q = self.param(
            "W_Q",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        self.W_V = self.param(
            "W_V",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.num_heads, self.d_head, self.d_model),
        )
        # Final linear after concatenating heads
        self.W_O = self.param(
            "W_O",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_model)),
            (self.d_model, self.num_heads * self.d_head),
        )
        # Causal mask of shape (n_ctx, n_ctx), typically (2,2) for this example
        causal_mask = np.tril(np.ones((self.n_ctx, self.n_ctx), dtype=np.float32))
        self.causal_mask = jnp.array(causal_mask)

        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def __call__(self, x):
        batch_size, seq_len, _ = x.shape
        def project(W, x_):
            return jnp.einsum("ihd,bpd->biph", W, x_)

        k = self.hook_k(project(self.W_K, x))
        q = self.hook_q(project(self.W_Q, x))
        v = self.hook_v(project(self.W_V, x))

        attn_scores_pre = jnp.einsum("biph,biqh->biqp", k, q)
        attn_scores_pre = self.hook_attn_pre(attn_scores_pre / np.sqrt(self.d_head))

        full_mask = self.causal_mask[:seq_len, :seq_len]
        mask = (1.0 - full_mask) * -1e10
        attn_scores_masked = attn_scores_pre #+ mask

        attn_matrix = nn.softmax(attn_scores_masked, axis=-1)
        attn_matrix = attn_matrix * self.attn_coeff + (1.0 - self.attn_coeff)
        attn_matrix = self.hook_attn(attn_matrix)

        z = jnp.einsum("biph,biqp->biqh", v, attn_matrix)
        z = self.hook_z(z)

        z_trans = jnp.transpose(z, (0, 2, 1, 3))  # (b, seq_len, heads, d_head)
        z_flat = jnp.reshape(z_trans, (batch_size, seq_len, self.num_heads * self.d_head))

        out = jnp.einsum("df,bpf->bpd", self.W_O, z_flat)
        return out

class MLP(nn.Module):
    d_model: int
    d_mlp: int
    num_layers: int
    act_type: str = "ReLU"

    # ---------- parameters & hooks ----------
    def setup(self):
        # For each hidden layer i, stash its params and hooks under unique names.
        for i in range(self.num_layers):
            in_dim  = self.d_model if i == 0 else self.d_mlp
            out_dim = self.d_mlp

            # weight + bias
            setattr(
                self, f"W_{i}",
                self.param(f"W_{i}",
                           nn.initializers.normal(stddev=1/np.sqrt(out_dim)),
                           (out_dim, in_dim))
            )
            setattr(
                self, f"b_{i}",
                self.param(f"b_{i}", nn.initializers.zeros, (out_dim,))
            )

            # hooks before & after activation
            setattr(self, f"hook_pre{i+1}",
                    HookPoint(key=f"blocks_0/mlp/hook_pre{i+1}"))
            setattr(self, f"hook_post{i+1}",
                    HookPoint(key=f"blocks_0/mlp/hook_post{i+1}"))

        # final projection back to d_model
        self.W_out = self.param(
            "W_out",
            nn.initializers.normal(stddev=1/np.sqrt(self.d_model)),
            (self.d_model, self.d_mlp)
        )
        self.b_out = self.param("b_out", nn.initializers.zeros, (self.d_model,))

    # ---------- forward ----------
    def _act(self, x):
        if self.act_type == "ReLU":
            return nn.relu(x)
        raise ValueError(f"Unsupported activation {self.act_type!r}")

    def __call__(self, x):
        h = x
        for i in range(self.num_layers):
            W     = getattr(self, f"W_{i}")
            b     = getattr(self, f"b_{i}")
            pre_h = getattr(self, f"hook_pre{i+1}")
            post_h= getattr(self, f"hook_post{i+1}")

            pre = pre_h(jnp.einsum("md,bpd->bpm", W, h) + b)
            h   = post_h(self._act(pre))

        # final output projection
        return jnp.einsum("dm,bpm->bpd", self.W_out, h) + self.b_out



class TransformerBlock(nn.Module):
    d_model: int
    d_head: int
    num_heads: int
    n_ctx: int
    act_type: str
    attn_coeff: float
    num_mlp_layers: int 

    def setup(self):
        self.attn = Attention(
            d_model=self.d_model,
            num_heads=self.num_heads,
            d_head=self.d_head,
            n_ctx=self.n_ctx,
            attn_coeff=self.attn_coeff
        )
        self.mlp = MLP(
            d_model=self.d_model,
            d_mlp=self.d_model * num_neurons,
            num_layers=self.num_mlp_layers,
            act_type=self.act_type,
        )
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def __call__(self, x):
        resid_pre = self.hook_resid_pre(x)
        attn_out = self.attn(resid_pre)
        attn_out = self.hook_attn_out(attn_out)
        x_mid = self.hook_resid_mid(resid_pre + attn_out)

        mlp_out = self.mlp(x_mid)
        mlp_out = self.hook_mlp_out(mlp_out)
        x_post = self.hook_resid_post(x_mid + mlp_out)
        return x_post

class Transformer(nn.Module):
    """
    A simple Transformer with:
    - Embedding
    - PosEmbed
    - N Transformer blocks
    - A final linear "unembed" (W_U) -> [batch, seq_len, d_vocab]
    """
    num_layers: int
    num_mlp_layers: int
    d_vocab: int
    d_model: int
    d_head: int
    num_heads: int
    n_ctx: int
    act_type: str
    attn_coeff: float

    def setup(self):
        self.embed = Embed(self.d_vocab, self.d_model)
        self.pos_embed = PosEmbed(self.n_ctx, self.d_model)
        self.blocks = [TransformerBlock(
            d_model=self.d_model,
            d_head=self.d_head,
            num_heads=self.num_heads,
            n_ctx=self.n_ctx,
            act_type=self.act_type,
            attn_coeff=self.attn_coeff,
            num_mlp_layers=self.num_mlp_layers
        ) for _ in range(self.num_layers)]
        self.blocks = nn.Sequential(self.blocks)
        self.W_U = self.param(
            "W_U",
            nn.initializers.normal(stddev=1.0 / np.sqrt(self.d_vocab)),
            (self.d_model, self.d_vocab),
        )

    def __call__(self, x, training=False):
        x_emb = self.embed(x)           # [batch, seq_len, d_model]
        x_emb = self.pos_embed(x_emb)   # [batch, seq_len, d_model]
        x_out = self.blocks(x_emb)      # [batch, seq_len, d_model]
        logits = jnp.einsum("dm,bpd->bpd", self.W_U, x_out)
        return logits


if len(sys.argv) < 15:
    print("Usage: script.py <learning_rate> <weight_decay> <p> <batch_size> <optimizer> <epochs> <k> <batch_experiment> <num_neurons> <zeta> <training_set_size> <momentum> <injected_noise> <num_mlp_layers> <random_seed_int_1> [<random_seed_int_2> ...]")
    sys.exit(1)

print("start args parsing")
# Parse command-line arguments
learning_rate = float(sys.argv[1])  # stepsize_
weight_decay = float(sys.argv[2])  # L2 norm
p = int(sys.argv[3])
batch_size = int(sys.argv[4])
optimizer = sys.argv[5]
epochs = int(sys.argv[6])
k = int(sys.argv[7])
batch_experiment = sys.argv[8]
num_neurons = int(sys.argv[9])  # not used, but kept for consistency
zeta = int(sys.argv[10])
training_set_size = k * batch_size
momentum = float(sys.argv[12])
injected_noise = float(sys.argv[13]) / float(k)
num_mlp_layers = int(sys.argv[14])

# Accept multiple random seeds
random_seed_ints = [int(seed) for seed in sys.argv[15:]]
num_models = len(random_seed_ints)

def lr_schedule_fn(step):
    total_steps = epochs * k
    warmup_steps = total_steps // 2
    cooldown_steps = total_steps - warmup_steps

    def warmup_fn(step_):
        return learning_rate * (step_ / warmup_steps)

    def cooldown_fn(step_):
        return learning_rate * (1 - (step_ - warmup_steps) / cooldown_steps)

    return jax.lax.cond(
        step < warmup_steps,
        warmup_fn,
        cooldown_fn,
        operand=step
    )



print("making dataset")

def generate_polynomial_dataset(p, f, num_train_batches, batch_size, rng: jax.random.PRNGKey):
    total_possible_samples = p * p
    total_samples_needed = num_train_batches * batch_size
    if total_samples_needed > total_possible_samples:
        raise ValueError("Not enough data samples for the requested number of batches.")

    a, b = jnp.mgrid[0:p, 0:p]
    y = f(a, b)
    a_flat = a.ravel()
    b_flat = b.ravel()
    y_flat = y.ravel()
    data = jnp.stack([a_flat, b_flat, y_flat], axis=1)

    rng, subkey = jax.random.split(rng)
    indices = jax.random.choice(subkey, total_possible_samples, (total_samples_needed,), replace=False)
    train_data = data[indices]
    train_data = train_data.reshape(num_train_batches, batch_size, 3)
    return train_data

def generate_polynomial_dataset_for_seed(seed):
    rng_key = jax.random.PRNGKey(seed)
    num_train_batches = k
    train_data = generate_polynomial_dataset(
        p, lambda a, b: jnp.mod(a + b, p), num_train_batches, batch_size, rng_key)
    return train_data

if batch_experiment == "random_random":
    train_ds_list = []
    for seed in random_seed_ints:
        train_data = generate_polynomial_dataset_for_seed(seed)
        train_ds_list.append(train_data)
    train_ds = jnp.stack(train_ds_list)  # [num_models, k, batch_size, 3]
    print(f"Number of training batches: {train_ds.shape[1]}")

print("made dataset")

def compute_pytree_size(pytree):
    total_size = 0
    for array in jax.tree_util.tree_leaves(pytree):
        total_size += array.size * array.dtype.itemsize
    return total_size

dataset_size_bytes = (train_ds.shape[1] *
                      train_ds.shape[2] *
                      train_ds.shape[3] *
                      train_ds.dtype.itemsize)
dataset_size_mb = dataset_size_bytes / (1024 ** 2)
print(f"Dataset size per model: {dataset_size_mb:.2f} MB")

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')
    l2_loss: metrics.Average.from_output('l2_loss')

transformer_config = {
    "num_layers": 1,
    "num_mlp_layers": 3,    
    "d_vocab": p,
    "d_model": 128,
    "d_head": 32,
    "num_heads": 4,
    "n_ctx": 2,
    "act_type": "ReLU",
    "attn_coeff": 0.0
}
transformer_config["num_mlp_layers"] = num_mlp_layers
NUM_MLP_LAYERS = transformer_config["num_mlp_layers"]
assert transformer_config["d_head"] * transformer_config["num_heads"] == transformer_config["d_model"]

model = Transformer(**transformer_config)
dummy_x = jnp.zeros((batch_size, 2), dtype=jnp.int32)

def cross_entropy_loss(y_pred, y):
    logits_last = y_pred[:, -1, :]
    log_probs = jax.nn.log_softmax(logits_last, axis=-1)
    nll = - log_probs[jnp.arange(y.shape[0]), y]
    return nll.mean()

def total_loss(y_pred_and_l2, y):
    y_pred, _, l2_loss = y_pred_and_l2
    return cross_entropy_loss(y_pred, y) + l2_loss * weight_decay

def apply(variables, x, training=False):
    params = variables["params"]
    outputs = model.apply({"params": params}, x, training=training)
    l2_loss = sum(jnp.sum(jnp.square(p_)) for p_ in jax.tree_util.tree_leaves(params))
    return outputs, {}, l2_loss

def batched_apply(variables_batch, x_batch, training=False):
    return jax.vmap(lambda vars_, xx: apply(vars_, xx, training),
                    in_axes=(0, 0))(variables_batch, x_batch)

def sample_hessian(prediction, sample):
    logits_2d = prediction[0][:, -1, :]
    labels = sample[0][:, 2]
    return (optimizers.sample_crossentropy_hessian(logits_2d, labels), prediction[1], 0.0)

def compute_metrics(metrics_, *, loss, l2_loss, outputs, labels):
    logits_last = outputs[0][:, -1, :]
    metric_updates = metrics_.single_from_model_output(
        logits=logits_last, labels=labels, loss=loss, l2_loss=l2_loss)
    return metrics_.merge(metric_updates)

def prepare_batches(batches_array):
    x = batches_array[:, :, :, :2].astype(jnp.int32)  # [num_models, k, batch_size, 2]
    y = batches_array[:, :, :, 2].astype(jnp.int32)   # [num_models, k, batch_size]
    return x, y

print("Transformer model created.")
variables_list = []
for seed in random_seed_ints:
    rng_key = jax.random.PRNGKey(seed)
    variables = model.init(rng_key, dummy_x, training=False)

    # Now print out every parameter name & its shape:
    def print_param_tree(tree, prefix=""):
        for name, subtree in tree.items():
            if isinstance(subtree, dict):
                print_param_tree(subtree, prefix + name + "/")
            else:
                # leaf array: print its full key and shape
                print(f"{prefix + name} — shape {tuple(subtree.shape)}")

    print("=== Transformer parameter names ===")
    print_param_tree(variables["params"])

    variables_list.append(variables)

model_size_bytes = compute_pytree_size(variables_list[0]["params"])
model_size_mb = model_size_bytes / (1024 ** 2)
print(f"Single model size: {model_size_mb:.2f} MB")

variables_batch = {
    "params": jax.tree_util.tree_map(
        lambda *args: jnp.stack(args),
        *(v["params"] for v in variables_list)
    ),
    "batch_stats": None
}

if optimizer == "adam":
    tx = optax.adam(lr_schedule_fn)
elif optimizer[:3] == "SGD":
    tx = optax.sgd(learning_rate, momentum)
else:
    raise ValueError("Unsupported optimizer type")

def init_opt_state(params):
    return tx.init(params)

opt_state_list = []
for i in range(num_models):
    params_i = jax.tree_map(lambda x: x[i], variables_batch["params"])
    opt_state_i = init_opt_state(params_i)
    opt_state_list.append(opt_state_i)

opt_state_batch = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *opt_state_list)

def create_train_state(params, opt_state, rng_key):
    return training.TrainState(
        apply_fn=apply,
        params=params,
        tx=tx,
        opt_state=opt_state,
        loss_fn=total_loss,
        loss_hessian_fn=sample_hessian,
        compute_metrics_fn=compute_metrics,
        rng_key=rng_key,
        initial_metrics=Metrics,
        batch_stats=None,
        injected_noise=injected_noise
    )

states_list = []
for i in range(num_models):
    seed = random_seed_ints[i]
    rng_key = jax.random.PRNGKey(seed)
    params_i = jax.tree_map(lambda x: x[i], variables_batch["params"])
    opt_state_i = jax.tree_map(lambda x: x[i], opt_state_batch)
    st = create_train_state(params_i, opt_state_i, rng_key)
    states_list.append(st)

states = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *states_list)

train_x, train_y = prepare_batches(train_ds)
train_x = jax.device_put(train_x)  # [num_models, k, batch_size, 2]
train_y = jax.device_put(train_y)  # [num_models, k, batch_size]

initial_metrics_list = [st.initial_metrics.empty() for st in states_list]
initial_metrics = jax.tree_util.tree_map(lambda *arrs: jnp.stack(arrs), *initial_metrics_list)

# EVAL data
if training_set_size == p * p:
    print("Train set is the entire dataset. Using the training set for evaluation.")
    # Reshape training data: train_ds is [num_models, k, batch_size, 3].
    # We extract the first two columns for x and the third column for y.
    x_eval = train_ds[:, :, :, :2].reshape(num_models, -1, 2)
    y_eval = train_ds[:, :, :, 2].reshape(num_models, -1)
else:
    a_eval, b_eval = jnp.mgrid[0:p, 0:p]
    a_eval = a_eval.ravel()
    b_eval = b_eval.ravel()
    x_eval = jnp.stack([a_eval, b_eval], axis=-1).astype(jnp.int32)
    y_eval = jnp.mod(a_eval + b_eval, p).astype(jnp.int32)
    # Expand eval set for each model (each model gets the same eval set)
    x_eval = jax.device_put(x_eval)
    y_eval = jax.device_put(y_eval)
    x_eval = jnp.tile(x_eval[None, :, :], (num_models, 1, 1))
    y_eval = jnp.tile(y_eval[None, :], (num_models, 1))

# If using the training set as eval, ensure data is on-device.
x_eval = jax.device_put(x_eval)
y_eval = jax.device_put(y_eval)

eval_batch_size = 1024
total_eval_samples = x_eval.shape[1]
num_full_batches = total_eval_samples // eval_batch_size
remaining_samples = total_eval_samples % eval_batch_size

if remaining_samples > 0:
    pad_size = eval_batch_size - remaining_samples
    x_padding = jnp.zeros((num_models, pad_size, x_eval.shape[2]), dtype=x_eval.dtype)
    y_padding = jnp.zeros((num_models, pad_size), dtype=y_eval.dtype)
    x_eval_padded = jnp.concatenate([x_eval, x_padding], axis=1)
    y_eval_padded = jnp.concatenate([y_eval, y_padding], axis=1)
    num_eval_batches = num_full_batches + 1
else:
    x_eval_padded = x_eval
    y_eval_padded = y_eval
    num_eval_batches = num_full_batches

x_eval_batches = x_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size, -1)
y_eval_batches = y_eval_padded.reshape(num_models, num_eval_batches, eval_batch_size)


@jax.jit
def train_epoch(states_, x_batches, y_batches, init_metrics):
    """
    x_batches: [num_models, k2, batch_size2, 2]
    y_batches: [num_models, k2, batch_size2]
    We do a jax.lax.scan over the k2 dimension (the 'batch' dimension).
    """
    def train_step(carry, batch):
        (st, mets) = carry
        x_, y_ = batch
        new_states, new_metrics = jax.vmap(
            lambda s, m, xx, yy: s.train_step(m, (xx, yy)),
            in_axes=(0, 0, 0, 0)
        )(st, mets, x_, y_)
        return (new_states, new_metrics), None

    carry0 = (states_, init_metrics)
    transposed_x = x_batches.transpose(1, 0, 2, 3)  # shape [k2, num_models, batch_size2, 2]
    transposed_y = y_batches.transpose(1, 0, 2)     # shape [k2, num_models, batch_size2]
    (final_states, final_metrics), _ = jax.lax.scan(
        train_step, carry0, (transposed_x, transposed_y)
    )
    return final_states, final_metrics

@jax.jit
def eval_model(states_, x_batches, y_batches, init_metrics):
    def eval_step(mets, batch):
        x_, y_ = batch
        new_metrics = jax.vmap(
            lambda s, mm, xx, yy: s.eval_step(mm, (xx, yy)),
            in_axes=(0, 0, 0, 0)
        )(states_, mets, x_, y_)
        return new_metrics, None

    mets_ = init_metrics
    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    final_metrics, _ = jax.lax.scan(eval_step, mets_, (transposed_x, transposed_y))
    return final_metrics


# Track the first epoch where each model hits 100% accuracy
first_100_test_loss = [None] * num_models
first_100_cross_entropy_loss = [None] * num_models 


######################################
# 1) Original Training Loop
######################################
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs} (original batch_size={batch_size})")
    states, train_metrics = train_epoch(states, train_x, train_y, initial_metrics)

    # Print training metrics
    for i in range(num_models):
        train_metric_i = jax.tree_util.tree_map(lambda x: x[i], train_metrics)
        train_metric_i = train_metric_i.compute()
        print(f"Model {i + 1}/{num_models}: Train Loss: {train_metric_i['loss']:.6f}, "
              f"Train Accuracy: {train_metric_i['accuracy']:.2%}")

    # Evaluate on test
    print(f"\n--- Test Evaluation at Epoch {epoch + 1} ---")
    test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
    for i in range(num_models):
        test_metric_i = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
        test_metric_i = test_metric_i.compute()
        test_loss = test_metric_i['loss']
        test_accuracy = test_metric_i['accuracy']
        print(f" Model {i + 1}/{num_models} -> Test Loss: {test_loss:.6f},  Test Acc: {test_accuracy:.2%}")
        cross_entropy_val = test_loss - weight_decay * test_metric_i['l2_loss']

        # Record first time hitting 100%
        if first_100_test_loss[i] is None and test_accuracy > 0.99999:
            first_100_test_loss[i] = test_loss
            first_100_cross_entropy_loss[i] = cross_entropy_val  

    print("--- End of Test Evaluation ---\n")

test_metrics = eval_model(states, x_eval_batches, y_eval_batches, initial_metrics)
final_test_accuracies = []
for i in range(num_models):
    test_metric = jax.tree_util.tree_map(lambda x: x[i], test_metrics)
    test_metric = test_metric.compute()
    test_accuracy = test_metric["accuracy"]
    final_test_accuracies.append(test_accuracy)
    print(f"Model {i + 1} final test accuracy: {test_accuracy:.2%}")
    cross_entropy = test_metric['loss'] - weight_decay * test_metric['l2_loss']

    # (Saving or not saving logic the same as before)
    if test_accuracy >= 0.99999:
        experiment_name = batch_experiment
        optimizer_name = optimizer + str(momentum)
        num_neurons_transformer = transformer_config["d_model"] * num_neurons
        params_file_path = (
            f"/logn/transformer_r2_heatmap_k={k}_{epochs}/{p}_{k}_nn_{num_neurons}_fits_attn-co={transformer_config['attn_coeff']}_top-k_layers={num_mlp_layers}/"
            f"p={p}_bs={batch_size}_k={k}_nn={num_neurons_transformer}_lr={learning_rate}_wd={weight_decay}_epochs={epochs}_"
            f"training_set_size={training_set_size}/params_p_{p}_{batch_experiment}_"
            f"{optimizer_name}_ts_{training_set_size}_bs={batch_size}_nn={num_neurons}_"
            f"lr={learning_rate}_wd={weight_decay}_noise={injected_noise}_zeta={zeta}_k={k}_"
            f"rs_{random_seed_ints[i]}.params"
        )
        os.makedirs(os.path.dirname(params_file_path), exist_ok=True)
        with open(params_file_path, 'wb') as f:
            f.write(serialization.to_bytes(jax.tree_util.tree_map(lambda x: x[i], states.params)))
        print(f"Model {i + 1}: Parameters saved to {params_file_path}")
    else:
        print(f"Model {i + 1}: Test accuracy did not exceed 99.9%. Model parameters will not be saved.")
        # print(f"\n--- Misclassified Test Examples for Model {i + 1} ---")
        # single_params = jax.tree_map(lambda x: x[i], states.params)
        # logits = model.apply({'params': single_params}, x_eval, training=False)
        # predictions = jnp.argmax(logits[:, -1, :], axis=-1)
        # y_true = y_eval
        # incorrect_mask = predictions != y_true
        # incorrect_indices = jnp.where(incorrect_mask)[0]
        # if incorrect_indices.size > 0:
        #     print(f"  Total misclassifications: {len(incorrect_indices)}")
        #     for idx, (x_vals, true_label, pred_label) in enumerate(
        #         zip(x_eval[incorrect_indices],
        #             y_true[incorrect_indices],
        #             predictions[incorrect_indices]), 1
        #     ):
        #         a_val, b_val = x_vals
        #         print(f"    {idx}. a: {int(a_val)}, b: {int(b_val)}, True: {int(true_label)}, Predicted: {int(pred_label)}")
        # else:
        #     print("No misclassifications found. All predictions correct.")

    # Write cross-entropy to loss_log.txt
    num_neurons_transformer = transformer_config["d_model"] * num_neurons
    loss_log_dir = (
        f"/logn/transformer_r2_heatmap_k={k}_{epochs}/{p}_{k}_nn_{num_neurons}_fits_attn-co={transformer_config['attn_coeff']}_top-k_layers={num_mlp_layers}/"
        f"p={p}_bs={batch_size}_k={k}_nn={num_neurons_transformer}_lr={learning_rate}_wd={weight_decay}_epochs={epochs}_"
        f"training_set_size={training_set_size}/"
    )
    loss_log_path = os.path.join(loss_log_dir, "loss_log.txt")
    os.makedirs(loss_log_dir, exist_ok=True)
    with open(loss_log_path, "a") as log_file:
        log_file.write(f"{random_seed_ints[i]},{cross_entropy}\n")

# Finally, write the first_100% records
first_100_path = os.path.join(loss_log_dir, "first_100_acc_test_loss_records.txt")
with open(first_100_path, "a") as f:
    for i in range(num_models):
        if first_100_test_loss[i] is not None:
            f.write(f"{random_seed_ints[i]},{first_100_test_loss[i]},{first_100_cross_entropy_loss[i]}\n")

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict

n = p                               # grid size
a_vals = np.arange(n)
b_vals = np.arange(n)
a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")
full_inputs = np.stack([a_grid.ravel(), b_grid.ravel()], axis=1).astype(np.int32)

def _extract_hook_pre(params, x_batch, layer: int = 1):
    suffix = f"hook_pre{layer}"

    # 1) Run and grab intermediates
    _, inter = model.apply(
        {"params": params},
        x_batch,
        mutable=["intermediates"],
        training=False,
    )
    ints = inter["intermediates"]

    # 2) DEBUG dump
    print(f"\n[debug] searching for '{suffix}' in your intermediates tree:")
    def _dump_tree(d, indent=0):
        for k, v in d.items():
            if isinstance(v, dict):
                print(" " * indent + f"{k}/")
                _dump_tree(v, indent + 2)
            elif isinstance(v, list):
                shape = getattr(v[0], "shape", None) if v else None
                print(" " * indent + f"{k} -> list(len={len(v)}), item.shape={shape}")
            else:
                shape = getattr(v, "shape", None)
                print(" " * indent + f"{k}: array.shape={shape}")
    _dump_tree(ints)

    # 3) Recursive finder that unwraps both lists and one‐key dicts
    def _find_hook(d):
        if isinstance(d, dict):
            for k, v in d.items():
                # if this key ends in our suffix, unwrap v:
                if k.endswith(suffix):
                    if isinstance(v, list):
                        return v[0]
                    if isinstance(v, dict):
                        # assume exactly one entry mapping full_name -> array
                        return next(iter(v.values()))
                    return v
                # else recurse deeper
                found = _find_hook(v)
                if found is not None:
                    return found
        return None

    arr = _find_hook(ints)
    if arr is None:
        raise KeyError(
            f"Couldn't find any key ending in '{suffix}'.\n"
            f"(Top-level keys were: {list(ints.keys())})"
        )

    # 4) Final sanity check / conversion
    if not hasattr(arr, "shape"):
        raise RuntimeError(f"Expected an array for '{suffix}', but got {type(arr)}")
    print(f"[debug] found '{suffix}' with shape {arr.shape}")
    return np.array(jax.device_get(arr))



def _fit_cosine_product(grid_true, freq, n):
    """
    Fits grid_true using term1 = cos(f*(a+b))*cos(f*c) and term2 = sin(f*(a+b))*sin(f*c),
    but if freq == n/2 on an even n, does a single-column parity fit instead.
    Returns
    -------
    grid_fit : (n,n) reconstruction
    coeffs   : array of length 1 or 2
    """
    # coords
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    C = (A + B) % n
    y = grid_true.ravel()

    # Nyquist parity special case
    if (n % 2 == 0) and (freq == n // 2):
        # parity on (a+b) mod n  ↦  (-1)^(a+b)
        P = ((-1) ** (A + B)).ravel().astype(float)
        # safe‐standardize
        mu, sigma = P.mean(), P.std()
        sigma = sigma if sigma > 0 else 1.0
        Pn = (P - mu) / sigma
        # one‐coef least squares
        c, *_ = np.linalg.lstsq(Pn[:, None], y, rcond=None)
        grid_fit = (Pn * c).reshape(n, n)
        bias = np.mean(y - grid_fit.ravel())
        return grid_fit + bias, c

    # otherwise original 2-term fit
    f_scaled = 2 * np.pi * freq / n
    term1 = np.cos(f_scaled * (A + B)) * np.cos(f_scaled * C)
    term2 = np.sin(f_scaled * (A + B)) * np.sin(f_scaled * C)
    X = np.column_stack([term1.ravel(), term2.ravel()])

    # safe‐standardize (never divide by zero)
    means = X.mean(axis=0)
    stds  = X.std(axis=0)
    stds_safe = np.where(stds == 0, 1.0, stds)
    Xn = (X - means) / stds_safe

    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())
    return grid_fit + bias, coeffs


# def _fit_with_sin_cos(grid_true, fa, fb, n):
#     a_vals = np.arange(n)
#     b_vals = np.arange(n)
#     a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")

#     sin_a = np.sin(2 * np.pi * fa * a_grid / n)
#     cos_a = np.cos(2 * np.pi * fa * a_grid / n)
#     sin_b = np.sin(2 * np.pi * fb * b_grid / n)
#     cos_b = np.cos(2 * np.pi * fb * b_grid / n)

#     X = np.column_stack([sin_a.ravel(), cos_a.ravel(), sin_b.ravel(), cos_b.ravel()])
#     y = grid_true.ravel()
#     X_mean, X_std = X.mean(axis=0), X.std(axis=0)
#     X_norm = (X - X_mean) / X_std

#     coeffs, *_ = np.linalg.lstsq(X_norm, y, rcond=None)
#     grid_fit = (X_norm @ coeffs).reshape(n, n)
#     return grid_fit, coeffs


def _fit_sequential(
    grid_true: np.ndarray,
    fab1: int,
    fab2: int,
    fa: int,
    fb: int,
    n: int,
    p: int,
):
    """
    Sequential fit:
      1. Fit the (a + b) terms only,
      2. Fit the (a - b) terms on the residual,
      3. Fit the axis‐aligned sine/cosine terms on the remaining residual.

    Returns:
      grid_full  –  [n,n] full reconstructed grid,
      coeffs_sum –  (4,) coefficients for the (a + b) terms,
      coeffs_diff – (4,) coefficients for the (a - b) terms,
      coeffs_sep –  (4,) coefficients for the a-only and b-only terms,
      intermediate_grids – dict with 'sum', 'diff', 'residual', and partial reconstructions.
    """
    a_vals = np.arange(n)
    b_vals = np.arange(n)
    a_grid, b_grid = np.meshgrid(a_vals, b_vals, indexing="ij")
    ab_grid = a_grid + b_grid
    ab_diff = a_grid - b_grid

    # frequencies (scaled to radians)
    w1 = 2 * np.pi * fab1 / p
    w2 = 2 * np.pi * fab2 / p
    wa = 2 * np.pi * fa / n
    wb = 2 * np.pi * fb / n

    y_true = grid_true.ravel()

    # ----- Stage 1: Fit (a + b) terms -----
    basis_sum = np.column_stack([
        np.cos(w1 * ab_grid).ravel(),
        np.sin(w1 * ab_grid).ravel(),
        np.cos(w2 * ab_grid).ravel(),
        np.sin(w2 * ab_grid).ravel(),
    ])
    mean1, std1 = basis_sum.mean(0), basis_sum.std(0)
    X1 = (basis_sum - mean1) / std1

    coeffs_sum, *_ = np.linalg.lstsq(X1, y_true, rcond=None)
    grid_sum = (X1 @ coeffs_sum).reshape(n, n)
    residual1 = y_true - grid_sum.ravel()

    # ----- Stage 2: Fit (a - b) terms -----
    basis_diff = np.column_stack([
        np.cos(w1 * ab_diff).ravel(),
        np.sin(w1 * ab_diff).ravel(),
        np.cos(w2 * ab_diff).ravel(),
        np.sin(w2 * ab_diff).ravel(),
    ])
    mean2, std2 = basis_diff.mean(0), basis_diff.std(0)
    X2 = (basis_diff - mean2) / std2

    coeffs_diff, *_ = np.linalg.lstsq(X2, residual1, rcond=None)
    grid_diff = (X2 @ coeffs_diff).reshape(n, n)
    residual2 = residual1 - grid_diff.ravel()

    # ----- Stage 3: Fit axis-aligned terms -----
    basis_sep = np.column_stack([
        np.sin(wa * a_grid).ravel(),
        np.cos(wa * a_grid).ravel(),
        np.sin(wb * b_grid).ravel(),
        np.cos(wb * b_grid).ravel(),
    ])
    mean3, std3 = basis_sep.mean(0), basis_sep.std(0)
    X3 = (basis_sep - mean3) / std3

    coeffs_sep, *_ = np.linalg.lstsq(X3, residual2, rcond=None)
    grid_sep = (X3 @ coeffs_sep).reshape(n, n)

    # Final combined reconstruction
    grid_full = grid_sum + grid_diff + grid_sep

    intermediate_grids = {
        "sum": grid_sum,
        "diff": grid_diff,
        "sep": grid_sep,
        "residual": residual2.reshape(n, n)
    }

    return grid_full, coeffs_sum, coeffs_diff, coeffs_sep, intermediate_grids


import numpy as np
from collections import defaultdict
# (Other necessary modules such as jax, plotly, etc., are assumed to be imported elsewhere.)

def fit_conjecture_direct(grid_true, n, fa, fb):
    """
    Original: builds 4 sin/cos basis for (a±b) at freqs fa,fb.
    New: if either fa or fb is n/2 on even n, collapses to parity basis
         [(-1)^(a+b), (-1)^(a-b)] only.
    Returns
    -------
    dict with keys "grid", "coeffs", "labels", "r2", "mse", "bias", "fa", "fb"
    """
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    y = grid_true.ravel()

    # special‐case any Nyquist
    if (n % 2 == 0) and (fa == n//2 or fb == n//2):
        # parity of sum & parity of difference
        sump  = ((-1) ** (A + B)).ravel().astype(float)
        diffp = np.where(((A - B) % 2) == 0, 1.0, -1.0).ravel()
        X = np.column_stack([sump, diffp])
        labels = ["parity(a+b)", "parity(a−b)"]

        # drop any zero‐variance columns
        vars_ = X.var(axis=0)
        mask  = vars_ > 0
        X     = X[:, mask]
        labels = [lab for lab, ok in zip(labels, mask) if ok]

        # safe‐standardize
        mu   = X.mean(axis=0)
        sig  = X.std(axis=0)
        sig  = np.where(sig == 0, 1.0, sig)
        Xn   = (X - mu) / sig

        coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
        grid_fit = (Xn @ coeffs).reshape(n, n)
        bias     = np.mean(y - grid_fit.ravel())

        resid = y - grid_fit.ravel()
        mse   = np.mean(resid**2)
        var   = np.var(y)
        r2    = 1 - mse/(var if var>0 else 1.0)

        return {
            "grid": grid_fit + bias,
            "coeffs": coeffs,
            "labels": labels,
            "bias": bias,
            "mse": mse,
            "r2": r2,
            "fa": fa,
            "fb": fb,
        }

    # otherwise fall back to original 4-term code
    sin_a = np.sin(2*np.pi*fa * A / n)
    cos_a = np.cos(2*np.pi*fa * A / n)
    sin_b = np.sin(2*np.pi*fb * B / n)
    cos_b = np.cos(2*np.pi*fb * B / n)

    basis_sin_ab   = (sin_a*cos_b + cos_a*sin_b).ravel()
    basis_sin_amab = (sin_a*cos_b - cos_a*sin_b).ravel()
    basis_cos_ab   = (cos_a*cos_b - sin_a*sin_b).ravel()
    basis_cos_amab = (cos_a*cos_b + sin_a*sin_b).ravel()

    X_basis = np.column_stack([
        basis_sin_ab,
        basis_sin_amab,
        basis_cos_ab,
        basis_cos_amab
    ])
    labels = ["sin(a+b)", "sin(a−b)", "cos(a+b)", "cos(a−b)"]

    # standardize & fit exactly as before
    mu = X_basis.mean(axis=0)
    sig = X_basis.std(axis=0)
    sig_safe = np.where(sig==0, 1.0, sig)
    Xn = (X_basis - mu)/sig_safe

    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())

    resid = y - grid_fit.ravel()
    mse   = np.mean(resid**2)
    var   = np.var(y)
    r2    = 1 - mse/(var if var>0 else 1.0)

    return {
        "grid": grid_fit + bias,
        "coeffs": coeffs,
        "labels": labels,
        "bias": bias,
        "mse": mse,
        "r2": r2,
        "fa": fa,
        "fb": fb,
    }


def _fit_with_sin_cos(grid_true, fa, fb, n):
    """
    Fits grid_true with either:
      • full sin/cos basis (4 columns), or
      • a parity basis (1 or 2 columns) if fa or fb == n/2 on even n.

    Returns
    -------
    grid_fit : np.ndarray of shape (n,n)
    coeffs   : np.ndarray of length 1–4
    """
    a = np.arange(n)
    b = np.arange(n)
    A, B = np.meshgrid(a, b, indexing="ij")
    y = grid_true.ravel()

    # parity branch if we hit the Nyquist freq
    if n % 2 == 0 and (fa == n//2 or fb == n//2):
        cols = []
        # parity(a+b)?
        if fa == n//2:
            parity_sum = np.where(((A + B) % 2) == 0, 1.0, -1.0).ravel()
            cols.append(parity_sum)
        # parity(a−b)?
        if fb == n//2:
            parity_diff = np.where(((A - B) % 2) == 0, 1.0, -1.0).ravel()
            cols.append(parity_diff)

        X = np.column_stack(cols)
    else:
        # original sin/cos basis
        sin_a = np.sin(2*np.pi*fa * A / n).ravel()
        cos_a = np.cos(2*np.pi*fa * A / n).ravel()
        sin_b = np.sin(2*np.pi*fb * B / n).ravel()
        cos_b = np.cos(2*np.pi*fb * B / n).ravel()
        X = np.column_stack([sin_a, cos_a, sin_b, cos_b])

    # safe standardize
    mu = X.mean(axis=0)
    sigma = X.std(axis=0)
    sigma_safe = np.where(sigma == 0, 1.0, sigma)
    Xn = (X - mu) / sigma_safe

    # least squares + bias
    coeffs, *_ = np.linalg.lstsq(Xn, y, rcond=None)
    grid_fit = (Xn @ coeffs).reshape(n, n)
    bias = np.mean(y - grid_fit.ravel())
    grid_fit += bias

    return grid_fit, coeffs


def fit_conjecture_enhanced(grid_true, n, fa, fb, top_k=3):
    """
    Enhanced additive conjecture fit with top_k second-order fits and top_k first-order fits.

    Parameters:
      grid_true : np.ndarray of shape (n, n) – the true grid.
      n         : int – grid dimension.
      fa, fb    : int – initial frequencies (unused, for backward compatibility).
      top_k     : int – number of second-order and first-order fits to perform.

    Returns a dictionary containing:
      "r2", "mse", "grid", "bias",
      "second_order_details": list of dicts (one per second-order fit),
      "first_order_details": list of dicts (one per first-order fit),
      "coeffs1", "labels": backward-compatible first SO coefficients/labels,
      "res1": residual after second-order fits,
      "res2": residual after first-order fits,
      "res_total": overall residual.
    """
    # Stage 1: top_k second-order fits
    total_fit_so = np.zeros_like(grid_true)
    resid = grid_true.copy()
    second_order_details = []
    for i in range(top_k):
        # compute dominant freqs on current residual
        fft2d = np.fft.fft2(resid)
        mag = np.abs(fft2d)
        mag[0, 0] = 0
        fa_i = int(np.argmax(mag.sum(axis=1)[1:(n//2 + 1)]) + 1)
        fb_i = int(np.argmax(mag.sum(axis=0)[1:(n//2 + 1)]) + 1)
        # second-order fit
        fit_i = fit_conjecture_direct(resid, n, fa_i, fb_i)
        grid_fit_i = fit_i["grid"]
        second_order_details.append({
            "fa": fa_i,
            "fb": fb_i,
            "coeffs": fit_i["coeffs"],
            "labels": fit_i["labels"],
            "r2": fit_i["r2"],
            "mse": fit_i["mse"]
        })
        total_fit_so += grid_fit_i
        resid -= grid_fit_i
    res1 = grid_true - total_fit_so

    # Stage 2: top_k first-order sin/cos fits
    total_fit_fo = np.zeros_like(grid_true)
    first_order_details = []
    for i in range(top_k):
        fft2d_r = np.fft.fft2(resid)
        mag_r = np.abs(fft2d_r)
        mag_r[0, 0] = 0
        fa_r = int(np.argmax(mag_r.sum(axis=1)[1:(n//2 + 1)]) + 1)
        fb_r = int(np.argmax(mag_r.sum(axis=0)[1:(n//2 + 1)]) + 1)
        grid_fit_cos, coeffs_cos = _fit_with_sin_cos(resid, fa_r, fb_r, n)
        first_order_details.append({
            "fa": fa_r,
            "fb": fb_r,
            "coeffs": coeffs_cos
        })
        total_fit_fo += grid_fit_cos
        resid -= grid_fit_cos
    res2 = grid_true - (total_fit_so + total_fit_fo)

    # Combine all fits and adjust bias
    grid_total = total_fit_so + total_fit_fo
    y_true = grid_true.ravel()
    bias_total = np.mean(y_true - grid_total.ravel())
    grid_total += bias_total
    res_total = grid_true - grid_total

    # Performance metrics
    var = np.var(y_true)
    mse = np.mean((y_true - grid_total.ravel()) ** 2)
    r2 = 1.0 - mse / (var if var > 0 else 1.0)

    # Backward compatibility: first second-order fit
    coeffs1 = second_order_details[0]["coeffs"]
    labels = second_order_details[0]["labels"]

    return {
        "r2": r2,
        "mse": mse,
        "grid": grid_total,
        "bias": bias_total,
        "second_order_details": second_order_details,
        "first_order_details": first_order_details,
        "coeffs1": coeffs1,
        "labels": labels,
        "res1": res1,
        "res2": res2,
        "res_total": res_total,
    }

def compute_injected_accuracy(params, fitted_pre, x_all, y_all, layer_idx):
    suffix = f"hook_pre{layer_idx}"

    # 1) Capture the current intermediates
    _, inter = model.apply(
        {"params": params},
        x_all,
        mutable=["intermediates"],
        training=False,
    )
    ints = inter["intermediates"]

    # 2) Find the full key-path to your hook_preX, handling dict/list/array leaves
    def find_hook_path(d, path):
        if isinstance(d, dict):
            for k, v in d.items():
                new_path = path + [k]
                # If this key ends in our suffix, dive one final level if needed
                if k.endswith(suffix):
                    if isinstance(v, list):
                        return new_path
                    if isinstance(v, dict):
                        # assume exactly one entry in that dict
                        inner = next(iter(v.keys()))
                        return new_path + [inner]
                    # v is already an array
                    return new_path
                # otherwise recurse
                res = find_hook_path(v, new_path)
                if res:
                    return res
        return None

    hook_path = find_hook_path(ints, [])
    if not hook_path:
        raise KeyError(f"No intermediate ending in '{suffix}' found. Keys seen:\n{list(ints.keys())}")

    print(f"[debug] will inject into path: {' ➔ '.join(hook_path)}")

    # 3) Build the minimal intermediates override dict
    new_int = fitted_pre
    # we want the final (deepest) key to map to a list-of-arrays
    for key in reversed(hook_path):
        if key == hook_path[-1]:
            new_int = { key: [new_int] }
        else:
            new_int = { key: new_int }
    new_intermediates = new_int

    # 4) Rerun with our override
    vars2 = freeze({"params": params, "intermediates": new_intermediates})
    logits_inj = model.apply(vars2, x_all, training=False)
    preds = jnp.argmax(logits_inj[:, -1, :], axis=-1)
    return jnp.mean(preds == y_all)


# --- Main analysis loop ---
threshold = 0.00001
# collect layer‑1 fits per model
layer1_fits_all = [[] for _ in range(len(random_seed_ints))]

for m_i, seed in enumerate(random_seed_ints):
    print(f"\n===== Seed {seed} =====")
    # extract this model’s params
    params_i = jax.tree_util.tree_map(lambda x: x[m_i], states.params)
    deep_layer_grids = [] 
    layer_summaries = {}

    # Analyze each layer
    for layer in range(1, NUM_MLP_LAYERS + 1):
        is_deepest = (layer == NUM_MLP_LAYERS)
        print(f"\n[analyzing] Seed {seed} · Layer {layer}")
        # 1) Extract true pre-activations
        pre_all = _extract_hook_pre(params_i, full_inputs, layer)
        pre_tok1 = pre_all[:, 1, :]
        d_mlp = pre_tok1.shape[-1]

        freq_r2s = defaultdict(list)
        freq_fit_type = {}

        # Fit each neuron
        for neuron in range(d_mlp):
            grid_true = pre_tok1[:, neuron].reshape(n, n)

            # first-order candidate freqs
            fft2d = np.fft.fft2(grid_true)
            mag = np.abs(fft2d)
            mag[0, 0] = 0
            fa_raw = int(np.argmax(mag.sum(axis=1)[1:(n//2 + 1)]) + 1)
            fb_raw = int(np.argmax(mag.sum(axis=0)[1:(n//2 + 1)]) + 1)

            # best sin/cos fit
            best_sc = {"r2": -np.inf}
            for candidate_fa in (fa_raw, n - fa_raw):
                for candidate_fb in (fb_raw, n - fb_raw):
                    grid_sc, _ = _fit_with_sin_cos(grid_true, candidate_fa, candidate_fb, n)
                    y_true = grid_true.ravel()
                    y_pred = grid_sc.ravel()
                    var = float(np.var(y_true))
                    mse = float(np.mean((y_true - y_pred) ** 2))
                    r2 = 1 - mse / (var if var > 0 else 1.0)
                    bias = np.mean(y_true - y_pred)
                    ypb = y_pred + bias
                    mse_b = float(np.mean((y_true - ypb) ** 2))
                    r2_b = 1 - mse_b / (var if var > 0 else 1.0)
                    if r2_b > r2:
                        r2, grid_sc = r2_b, grid_sc + bias
                    if r2 > best_sc["r2"]:
                        best_sc = {"grid": grid_sc, "r2": r2, "fa": candidate_fa, "fb": candidate_fb}

            # Store layer-1 fits
            if layer == 1:
                layer1_fits_all[m_i].append({
                    "fa": best_sc["fa"],
                    "fb": best_sc["fb"],
                    "grid": best_sc["grid"]
                })

            # Layer-2 additional fits
            if layer != 1:
                # cosine-product fit
                grid_cp, _ = _fit_cosine_product(grid_true, best_sc["fa"], p)
                y_t = grid_true.ravel()
                y_cp = grid_cp.ravel()
                var = float(np.var(y_t))
                mse_cp = float(np.mean((y_t - y_cp) ** 2))
                r2_cp = 1 - mse_cp / (var if var > 0 else 1.0)
                bias_cp = np.mean(y_t - y_cp)
                ycpb = y_cp + bias_cp
                mse_cpb = float(np.mean((y_t - ycpb) ** 2))
                r2_cpb = 1 - mse_cpb / (var if var > 0 else 1.0)
                if r2_cpb > r2_cp:
                    r2_cp, grid_cp = r2_cpb, grid_cp + bias_cp

                # compute dynamic top_k
                unique_freqs = set()
                for entry in layer1_fits_all[m_i]:
                    unique_freqs.update([entry["fa"], entry["fb"]])
                top_k = len(unique_freqs)

                # enhanced conjecture with top_k
                conj = fit_conjecture_enhanced(
                    grid_true, n, best_sc["fa"], best_sc["fb"], top_k=top_k
                )
                grid_conj, r2_conj = conj["grid"], conj["r2"]

            # choose best fit
            best_r2 = best_sc["r2"]
            best_type = "sin_cos"
            grid_best = best_sc["grid"]

            if layer != 1:
                if r2_cp > best_r2:
                    best_r2, best_type, grid_best = r2_cp, "cos_product", grid_cp
                if r2_conj > best_r2:
                    best_r2, best_type, grid_best = r2_conj, "conjecture", grid_conj
            if is_deepest:
                deep_layer_grids.append(grid_best)

            freq_r2s[(best_sc["fa"], best_sc["fb"])].append(best_r2)
            freq_fit_type[(best_sc["fa"], best_sc["fb"])] = best_type

        # Determine top_k for summary
        if is_deepest:
            unique_freqs = set()
            for entry in layer1_fits_all[m_i]:
                unique_freqs.update([entry["fa"], entry["fb"]])
            top_k = len(unique_freqs)
        else:
            top_k = None

        # Build summary dict
        summary = {
            "test_accuracy": float(final_test_accuracies[m_i]),
            "top_k": top_k
        }

        # manual forward debug (unchanged)
        embed_mod = Embed(d_vocab=p, d_model=128)
        pos_mod = PosEmbed(max_ctx=2, d_model=128)
        attn_mod = Attention(
            d_model=128, num_heads=4, d_head=32,
            n_ctx=2, attn_coeff=1.0
        )

        x_debug = full_inputs[:16]
        y_debug = (x_debug[:, 0] + x_debug[:, 1]) % p

        logits_ref = model.apply({'params': params_i}, x_debug)
        preds_ref = jnp.argmax(logits_ref[:, -1, :], axis=-1)

        def manual_forward(params, x):
            x_emb  = embed_mod.apply({'params': params['embed']}, x)
            x_emb  = pos_mod.apply({'params': params['pos_embed']}, x_emb)
            att    = attn_mod.apply({'params': params['blocks_0']['attn']}, x_emb)
            resid  = x_emb + att           # after attention

            h = resid
            for i in range(NUM_MLP_LAYERS):
                W = params['blocks_0']['mlp'][f'W_{i}']
                b = params['blocks_0']['mlp'][f'b_{i}']
                pre = jnp.einsum("md,bpd->bpm", W, h) + b
                h   = jax.nn.relu(pre)
            out = jnp.einsum("dm,bpm->bpd",
                            params['blocks_0']['mlp']['W_out'], h) + params['blocks_0']['mlp']['b_out']
            resid = resid + out
            return jnp.einsum("dm,bpd->bpd", params['W_U'], resid)

        preds_man = jnp.argmax(manual_forward(params_i, x_debug)[:, -1, :], axis=-1)
        print("manual == reference?", jnp.all(preds_man == preds_ref))
        print("debug accuracy:", jnp.mean(preds_ref == y_debug))

        # compute fake-injection accuracy
        y_full = (full_inputs[:, 0] + full_inputs[:, 1]) % p
        real_pre = pre_all[:, 1, :]
        fake_acc = compute_injected_accuracy(
            params_i, real_pre, full_inputs, y_full, layer_idx=layer
        )
        print(f"fake-injection acc: {fake_acc:.2%}")

        # compute and record injected accuracy for layer 2
        if is_deepest:
            fitted_pre = jnp.stack([g.reshape(-1) for g in deep_layer_grids], axis=1)
            inj_acc = compute_injected_accuracy(params_i, fitted_pre, full_inputs, y_full, layer_idx=layer)
            print(f"injected test accuracy: {inj_acc:.2%}")
            summary["injected_test_accuracy"] = float(inj_acc)

        # record freq stats
        for (fa, fb), r2_list in sorted(freq_r2s.items(), key=lambda x: -np.mean(x[1])):
            summary[f"{fa},{fb}"] = {
                "avg_r2": np.mean(r2_list),
                "count": len(r2_list),
                "fit_type": freq_fit_type[(fa, fb)]
            }

        layer_summaries[layer] = summary

    # ─── Propagate the deepest-layer injected accuracy upward ───
    deep_key = NUM_MLP_LAYERS                              # last (deepest) hidden layer
    deep_acc = layer_summaries.get(deep_key, {}).get("injected_test_accuracy")

    if deep_acc is not None:
        # Copy to every shallower layer
        for l in range(1, deep_key):                      # 1-based indexing
            layer_summaries[l]["injected_test_accuracy"] = deep_acc



    # Write out JSON for both layers
    for layer, summary in layer_summaries.items():  
        pth = os.path.join(
            loss_log_dir,
            f"all_preactivations_L{layer}_seed{seed}.json"
        )
        with open(pth, 'w') as jf:
            json.dump(summary, jf, indent=2)
        print(f"[analysis] saved → {pth}")
