import os
import json
from typing import Any, Dict, List, Tuple

import numpy as np
import jax
import jax.numpy as jnp
from flax.core import freeze

import DFT
import dihedral
import report
import analysis.R2 as R2
from pca_diffusion_plots_w_helpers import generate_pdf_plots_for_matrix
from color_rules import colour_quad_a_only, colour_quad_mod_g, colour_c_mod_p
from transformer_class import TransformerOneEmbed, TransformerTwoEmbed, HookPoint
import controllers.paths_Transformer as paths
from controllers.prep_data_train_eval import eval_model, make_full_eval_grid


def final_eval_all_models(*, states, x_eval_batches, y_eval_batches, init_metrics, random_seed_ints: List[int]):
    test_metrics = eval_model(states, x_eval_batches, y_eval_batches, init_metrics)
    results = {}
    for i, seed in enumerate(random_seed_ints):
        tm = jax.tree_util.tree_map(lambda x: x[i], test_metrics).compute()
        reached = float(tm["accuracy"]) == 1.0
        results[seed] = {
            "reach_100pct_test": reached,
            "loss": float(tm["loss"]),
            "l2_loss": float(tm["l2_loss"]),
            "accuracy": float(tm["accuracy"]),
        }
    return results


def save_epoch_logs(logs_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features_or_dm: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in logs_by_seed.items():
        path = os.path.join(out_dir, f"log_features_{features_or_dm}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"[Transformer] Epoch log for seed {seed} saved to {path}")


def save_final_logs(log_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in log_by_seed.items():
        path = os.path.join(out_dir, f"final_log_features_{features}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"[Transformer] Final log for seed {seed} saved to {path}")


def save_prune_logs(log_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in log_by_seed.items():
        path = os.path.join(out_dir, f"prune_log_features_{features}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"Prune log for seed {seed} saved to {path}")


def make_some_jsons(
    *,
    preacts: list[np.ndarray],
    group_size: int,
    clusters_by_layer: list[dict[int, list[int]]],
    cluster_weights_to_logits: dict[int, np.ndarray],
    cluster_weights_to_dmodel: dict[int, np.ndarray] | None = None,
    save_dir: str,
    subdir: str = "json",
    float_dtype=np.float32,
    sanity_check: bool = True,
    cluster_contribs_to_logits: dict[int, np.ndarray] | None = None,
    cluster_contribs_to_dmodel: dict[int, np.ndarray] | None = None,
) -> str:
    if not preacts:
        raise ValueError("make_some_jsons: empty `preacts`.")
    Z_last = np.asarray(preacts[-1])
    n_rows, width_last = Z_last.shape
    if n_rows != group_size * group_size:
        raise ValueError(f"make_some_jsons: expected group_size^2={group_size*group_size} rows, got {n_rows}.")
    if not clusters_by_layer:
        raise ValueError("make_some_jsons: empty `clusters_by_layer`.")

    last_layer_clusters = clusters_by_layer[-1] or {}
    if not isinstance(last_layer_clusters, dict):
        raise TypeError("make_some_jsons: clusters_by_layer[-1] must be a dict {freq -> [neuron_ids]}.")

    json_root = os.path.join(save_dir, subdir)
    os.makedirs(json_root, exist_ok=True)

    for freq, neuron_ids in last_layer_clusters.items():
        if not neuron_ids:
            continue

        W_block = cluster_weights_to_logits.get(freq, None)
        if W_block is None:
            continue
        W_block = np.asarray(W_block)

        W_block_dmodel = None
        if cluster_weights_to_dmodel is not None:
            W_block_dmodel = cluster_weights_to_dmodel.get(freq)
            if W_block_dmodel is not None:
                W_block_dmodel = np.asarray(W_block_dmodel)

        ids = np.asarray(neuron_ids, dtype=int)
        if np.any((ids < 0) | (ids >= width_last)):
            bad = ids[(ids < 0) | (ids >= width_last)]
            raise ValueError(f"cluster {freq}: invalid neuron ids {bad.tolist()} for width_last={width_last}")
        if W_block.shape[0] != ids.shape[0]:
            raise ValueError(f"cluster {freq}: W_block rows {W_block.shape[0]} != len(ids) {ids.shape[0]}")

        if W_block.shape[1] != group_size:
            raise ValueError(
                f"make_some_jsons: for freq={freq}, W_block has {W_block.shape[1]} columns, expected group_size={group_size}."
            )

        Z_cluster = Z_last[:, ids]
        H_cluster = np.maximum(Z_cluster, 0.0)

        contribs_logits = H_cluster[:, :, None] * W_block[None, :, :]

        Hf = np.asarray(H_cluster, dtype=np.float64)
        Wf = np.asarray(W_block, dtype=np.float64)
        C_sum_broadcast = (Hf[:, :, None] * Wf[None, :, :]).sum(axis=1)
        C_mm = Hf @ Wf

        if not np.allclose(C_sum_broadcast, C_mm, rtol=1e-9, atol=1e-12):
            diff = np.max(np.abs(C_sum_broadcast - C_mm))
            where = np.unravel_index(np.argmax(np.abs(C_sum_broadcast - C_mm)), C_mm.shape)
            raise RuntimeError(
                f"[debug] freq={freq}: broadcast-sum vs matmul mismatch. "
                f"max_abs_diff={diff:.3e} at {where}; "
                f"dtypes H:{H_cluster.dtype} W:{W_block.dtype}"
            )

        contribs_dmodel = None
        if W_block_dmodel is not None:
            contribs_dmodel = H_cluster[:, :, None] * W_block_dmodel[None, :, :]

        if sanity_check and (cluster_contribs_to_logits is not None):
            C_freq_expected = cluster_contribs_to_logits.get(freq)
            if C_freq_expected is not None and np.size(C_freq_expected):
                C_exp = np.asarray(C_freq_expected, dtype=np.float64)
                if C_exp.shape != C_mm.shape:
                    raise ValueError(f"[debug] freq={freq}: shape mismatch exp{C_exp.shape} vs mm{C_mm.shape}")
                scale = max(1.0, float(np.max(np.abs(C_exp))))
                if not np.allclose(C_mm, C_exp, rtol=1e-4, atol=1e-5 * scale):
                    diff = np.max(np.abs(C_mm - C_exp))
                    where = np.unravel_index(np.argmax(np.abs(C_mm - C_exp)), C_mm.shape)
                    raise ValueError(
                        f"make_some_jsons: contribution mismatch for freq={freq}. "
                        f"max_abs_diff={diff:.3e} at {where}. "
                        f"|C|={W_block.shape[0]}, group_size={group_size}"
                    )

        if sanity_check and (cluster_contribs_to_dmodel is not None) and (contribs_dmodel is not None):
            D_freq_expected = np.asarray(cluster_contribs_to_dmodel.get(freq))
            if D_freq_expected is not None and D_freq_expected.size:
                D_sum = contribs_dmodel.sum(axis=1)
                if D_freq_expected.shape != D_sum.shape:
                    raise ValueError(
                        f"make_some_jsons: cluster_contribs_to_dmodel[{freq}] has shape {D_freq_expected.shape}, "
                        f"expected {D_sum.shape}."
                    )
                if not np.allclose(D_sum, D_freq_expected, rtol=1e-5, atol=1e-6):
                    raise ValueError(f"make_some_jsons: d_model contribution mismatch for freq={freq} (sum != total).")

        payload = {}
        for j, nid in enumerate(ids.tolist()):
            entry = {
                "preactivations": Z_cluster[:, j].astype(float_dtype).tolist(),
                "w_out": W_block[j, :].astype(float_dtype).tolist(),
                "contribs_to_logits": contribs_logits[:, j, :].astype(float_dtype).tolist(),
            }
            if W_block_dmodel is not None:
                entry["w_dmodel"] = W_block_dmodel[j, :].astype(float_dtype).tolist()
                entry["contribs_to_dmodel"] = contribs_dmodel[:, j, :].astype(float_dtype).tolist()
            payload[str(int(nid))] = entry

        out_path = os.path.join(json_root, f"cluster_{freq}.json")
        with open(out_path, "w") as f:
            json.dump(payload, f)

    return json_root


def _load_alive_indices_for_seed(
    prune_dir: str,
    features_or_dm: int,
    seed: int,
    *,
    num_layers: int,
    params_seed: dict,
) -> list[list[int]]:
    path = os.path.join(prune_dir, f"prune_log_features_{features_or_dm}_seed_{seed}.json")
    if os.path.exists(path):
        with open(path, "r") as f:
            rep = json.load(f)
        alive = rep.get("alive_final") or rep.get("stageB_alive") or rep.get("stageA_alive")
        if alive is not None:
            return [[int(x) for x in alive.get(str(li), [])] for li in range(num_layers)]

    mlp = params_seed["blocks_0"]["mlp"]
    return [list(range(int(mlp[f"b_{li}"].shape[0]))) for li in range(num_layers)]


def _find_by_suffix(d, suffix):
    if isinstance(d, dict):
        for k, v in d.items():
            if isinstance(k, str) and k.endswith(suffix):
                if isinstance(v, list):
                    return v[0]
                if isinstance(v, dict):
                    return next(iter(v.values()))
                return v
            out = _find_by_suffix(v, suffix)
            if out is not None:
                return out
    elif isinstance(d, list):
        for x in d:
            out = _find_by_suffix(x, suffix)
            if out is not None:
                return out
    return None


def extract_attn_last_token(*, model, params: dict, x_full: jnp.ndarray, last_token_index: int = 1):
    _, inter = model.apply({"params": params}, x_full, training=False, mutable=["intermediates"])
    ints = inter["intermediates"]

    attn = _find_by_suffix(ints, "blocks_0/attn/hook_attn")
    if attn is None:
        attn = _find_by_suffix(ints, "hook_attn")
    if attn is None:
        raise KeyError("Could not find attention hook 'hook_attn' in intermediates.")

    attn = np.asarray(attn)
    B, H, Tq, Ts = attn.shape
    q = int(last_token_index)
    if q < 0 or q >= Tq:
        raise ValueError(f"last_token_index={q} out of range for attn T={Tq}")

    attn_to_a = attn[:, :, q, 0]
    attn_to_b = attn[:, :, q, 1]
    return attn_to_a, attn_to_b, attn


def _extract_hook_pre_all_layers(model, params, x_full, num_mlp_layers: int):
    _, inter = model.apply({"params": params}, x_full, training=False, mutable=["intermediates"])
    ints = inter["intermediates"]
    outs = []
    for l in range(1, num_mlp_layers + 1):
        suffix = f"blocks_0/mlp/hook_pre{l}"
        arr = _find_by_suffix(ints, suffix)
        if arr is None:
            arr = _find_by_suffix(ints, f"hook_pre{l}")
        outs.append(arr)
    return outs


def _extract_hook_post_all_layers(model, params, x_full, num_mlp_layers: int):
    _, inter = model.apply({"params": params}, x_full, training=False, mutable=["intermediates"])
    ints = inter["intermediates"]
    outs = []
    for l in range(1, num_mlp_layers + 1):
        suffix = f"blocks_0/mlp/hook_post{l}"
        arr = _find_by_suffix(ints, suffix)
        if arr is None:
            arr = _find_by_suffix(ints, f"hook_post{l}")
        outs.append(arr)
    return outs


def extract_preacts_last_token(*, model, params: dict, x_full: jnp.ndarray, num_mlp_layers: int, last_token_index: int = 1):
    pre_list = _extract_hook_pre_all_layers(model, params, x_full, num_mlp_layers)
    preacts = [np.asarray(pre)[:, last_token_index, :] for pre in pre_list]
    return preacts


def extract_postacts_last_token(*, model, params: dict, x_full: jnp.ndarray, num_mlp_layers: int, last_token_index: int = 1):
    post_list = _extract_hook_post_all_layers(model, params, x_full, num_mlp_layers)
    postacts = [np.asarray(post)[:, last_token_index, :] for post in post_list]
    return postacts


def _find_hook_path(d, suffix, path=None):
    if path is None:
        path = []
    if isinstance(d, dict):
        for k, v in d.items():
            new_path = path + [k]
            if isinstance(k, str) and k.endswith(suffix):
                return new_path
            res = _find_hook_path(v, suffix, new_path)
            if res is not None:
                return res
    return None


def compute_injected_accuracy_transformer(
    *,
    model,
    params: dict,
    x_eval: jnp.ndarray,
    y_eval: jnp.ndarray,
    layer_idx: int,
    fitted_pre_last: np.ndarray,
    last_token_index: int = 1,
) -> float:
    suffix = f"hook_pre{layer_idx}"

    _, inter = model.apply({"params": params}, x_eval, training=False, mutable=["intermediates"])
    ints = inter["intermediates"]

    hook_arr = _find_by_suffix(ints, suffix)
    if hook_arr is None:
        raise KeyError(f"Could not find any intermediate ending with '{suffix}'.")

    hook_arr = np.asarray(hook_arr)
    B, T, width_l = hook_arr.shape
    if fitted_pre_last.shape != (B, width_l):
        raise ValueError(f"fitted_pre_last shape {fitted_pre_last.shape} != ({B}, {width_l}) for layer {layer_idx}")

    new_hook = hook_arr.copy()
    new_hook[:, last_token_index, :] = fitted_pre_last

    hook_path = _find_hook_path(ints, suffix, [])
    if not hook_path:
        raise KeyError(f"Could not find path for intermediate ending with '{suffix}'.")

    new_val = jnp.asarray(new_hook)
    for key in reversed(hook_path):
        if key == hook_path[-1]:
            new_val = {key: [new_val]}
        else:
            new_val = {key: new_val}

    vars2 = freeze({"params": params, "intermediates": new_val})
    logits_inj = model.apply(vars2, x_eval, training=False)
    logits_last = np.asarray(logits_inj)[:, last_token_index, :]

    preds = logits_last.argmax(axis=-1)
    y_true = np.asarray(y_eval)
    acc = float((preds == y_true).mean())
    return acc


def build_X_in_and_halves(model, params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    X_in = np.asarray(model.all_p_squared_embeddings(params))
    D2 = X_in.shape[1]
    assert D2 % 2 == 0
    D = D2 // 2
    return X_in, X_in[:, :D], X_in[:, D:]


def compute_first_layer_ab_contribs_transformer(
    *,
    model,
    params: dict,
    group_size: int,
    X_full,
    last_token_index: int = 1,
    bias_mode: str = "b",
    target_chunk_mb: int = 256,
    use_bfloat16: bool = True,
):
    Bsize = int(group_size) * int(group_size)

    Wa, Wb = model.extract_embeddings_ab(params)
    pos0, pos1 = np.asarray(params["pos_embed"]["W_pos"][:2])
    Wb = np.asarray(Wb)

    W_O = jnp.asarray(params["blocks_0"]["attn"]["W_O"])
    W0 = jnp.asarray(params["blocks_0"]["mlp"]["W_0"])
    b0 = jnp.asarray(params["blocks_0"]["mlp"]["b_0"])
    d_mlp = int(W0.shape[0])

    pre_from_a = np.empty((Bsize, d_mlp), dtype=np.float32)
    pre_from_b = np.empty((Bsize, d_mlp), dtype=np.float32)
    pre_total = np.empty((Bsize, d_mlp), dtype=np.float32)
    pre_hook = np.empty((Bsize, d_mlp), dtype=np.float32)

    bytes_per_elem = 2 if use_bfloat16 else 4
    denom = max(1, 4 * d_mlp * bytes_per_elem)
    chunk = max(256, (target_chunk_mb * 1024 * 1024) // denom)
    chunk = int(min(chunk, Bsize))

    q = int(last_token_index)
    _dtype = jnp.bfloat16 if use_bfloat16 else jnp.float32

    for i0 in range(0, Bsize, chunk):
        i1 = min(Bsize, i0 + chunk)

        x_chunk = jnp.asarray(X_full[i0:i1, :], dtype=jnp.int32)
        b_idx = np.asarray(X_full[i0:i1, 1])

        Eb_pos1 = np.asarray(Wb[b_idx] + pos1, dtype=np.float32)
        Eb_pos1 = jnp.asarray(Eb_pos1)

        _, inter = model.apply({"params": params}, x_chunk, training=False, mutable=["intermediates"])
        ints = inter["intermediates"]

        attn = _find_by_suffix(ints, "blocks_0/attn/hook_attn")
        v = _find_by_suffix(ints, "blocks_0/attn/hook_v")
        if attn is None:
            attn = _find_by_suffix(ints, "hook_attn")
        if v is None:
            v = _find_by_suffix(ints, "hook_v")
        if attn is None or v is None:
            raise KeyError("Could not find attention hooks (hook_attn, hook_v).")

        attn = jnp.asarray(attn, _dtype)
        v = jnp.asarray(v, _dtype)

        z_from_a = v[:, :, 0, :] * attn[:, :, q, 0][..., None]
        z_from_b = v[:, :, 1, :] * attn[:, :, q, 1][..., None]

        zfa = z_from_a.reshape(z_from_a.shape[0], -1)
        zfb = z_from_b.reshape(z_from_b.shape[0], -1)
        attn_from_a = jnp.einsum("df,bf->bd", W_O, zfa)
        attn_from_b = jnp.einsum("df,bf->bd", W_O, zfb)

        xmid_from_a = attn_from_a
        xmid_from_b = Eb_pos1 + attn_from_b

        pre_a_chunk = jnp.einsum("md,bd->bm", W0, xmid_from_a)
        pre_b_chunk = jnp.einsum("md,bd->bm", W0, xmid_from_b)

        if bias_mode == "b":
            pre_b_chunk = pre_b_chunk + b0
        elif bias_mode == "even":
            half = 0.5 * b0
            pre_a_chunk = pre_a_chunk + half
            pre_b_chunk = pre_b_chunk + half
        elif bias_mode == "none":
            pass
        else:
            raise ValueError("bias_mode must be one of {'b','even','none'}")

        pre1_hook = _find_by_suffix(ints, "blocks_0/mlp/hook_pre1")
        if pre1_hook is None:
            pre1_hook = _find_by_suffix(ints, "hook_pre1")
        if pre1_hook is None:
            raise KeyError("Could not find 'hook_pre1' in intermediates.")
        pre1_hook = jnp.asarray(pre1_hook)[:, q, :]

        pa = np.asarray(pre_a_chunk, dtype=np.float32)
        pb = np.asarray(pre_b_chunk, dtype=np.float32)
        ph = np.asarray(pre1_hook, dtype=np.float32)

        pre_from_a[i0:i1, :] = pa
        pre_from_b[i0:i1, :] = pb
        pre_total[i0:i1, :] = pa + pb
        pre_hook[i0:i1, :] = ph

        del attn, v, z_from_a, z_from_b, zfa, zfb, attn_from_a, attn_from_b
        del xmid_from_a, xmid_from_b, pre_a_chunk, pre_b_chunk, pre1_hook

    return pre_from_a, pre_from_b, pre_total, pre_hook


def cluster_contribs_last_layer_transformer(
    *,
    preacts_last: np.ndarray,
    params: dict,
    clusters_last_layer: dict[int, list[int]],
):
    mlp_params = params["blocks_0"]["mlp"]
    W_out = np.asarray(mlp_params["W_out"])
    W_U = np.asarray(params["W_U"])

    eff_W_dmodel = W_out.T
    eff_W_logits = W_out.T @ W_U

    Z_last = np.asarray(preacts_last)
    H_last = np.maximum(Z_last, 0.0)

    d_mlp = H_last.shape[1]

    contribs_to_dmodel: dict[int, np.ndarray] = {}
    contribs_to_logits: dict[int, np.ndarray] = {}
    Wblocks_to_logits: dict[int, np.ndarray] = {}
    Wblocks_to_dmodel: dict[int, np.ndarray] = {}

    for f, ids in (clusters_last_layer or {}).items():
        if not ids:
            continue
        ids = np.asarray(ids, dtype=int)

        mask = (ids >= 0) & (ids < d_mlp)
        if not np.all(mask):
            ids = ids[mask]
            if ids.size == 0:
                continue

        Hc = H_last[:, ids]
        Wd = eff_W_dmodel[ids, :]
        Wl = eff_W_logits[ids, :]

        contribs_to_dmodel[f] = Hc @ Wd
        contribs_to_logits[f] = Hc @ Wl
        Wblocks_to_logits[f] = Wl
        Wblocks_to_dmodel[f] = Wd

    return contribs_to_dmodel, contribs_to_logits, Wblocks_to_logits, Wblocks_to_dmodel


def run_post_training_analysis(
    *,
    model,
    states,
    random_seed_ints: List[int],
    p: int,
    group_size: int,
    num_layers: int,
    mdir: str,
    class_lower: str = "transformer",
    colour_rule=None,
    dmodel: int | None = None,
    alive_by_layer_override: dict[int, list[list[int]]] | None = None,
    write_json: bool = False,
    write_pdfs: bool = False,
    do_layerwise_sinefit_ablation: bool = False,
):
    G, irreps = DFT.make_irreps_Dn(p)
    freq_map = {name: freq for (name, dim, R, freq) in irreps}
    rho_cache = DFT.build_rho_cache(G, irreps)
    dft_fn = DFT.jit_wrap_group_dft(rho_cache, irreps, group_size)

    subgroups = dihedral.enumerate_subgroups_Dn(p)
    coset_masks_L = dihedral.build_coset_masks(G, subgroups, dihedral.mult, p, side="left")
    coset_masks_R = dihedral.build_coset_masks(G, subgroups, dihedral.mult, p, side="right")

    x_eval_full, y_eval_full = make_full_eval_grid(p)

    for seed_idx, seed in enumerate(random_seed_ints):
        print(f"\n=== Transformer post-training analysis for seed {seed} ===")
        gdir = paths.seed_graph_dir(mdir, seed)
        os.makedirs(gdir, exist_ok=True)

        params_seed = jax.tree_util.tree_map(lambda x: x[seed_idx], states.params)

        if alive_by_layer_override is not None and seed in alive_by_layer_override:
            alive_by_layer = alive_by_layer_override[seed]
        else:
            alive_by_layer = _load_alive_indices_for_seed(
                prune_dir=mdir,
                features_or_dm=dmodel,
                seed=seed,
                num_layers=num_layers,
                params_seed=params_seed,
            )

        for li in range(num_layers):
            width_li = int(params_seed["blocks_0"]["mlp"][f"b_{li}"].shape[0])
            bad = [i for i in alive_by_layer[li] if i < 0 or i >= width_li]
            if bad:
                raise ValueError(f"[seed={seed}] alive_by_layer[{li}] includes bad index: {bad} (width={width_li})")

        layers_freq: List[Dict[int, list]] = []
        layers_freq_post: List[Dict[int, list]] = []

        pre_acts_all = extract_preacts_last_token(
            model=model,
            params=params_seed,
            x_full=x_eval_full,
            num_mlp_layers=num_layers,
            last_token_index=1,
        )
        post_acts_all = extract_postacts_last_token(
            model=model,
            params=params_seed,
            x_full=x_eval_full,
            num_mlp_layers=num_layers,
            last_token_index=1,
        )

        pre_a, pre_b, pre_sum, pre_hook = compute_first_layer_ab_contribs_transformer(
            model=model,
            params=params_seed,
            group_size=group_size,
            X_full=x_eval_full,
            last_token_index=1,
            bias_mode="b",
        )

        thresh_small = 2.0 if group_size < 50 else 3.0
        cluster_tau = 1e-3

        layer0_artifacts = None
        f_pool_L1 = None
        if do_layerwise_sinefit_ablation:
            sinefit_acc_log: dict[str, float] = {}

        for layer_idx in range(num_layers):
            prei_full = pre_acts_all[layer_idx]
            alive_ids = alive_by_layer[layer_idx]

            prei = prei_full[:, alive_ids]
            prei_grid = prei.reshape(group_size, group_size, -1)
            pre_a_alive = pre_a[:, alive_ids]
            pre_b_alive = pre_b[:, alive_ids]

            artifacts = report.prepare_layer_artifacts(
                prei_grid,
                pre_a_alive,
                pre_b_alive,
                dft_fn,
                irreps,
                freq_map,
                prune_cfg={"thresh1": thresh_small, "thresh2": thresh_small, "seed": 0},
            )

            local_clusters = artifacts.get("freq_cluster", {}) or {}
            clusters_layer = {freq: [alive_ids[j] for j in ids] for freq, ids in local_clusters.items()}
            layers_freq.append(clusters_layer)

            diag_labels = artifacts["diag_labels"]
            names = artifacts["names"]
            approx = report.summarize_diag_labels(diag_labels, p, names)

            cluster_sizes_main = {}
            pruned = artifacts.get("cluster_prune", {}) or {}
            for freq, orig_ids in clusters_layer.items():
                if freq in pruned and "main" in pruned[freq]:
                    cluster_sizes_main[str(freq)] = int(len(pruned[freq]["main"]))
                else:
                    cluster_sizes_main[str(freq)] = int(len(orig_ids))
            approx["cluster_sizes_main"] = cluster_sizes_main

            with open(os.path.join(gdir, f"approx_summary_layer{layer_idx+1}_p{p}.json"), "w") as f:
                json.dump(approx, f, indent=2)

            report_dir = os.path.join(gdir, f"report_layer{layer_idx+1}")
            os.makedirs(report_dir, exist_ok=True)

            if layer_idx == 0:
                colour_rule = colour_quad_a_only
                use_pair = layer_idx > 0
                layer0_artifacts = artifacts
                try:
                    f_pool_L1 = report.build_f_pool_from_layer1_artifacts(layer0_artifacts, freq_map, p)
                    f_pool_L1 = list(f_pool_L1)
                except Exception as e:
                    print(f"[sinefit] build_f_pool_from_layer1_artifacts failed: {e}")
                    f_pool_L1 = None

                report.make_layer_report(
                    prei_grid,
                    pre_a_alive,
                    pre_b_alive,
                    p,
                    dft_fn,
                    irreps,
                    coset_masks_L,
                    coset_masks_R,
                    report_dir,
                    cluster_tau,
                    colour_rule,
                    artifacts,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                )

                report.make_R2_c_angle_report_no_plot(
                    prei_grid,
                    p,
                    report_dir,
                    artifacts=artifacts,
                    base_layer_artifacts=layer0_artifacts,
                    base_f_pool=f_pool_L1,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                    use_pair_terms=use_pair,
                )

                report.make_stripe_report_only_phase_pdf(prei_grid, p, report_dir, artifacts)
                report.epsilon_analysis(prei_grid, p, report_dir, artifacts)

                if do_layerwise_sinefit_ablation:
                    base_layer_artifacts = layer0_artifacts if layer0_artifacts is not None else artifacts

                    fitted_full = R2._sinefit_preact_layer_full(
                        p=group_size,
                        pre_full=prei_full,
                        pre_grid_alive=prei_grid,
                        alive_ids=alive_ids,
                        artifacts=artifacts,
                        base_layer_artifacts=base_layer_artifacts,
                        freq_map=freq_map,
                        base_f_pool=f_pool_L1,
                        use_axes_only=True,
                        use_pair_terms=use_pair,
                    )

                    acc_inj = compute_injected_accuracy_transformer(
                        model=model,
                        params=params_seed,
                        x_eval=x_eval_full,
                        y_eval=y_eval_full,
                        layer_idx=layer_idx + 1,
                        fitted_pre_last=fitted_full,
                        last_token_index=1,
                    )
                    print(f"[sinefit-ablation][seed={seed}] layer {layer_idx+1}: full-grid accuracy = {acc_inj:.6f}")
                    sinefit_acc_log[str(layer_idx + 1)] = float(acc_inj)

            else:
                colour_rule = colour_c_mod_p
                use_pair = layer_idx > 0

                report.make_layer_report(
                    prei_grid,
                    pre_a_alive,
                    pre_b_alive,
                    p,
                    dft_fn,
                    irreps,
                    coset_masks_L,
                    coset_masks_R,
                    report_dir,
                    cluster_tau,
                    colour_rule,
                    artifacts,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                )

                report.make_R2_c_angle_report_no_plot(
                    prei_grid,
                    p,
                    report_dir,
                    artifacts=artifacts,
                    base_layer_artifacts=layer0_artifacts,
                    base_f_pool=f_pool_L1,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                    use_pair_terms=use_pair,
                )

                if do_layerwise_sinefit_ablation:
                    fitted_full = R2._sinefit_preact_layer_full(
                        p=group_size,
                        pre_full=prei_full,
                        pre_grid_alive=prei_grid,
                        alive_ids=alive_ids,
                        artifacts=artifacts,
                        base_layer_artifacts=layer0_artifacts,
                        base_f_pool=f_pool_L1,
                        freq_map=freq_map,
                        use_axes_only=True,
                        use_pair_terms=use_pair,
                    )

                    acc_inj = compute_injected_accuracy_transformer(
                        model=model,
                        params=params_seed,
                        x_eval=x_eval_full,
                        y_eval=y_eval_full,
                        layer_idx=layer_idx + 1,
                        fitted_pre_last=fitted_full,
                        last_token_index=1,
                    )
                    print(f"[sinefit-ablation][seed={seed}] layer {layer_idx+1}: full-grid accuracy = {acc_inj:.6f}")
                    sinefit_acc_log[str(layer_idx + 1)] = float(acc_inj)

        if do_layerwise_sinefit_ablation:
            out_path = os.path.join(gdir, f"sinefit_injected_accuracy_p{p}.json")
            payload = {
                "seed": int(seed),
                "p": int(p),
                "group_size": int(group_size),
                "num_layers": int(num_layers),
                "accuracy_by_layer": sinefit_acc_log,
            }
            with open(out_path, "w") as f:
                json.dump(payload, f, indent=2)
            print(f"[sinefit-ablation][seed={seed}] saved layerwise injected accuracies -> {out_path}")

        last_layer_clusters = layers_freq[-1]
        contribs_dmodel, contribs_logits, Wblocks_logits, Wblocks_dmodel = cluster_contribs_last_layer_transformer(
            preacts_last=pre_acts_all[-1],
            params=params_seed,
            clusters_last_layer=last_layer_clusters,
        )

        if write_json:
            _ = make_some_jsons(
                preacts=pre_acts_all,
                group_size=group_size,
                clusters_by_layer=layers_freq,
                cluster_weights_to_logits=Wblocks_logits,
                cluster_weights_to_dmodel=Wblocks_dmodel,
                cluster_contribs_to_logits=contribs_logits,
                cluster_contribs_to_dmodel=contribs_dmodel,
                save_dir=gdir,
                subdir="json_preacts",
                sanity_check=True,
            )
            print(f"[Transformer] cluster JSONs written -> {os.path.join(gdir, 'json_preacts')}")

        if write_pdfs and contribs_logits:
            pdf_root = os.path.join(gdir, "pdf_plots", f"seed_{seed}")
            os.makedirs(pdf_root, exist_ok=True)
            num_pc = 4
            for freq, C_freq in contribs_logits.items():
                generate_pdf_plots_for_matrix(
                    C_freq,
                    p,
                    save_dir=pdf_root,
                    seed=seed,
                    freq_list=[freq],
                    tag=f"cluster_contributions_to_logits_freq={freq}",
                    tag_q="full",
                    colour_rule=colour_quad_mod_g,
                    class_string=class_lower,
                    num_principal_components=num_pc,
                )
            print(f"[Transformer] PDF plots written -> {pdf_root}")
