import os
import json
import math
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import jax
import jax.numpy as jnp

import DFT
import dihedral
import report
import analysis.R2 as R2
from mlp_models_multilayer import DonutMLP
import controllers.paths_MLP as paths
from controllers.prep_data_train_eval import eval_model, make_full_eval_grid
from pca_diffusion_plots_w_helpers import generate_pdf_plots_for_matrix
from color_rules import colour_quad_a_only, colour_c_mod_p


def final_eval_all_models(*, states, x_eval_batches, y_eval_batches, init_metrics, random_seed_ints: List[int]):
    test_metrics = eval_model(states, x_eval_batches, y_eval_batches, init_metrics)
    results = {}
    for i, seed in enumerate(random_seed_ints):
        tm = jax.tree_util.tree_map(lambda x: x[i], test_metrics).compute()
        reached = float(tm["accuracy"]) == 1.0
        results[seed] = {
            "reach_100pct_test": reached,
            "loss": float(tm["loss"]),
            "l2_loss": float(tm["l2_loss"]),
            "accuracy": float(tm["accuracy"]),
        }
    return results


def save_epoch_logs(log_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in log_by_seed.items():
        path = os.path.join(out_dir, f"log_features_{features}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"Epoch log for seed {seed} saved to {path}")


def save_final_logs(log_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in log_by_seed.items():
        path = os.path.join(out_dir, f"final_log_features_{features}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"Final log for seed {seed} saved to {path}")


def save_prune_logs(log_by_seed: Dict[int, Dict[int, Dict]], out_dir: str, features: int):
    os.makedirs(out_dir, exist_ok=True)
    for seed, logs in log_by_seed.items():
        path = os.path.join(out_dir, f"prune_log_features_{features}_seed_{seed}.json")
        with open(path, "w") as f:
            json.dump(logs, f, indent=2)
        print(f"Prune log for seed {seed} saved to {path}")


def _load_alive_indices_for_seed(
    prune_dir: str,
    features: int,
    seed: int,
    *,
    num_layers: int,
    params_seed: dict,
) -> list[list[int]]:
    prune_path = os.path.join(prune_dir, f"prune_log_features_{features}_seed_{seed}.json")
    if os.path.exists(prune_path):
        with open(prune_path, "r") as f:
            rep = json.load(f)
        alive_map = rep.get("alive_final") or rep.get("stageB_alive") or rep.get("stageA_alive")
        if alive_map is not None:
            out: list[list[int]] = []
            for li in range(num_layers):
                out.append([int(x) for x in alive_map.get(str(li), [])])
            return out

    out: list[list[int]] = []
    for li in range(1, num_layers + 1):
        width = int(params_seed[f"dense_{li}"]["bias"].shape[0])
        out.append(list(range(width)))
    return out


def save_preacts_json_chunked(
    *,
    preacts_by_layer: list[np.ndarray],
    alive_by_layer: list[list[int]] | None,
    out_dir: str,
    seed: int | str,
    p: int,
    float_dtype=np.float32,
    chunk_neurons: int = 256,
    include_full_index: bool = True,
) -> str:
    root = os.path.join(out_dir, "preacts_json", f"seed_{seed}")
    os.makedirs(root, exist_ok=True)

    B_expected = p * p
    for li, Z_full in enumerate(preacts_by_layer, start=1):
        Z_full = np.asarray(Z_full)
        B, W = Z_full.shape
        if B != B_expected:
            raise ValueError(f"[save_preacts_json_chunked] layer{li}: B={B} != p^2={B_expected}")

        if alive_by_layer is None:
            neuron_ids = list(range(W))
        else:
            neuron_ids = [int(x) for x in alive_by_layer[li - 1] if 0 <= int(x) < W]
            if not neuron_ids:
                continue

        Z = Z_full[:, neuron_ids]
        N = Z.shape[1]

        layer_dir = os.path.join(root, f"layer{li}")
        os.makedirs(layer_dir, exist_ok=True)

        for c0 in range(0, N, chunk_neurons):
            c1 = min(N, c0 + chunk_neurons)
            ids_chunk = neuron_ids[c0:c1]
            Z_chunk = Z[:, c0:c1].astype(float_dtype)

            payload = {
                "meta": {
                    "seed": int(seed) if str(seed).isdigit() else str(seed),
                    "layer": int(li),
                    "p": int(p),
                    "B": int(B),
                    "width_full": int(W),
                    "n_saved": int(N),
                    "chunk_range": [int(c0), int(c1)],
                    "dtype": str(np.dtype(float_dtype)),
                },
                "neuron_ids": ids_chunk if include_full_index else list(range(c1 - c0)),
                "preacts": Z_chunk.tolist(),
            }
            out_path = os.path.join(layer_dir, f"chunk_{c0:06d}_{c1:06d}.json")
            with open(out_path, "w") as f:
                json.dump(payload, f)

    return root


def get_all_preacts_and_embeddings(
    *,
    model: DonutMLP,
    params: dict,
    group_size: int | None = None,
    clusters_by_layer: list[dict[int, list[int]]] | None = None,
):
    if clusters_by_layer is None:
        raise ValueError("clusters_by_layer cannot be None")

    group_size = group_size or model.group_size
    X_in = model.all_p_squared_embeddings(params)

    _, preacts = model.call_from_embedding(jnp.asarray(X_in), params)
    preacts_np = [np.asarray(layer) for layer in preacts]
    H_last = np.maximum(preacts_np[-1], 0.0)

    weights_np = [np.asarray(params[f"dense_{l}"]["kernel"]) for l in range(1, model.num_layers + 1)]

    W_out = np.asarray(params["output_dense"]["kernel"])
    cluster_contribs: dict[int, np.ndarray] = {}
    cluster_weights: dict[int, np.ndarray] = {}
    last_layer_clusters = clusters_by_layer[-1]
    for freq, neuron_ids in last_layer_clusters.items():
        if not neuron_ids:
            continue
        H_cluster = H_last[:, neuron_ids]
        W_block = W_out[neuron_ids, :]
        C_freq = H_cluster @ W_block
        cluster_contribs[freq] = C_freq
        cluster_weights[freq] = W_block

    return preacts_np, X_in, weights_np, cluster_contribs, cluster_weights


def make_some_jsons(
    *,
    preacts: list[np.ndarray],
    group_size: int,
    clusters_by_layer: list[dict[int, list[int]]],
    cluster_weights_to_logits: dict[int, np.ndarray],
    save_dir: str,
    subdir: str = "json",
    float_dtype=np.float32,
    sanity_check: bool = True,
    cluster_contribs_to_logits: dict[int, np.ndarray] | None = None,
) -> str:
    if not preacts:
        raise ValueError("make_some_jsons: empty preacts.")
    Z_last = np.asarray(preacts[-1])
    n_rows, width_last = Z_last.shape
    if n_rows != group_size * group_size:
        raise ValueError(
            f"make_some_jsons: expected group_size^2={group_size*group_size} rows, got {n_rows}."
        )
    if not clusters_by_layer:
        raise ValueError("make_some_jsons: empty clusters_by_layer.")

    last_layer_clusters = clusters_by_layer[-1] or {}
    if not isinstance(last_layer_clusters, dict):
        raise TypeError("make_some_jsons: clusters_by_layer[-1] must be a dict {freq: [neuron_ids]}.")

    json_root = os.path.join(save_dir, subdir)
    os.makedirs(json_root, exist_ok=True)

    for freq, neuron_ids in last_layer_clusters.items():
        if not neuron_ids:
            continue

        W_block = cluster_weights_to_logits.get(freq, None)
        if W_block is None:
            continue
        W_block = np.asarray(W_block)

        ids = np.asarray(neuron_ids, dtype=int)
        valid_mask = (ids >= 0) & (ids < width_last)
        if not np.all(valid_mask):
            bad = ids[~valid_mask].tolist()
            ids = ids[valid_mask]
            W_block = W_block[valid_mask, :]
            if ids.size == 0:
                continue
            print(f"[make_some_jsons] freq={freq}: dropped invalid neuron ids {bad}")

        if W_block.shape[0] != ids.shape[0]:
            raise ValueError(
                f"make_some_jsons: for freq={freq}, W_block rows ({W_block.shape[0]}) "
                f"!= number of neuron ids ({ids.shape[0]})."
            )
        if W_block.shape[1] != group_size:
            raise ValueError(
                f"make_some_jsons: for freq={freq}, W_block has {W_block.shape[1]} columns, expected group_size={group_size}."
            )

        Z_cluster = Z_last[:, ids]
        H_cluster = np.maximum(Z_cluster, 0.0)
        contribs = H_cluster[:, :, None] * W_block[None, :, :]

        if sanity_check and (cluster_contribs_to_logits is not None):
            C_freq_expected = np.asarray(cluster_contribs_to_logits.get(freq))
            if C_freq_expected is not None and C_freq_expected.size:
                C_sum = contribs.sum(axis=1)
                if C_freq_expected.shape != C_sum.shape:
                    raise ValueError(
                        f"make_some_jsons: cluster_contribs_to_logits[{freq}] has shape {C_freq_expected.shape}, "
                        f"expected {C_sum.shape}."
                    )
                if not np.allclose(C_sum, C_freq_expected, rtol=1e-5, atol=1e-6):
                    raise ValueError(
                        f"make_some_jsons: contribution mismatch for freq={freq} (sum of per-neuron != cluster total)."
                    )

        payload = {}
        for j, nid in enumerate(ids.tolist()):
            payload[str(int(nid))] = {
                "preactivations": Z_cluster[:, j].astype(float_dtype).tolist(),
                "w_out": W_block[j, :].astype(float_dtype).tolist(),
                "contribs_to_logits": contribs[:, j, :].astype(float_dtype).tolist(),
            }

        out_path = os.path.join(json_root, f"cluster_{freq}.json")
        with open(out_path, "w") as f:
            json.dump(payload, f)

    return json_root


def apply_in_batches(model, params_seed, x_all, batch=4096):
    pre_acts_all_acc = None
    left_all, right_all = None, None
    for s in range(0, x_all.shape[0], batch):
        x = x_all[s : s + batch]
        _, preacts, left, right = model.apply({"params": params_seed}, x, training=False)
        preacts = [np.asarray(z) for z in preacts]
        if pre_acts_all_acc is None:
            pre_acts_all_acc = [p.copy() for p in preacts]
            left_all = np.asarray(left)
            right_all = np.asarray(right)
        else:
            pre_acts_all_acc = [np.concatenate([a, b], axis=0) for a, b in zip(pre_acts_all_acc, preacts)]
            left_all = np.concatenate([left_all, np.asarray(left)], axis=0)
            right_all = np.concatenate([right_all, np.asarray(right)], axis=0)
        jax.block_until_ready(left)
        jax.block_until_ready(right)
        del preacts, left, right, x
    return pre_acts_all_acc, left_all, right_all


def _infer_num_layers_from_params(params: dict) -> int:
    layer_ids = []
    for k in params.keys():
        if k.startswith("dense_"):
            try:
                lid = int(k.split("_")[1])
                layer_ids.append(lid)
            except Exception:
                continue
    if not layer_ids:
        raise ValueError("Could not infer num_layers: no 'dense_i' keys found in params.")
    return max(layer_ids)


def compute_injected_accuracy_mlp(
    *,
    params: dict,
    y_eval: np.ndarray | jnp.ndarray,
    layer_idx: int,
    fitted_pre_layer: np.ndarray,
) -> float:
    y_true = np.asarray(y_eval)
    pre = np.asarray(fitted_pre_layer, dtype=np.float32)

    B, width_l = pre.shape

    num_layers = _infer_num_layers_from_params(params)
    if not (1 <= layer_idx <= num_layers):
        raise ValueError(f"layer_idx={layer_idx} out of range; inferred num_layers={num_layers}")

    key_l = f"dense_{layer_idx}"
    if key_l not in params:
        raise KeyError(f"Params missing key '{key_l}'. Available keys: {list(params.keys())}")
    width_param = int(np.asarray(params[key_l]["bias"]).shape[0])
    if width_param != width_l:
        raise ValueError(
            f"fitted_pre_layer width={width_l} != params['{key_l}']['bias'].shape[0]={width_param}"
        )

    act = np.maximum(pre, 0.0).astype(np.float32)

    for l in range(layer_idx + 1, num_layers + 1):
        k = f"dense_{l}"
        if k not in params:
            raise KeyError(f"Params missing key '{k}' when propagating from layer {layer_idx}.")

        W = np.asarray(params[k]["kernel"], dtype=np.float32)
        b = np.asarray(params[k]["bias"], dtype=np.float32)

        pre = act @ W + b[None, :]
        act = np.maximum(pre, 0.0)

    if "output_dense" not in params:
        raise KeyError("Params missing 'output_dense' layer.")

    W_out = np.asarray(params["output_dense"]["kernel"], dtype=np.float32)
    b_out = params["output_dense"].get("bias", None)
    if b_out is not None:
        b_out = np.asarray(b_out, dtype=np.float32)

    logits = act @ W_out
    if b_out is not None:
        logits = logits + b_out[None, :]

    preds = logits.argmax(axis=-1)
    acc = float((preds == y_true).mean())
    return acc


def run_post_training_analysis(
    *,
    model,
    states,
    random_seed_ints: List[int],
    p: int,
    group_size: int,
    num_layers: int,
    mdir: str,
    mlp_class_lower: str,
    colour_rule=None,
    features: int | None = None,
    alive_by_layer_override: dict[int, list[list[int]]] | None = None,
    do_layerwise_sinefit_ablation: bool = False,
):
    G, irreps = DFT.make_irreps_Dn(p)
    freq_map = {}
    for name, dim, R, freq in irreps:
        freq_map[name] = freq
    rho_cache = DFT.build_rho_cache(G, irreps)
    dft_fn = DFT.jit_wrap_group_dft(rho_cache, irreps, group_size)
    subgroups = dihedral.enumerate_subgroups_Dn(p)

    coset_masks_L = dihedral.build_coset_masks(G, subgroups, dihedral.mult, p, side="left")
    coset_masks_R = dihedral.build_coset_masks(G, subgroups, dihedral.mult, p, side="right")
    prune_dir = mdir
    x_eval_full, y_eval_full = make_full_eval_grid(p)

    for seed_idx, seed in enumerate(random_seed_ints):
        print(f"\n=== Post-training analysis (alive-only, original IDs) for seed {seed} ===")
        gdir = paths.seed_graph_dir(mdir, f"{seed}")
        os.makedirs(gdir, exist_ok=True)

        params_seed = jax.tree_util.tree_map(lambda x: x[seed_idx], states.params)

        if alive_by_layer_override is not None and seed in alive_by_layer_override:
            alive_by_layer = alive_by_layer_override[seed]
        else:
            alive_by_layer = _load_alive_indices_for_seed(
                prune_dir=mdir,
                features=features,
                seed=seed,
                num_layers=num_layers,
                params_seed=params_seed,
            )

        pre_acts_all, left, right = apply_in_batches(model, params_seed, x_eval_full, batch=4096)
        pre_acts_all = [np.asarray(z) for z in pre_acts_all]
        post_acts_all = [np.maximum(z, 0.0) for z in pre_acts_all]

        layers_freq: List[Dict[int, list]] = []
        layers_freq_post: List[Dict[int, list]] = []

        cluster_tau = 1e-3
        thresh_small = 1.7 if group_size < 50 else 1.8
        layer0_artifacts = None
        f_pool_L1 = None
        if do_layerwise_sinefit_ablation:
            sinefit_acc_log: dict[str, float] = {}

        for layer_idx in range(num_layers):
            prei_full = pre_acts_all[layer_idx]
            posti_full = post_acts_all[layer_idx]
            alive_ids = alive_by_layer[layer_idx]

            if len(alive_ids) == 0:
                layers_freq.append({})
                layers_freq_post.append({})
                with open(os.path.join(gdir, f"approx_summary_layer{layer_idx+1}_p{p}.json"), "w") as f:
                    json.dump({}, f, indent=2)
                with open(os.path.join(gdir, f"approx_summary_layer{layer_idx+1}_p{p}_postacts.json"), "w") as f:
                    json.dump({}, f, indent=2)
                continue

            prei = prei_full[:, alive_ids]
            prei_grid = prei.reshape(group_size, group_size, -1)
            left_alive = left[:, alive_ids]
            right_alive = right[:, alive_ids]
            assert prei_grid.shape[-1] == left_alive.shape[1] == right_alive.shape[1], (
                prei_grid.shape,
                left_alive.shape,
                right_alive.shape,
            )

            preacts_json_dir = save_preacts_json_chunked(
                preacts_by_layer=pre_acts_all,
                alive_by_layer=alive_by_layer,
                out_dir=gdir,
                seed=seed,
                p=group_size,
                float_dtype=np.float32,
                chunk_neurons=256,
            )
            print(f"[preacts] saved chunked JSONs -> {preacts_json_dir}")

            artifacts = report.prepare_layer_artifacts(
                prei_grid,
                left_alive,
                right_alive,
                dft_fn,
                irreps,
                freq_map,
                prune_cfg={"thresh1": thresh_small, "thresh2": thresh_small, "seed": 0},
            )

            local_clusters = artifacts.get("freq_cluster", {}) or {}
            clusters_layer = {freq: [alive_ids[j] for j in ids] for freq, ids in local_clusters.items()}
            layers_freq.append(clusters_layer)

            diag_labels = artifacts["diag_labels"]
            names = artifacts["names"]
            approx = report.summarize_diag_labels(diag_labels, p, names)

            cluster_sizes_main = {}
            pruned = artifacts.get("cluster_prune", {}) or {}
            for freq, orig_ids in clusters_layer.items():
                if freq in pruned and "main" in pruned[freq]:
                    cluster_sizes_main[str(freq)] = int(len(pruned[freq]["main"]))
                else:
                    cluster_sizes_main[str(freq)] = int(len(orig_ids))
            approx["cluster_sizes_main"] = cluster_sizes_main

            with open(os.path.join(gdir, f"approx_summary_layer{layer_idx+1}_p{p}.json"), "w") as f:
                json.dump(approx, f, indent=2)

            report_dir = os.path.join(gdir, f"report_layer{layer_idx+1}")
            os.makedirs(report_dir, exist_ok=True)

            if layer_idx == 0:
                colour_rule = colour_quad_a_only
                use_pair = layer_idx > 0
                layer0_artifacts = artifacts
                try:
                    f_pool_L1 = report.build_f_pool_from_layer1_artifacts(layer0_artifacts, freq_map, p)
                    f_pool_L1 = list(f_pool_L1)
                except Exception as e:
                    print(f"[sinefit] build_f_pool_from_layer1_artifacts failed: {e}")
                    f_pool_L1 = None

                report.make_layer_report(
                    prei_grid,
                    left_alive,
                    right_alive,
                    p,
                    dft_fn,
                    irreps,
                    coset_masks_L,
                    coset_masks_R,
                    report_dir,
                    cluster_tau,
                    colour_rule,
                    artifacts,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                )

                report.make_stripe_report_only_phase_pdf(prei_grid, p, report_dir, artifacts)
                report.epsilon_analysis(prei_grid, p, report_dir, artifacts)

                report.make_R2_c_angle_report_no_plot(
                    prei_grid,
                    p,
                    report_dir,
                    artifacts=artifacts,
                    base_layer_artifacts=layer0_artifacts,
                    base_f_pool=f_pool_L1,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                    use_pair_terms=use_pair,
                )

                if do_layerwise_sinefit_ablation:
                    base_layer_artifacts = layer0_artifacts if layer0_artifacts is not None else artifacts

                    fitted_full = R2._sinefit_preact_layer_full(
                        p=group_size,
                        pre_full=prei_full,
                        pre_grid_alive=prei_grid,
                        alive_ids=alive_ids,
                        artifacts=artifacts,
                        base_layer_artifacts=base_layer_artifacts,
                        freq_map=freq_map,
                        base_f_pool=f_pool_L1,
                        use_axes_only=True,
                        use_pair_terms=use_pair,
                    )

                    acc_inj = compute_injected_accuracy_mlp(
                        params=params_seed,
                        y_eval=y_eval_full,
                        layer_idx=layer_idx + 1,
                        fitted_pre_layer=fitted_full,
                    )
                    print(f"[sinefit-ablation][seed={seed}] layer {layer_idx+1}: full-grid accuracy = {acc_inj:.6f}")
                    sinefit_acc_log[str(layer_idx + 1)] = float(acc_inj)

            else:
                colour_rule = colour_c_mod_p
                use_pair = layer_idx > 0

                report.make_layer_report(
                    prei_grid,
                    left_alive,
                    right_alive,
                    p,
                    dft_fn,
                    irreps,
                    coset_masks_L,
                    coset_masks_R,
                    report_dir,
                    cluster_tau,
                    colour_rule,
                    artifacts,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                )

                report.make_R2_c_angle_report_no_plot(
                    prei_grid,
                    p,
                    report_dir,
                    artifacts=artifacts,
                    base_layer_artifacts=layer0_artifacts,
                    base_f_pool=f_pool_L1,
                    layer_idx=layer_idx,
                    freq_map=freq_map,
                    use_pair_terms=use_pair,
                )

                report.epsilon_analysis(prei_grid, p, report_dir, artifacts)

                if do_layerwise_sinefit_ablation:
                    base_layer_artifacts = layer0_artifacts if layer0_artifacts is not None else artifacts

                    fitted_full = R2._sinefit_preact_layer_full(
                        p=group_size,
                        pre_full=prei_full,
                        pre_grid_alive=prei_grid,
                        alive_ids=alive_ids,
                        artifacts=artifacts,
                        base_layer_artifacts=base_layer_artifacts,
                        freq_map=freq_map,
                        base_f_pool=f_pool_L1,
                        use_axes_only=True,
                        use_pair_terms=use_pair,
                    )

                    acc_inj = compute_injected_accuracy_mlp(
                        params=params_seed,
                        y_eval=y_eval_full,
                        layer_idx=layer_idx + 1,
                        fitted_pre_layer=fitted_full,
                    )
                    print(f"[sinefit-ablation][seed={seed}] layer {layer_idx+1}: full-grid accuracy = {acc_inj:.6f}")
                    sinefit_acc_log[str(layer_idx + 1)] = float(acc_inj)

        if do_layerwise_sinefit_ablation:
            out_path = os.path.join(gdir, f"sinefit_injected_accuracy_p{p}.json")
            payload = {
                "seed": int(seed),
                "p": int(p),
                "group_size": int(group_size),
                "num_layers": int(num_layers),
                "accuracy_by_layer": sinefit_acc_log,
            }
            with open(out_path, "w") as f:
                json.dump(payload, f, indent=2)
            print(f"[sinefit-ablation][seed={seed}] saved layerwise injected accuracies -> {out_path}")

        preacts, X_in, weights_by_layer, cluster_contribs, cluster_W_blocks = get_all_preacts_and_embeddings(
            model=model,
            params=params_seed,
            group_size=group_size,
            clusters_by_layer=layers_freq,
        )

        json_root = make_some_jsons(
            preacts=preacts,
            group_size=group_size,
            clusters_by_layer=layers_freq,
            cluster_weights_to_logits=cluster_W_blocks,
            cluster_contribs_to_logits=cluster_contribs,
            save_dir=gdir,
            subdir="json_preacts",
            sanity_check=True,
        )

        num_pc = 4 if "cheating" not in mlp_class_lower else 2

        pdf_root = os.path.join(gdir, "pdf_plots_preacts", f"seed_{seed}")
        os.makedirs(pdf_root, exist_ok=True)
        for freq, C_freq in cluster_contribs.items():
            generate_pdf_plots_for_matrix(
                C_freq,
                p,
                save_dir=pdf_root,
                seed=seed,
                freq_list=[freq],
                tag=f"cluster_contributions_to_logits_freq={freq}",
                tag_q="full",
                colour_rule=colour_c_mod_p,
                class_string=mlp_class_lower,
                num_principal_components=num_pc,
            )

        print(f"PDF plots (preacts) written -> {pdf_root}")
