"""Temporal representation formation: when does the GRU's spatial map crystallize within a trial?"""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets

N_STATES = 127
N_ACTIONS = 3
MAX_SEQ_LEN = 200
BIN_WIDTH = 20     # absolute timestep bin width
N_BINS = MAX_SEQ_LEN // BIN_WIDTH  # 10 bins
MIN_CHUNK_LEN = 100  # only include chunks >= this length
SEEDS = [0, 1, 2, 3, 4]


def build_obs_class_map(n_states=127):
    obs_class = np.zeros(n_states, dtype=int)
    for s in range(n_states):
        if s == 0:
            obs_class[s] = 0
        elif s <= 2:
            obs_class[s] = 1
        elif s <= 30:
            obs_class[s] = 2
        elif s <= 62:
            obs_class[s] = 3
        else:
            obs_class[s] = 4
    obs_masks = np.zeros((5, n_states))
    for s in range(n_states):
        obs_masks[obs_class[s], s] = 1.0
    return obs_class, obs_masks


def build_transition_matrices(n_states=127):
    T = np.zeros((N_ACTIONS, n_states, n_states))
    for s in range(n_states):
        left = 2 * s + 1 if 2 * s + 1 < n_states else s
        right = 2 * s + 2 if 2 * s + 2 < n_states else s
        parent = (s - 1) // 2 if s > 0 else 0
        T[0, s, left] = 1.0
        T[1, s, right] = 1.0
        T[2, s, parent] = 1.0
    return T


def run_filter_per_timestep(val_trajs, obs_class, obs_masks, bc_policy_np,
                            T_action, max_seq_len=MAX_SEQ_LEN,
                            min_chunk_len=0):
    """Run obs-only Bayesian filter, returning per-timestep entropy
    with absolute timestep index. Only includes chunks >= min_chunk_len."""
    M = np.zeros((N_STATES, N_STATES))
    for a in range(N_ACTIONS):
        M += np.diag(bc_policy_np[:, a]) @ T_action[a]

    entropies = []
    abs_timesteps = []

    for traj in val_trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)

        for start in range(0, T, max_seq_len):
            end = min(start + max_seq_len, T)
            cs = states[start:end + 1]
            ca = actions[start:end]
            if len(ca) < min_chunk_len:
                continue

            c0 = obs_class[cs[0]]
            belief = obs_masks[c0].copy()
            belief /= belief.sum()

            for t in range(len(ca)):
                true_s = cs[t]
                pos = belief[belief > 1e-30]
                ent = -np.sum(pos * np.log2(pos))
                entropies.append(ent)
                abs_timesteps.append(t)

                new_belief = belief @ M
                c_next = obs_class[cs[t + 1]]
                new_belief *= obs_masks[c_next]
                total = new_belief.sum()
                if total > 1e-30:
                    new_belief /= total
                else:
                    new_belief = obs_masks[c_next].copy()
                    new_belief /= new_belief.sum()
                belief = new_belief

    return np.array(entropies), np.array(abs_timesteps)


def compute_bin_metrics(hidden_states, positions, depths, timesteps,
                        n_bins, bin_width):
    """Compute probe accuracy and PC1-depth correlation per absolute bin."""
    bins = timesteps // bin_width

    probe_accs = []
    pc1_corrs = []
    bin_ns = []
    for b in range(n_bins):
        mask = bins == b
        h_bin = hidden_states[mask]
        p_bin = positions[mask]
        d_bin = depths[mask]
        n_bin = mask.sum()
        bin_ns.append(n_bin)

        if n_bin < 50:
            probe_accs.append(np.nan)
            pc1_corrs.append(np.nan)
            continue

        n_train = int(0.8 * n_bin)
        idx = np.arange(n_bin)
        rng = np.random.RandomState(42)
        rng.shuffle(idx)
        train_idx, test_idx = idx[:n_train], idx[n_train:]

        clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
        clf.fit(h_bin[train_idx], p_bin[train_idx])
        acc = clf.score(h_bin[test_idx], p_bin[test_idx])
        probe_accs.append(acc)

        pca = PCA(n_components=min(10, h_bin.shape[1], h_bin.shape[0]))
        h_pca = pca.fit_transform(h_bin)
        rho, _ = spearmanr(h_pca[:, 0], d_bin)
        pc1_corrs.append(abs(rho))

    return probe_accs, pc1_corrs, bin_ns


def main():
    print(f"Aggregating over {len(SEEDS)} seeds: {SEEDS}\n", flush=True)

    all_probe_accs = []  # list of lists, one per seed
    all_pc1_corrs = []

    for seed in SEEDS:
        ckpt_path = f"checkpoints/structural_hidden_analysis_seed{seed}.pt"
        print(f"Seed {seed}: Loading {ckpt_path} ", flush=True)
        ckpt = torch.load(ckpt_path, weights_only=False)
        hidden_states = np.array(ckpt["hidden_states"])
        positions = np.array(ckpt["positions"])
        timesteps = np.array(ckpt["timesteps"])
        depths = np.array(ckpt["depths"])
        N = len(positions)
        print(f"{N} vectors, hidden_dim={hidden_states.shape[1]}", flush=True)

        boundaries = [0]
        for i in range(1, N):
            if timesteps[i] <= timesteps[i - 1]:
                boundaries.append(i)
        boundaries.append(N)
        n_chunks = len(boundaries) - 1
        chunk_lens = [boundaries[i + 1] - boundaries[i] for i in range(n_chunks)]

        long_mask = np.zeros(N, dtype=bool)
        n_long = 0
        for c in range(n_chunks):
            s, e = boundaries[c], boundaries[c + 1]
            if chunk_lens[c] >= MIN_CHUNK_LEN:
                long_mask[s:e] = True
                n_long += 1
        ts_mask = timesteps < MAX_SEQ_LEN
        use_mask = long_mask & ts_mask

        h_long = hidden_states[use_mask]
        p_long = positions[use_mask]
        d_long = depths[use_mask]
        ts_long = timesteps[use_mask]
        print(f"Long chunks: {n_long}, samples: {use_mask.sum()}", flush=True)

        probe_accs, pc1_corrs, bin_ns = compute_bin_metrics(
            h_long, p_long, d_long, ts_long, N_BINS, BIN_WIDTH)
        all_probe_accs.append(probe_accs)
        all_pc1_corrs.append(pc1_corrs)
        print(f"Last bin probe: {probe_accs[-1]*100:.1f}%, "
              f"|PC1-depth|: {pc1_corrs[-1]:.3f}", flush=True)

    all_probe_accs = np.array(all_probe_accs)  # (n_seeds, n_bins)
    all_pc1_corrs = np.array(all_pc1_corrs)
    mean_probe = np.nanmean(all_probe_accs, axis=0)
    std_probe = np.nanstd(all_probe_accs, axis=0)
    mean_pc1 = np.nanmean(all_pc1_corrs, axis=0)
    std_pc1 = np.nanstd(all_pc1_corrs, axis=0)

    print("\nRunning Bayesian filter (long chunks only)...", flush=True)
    data = load_rosenberg_everything()
    bc_policy = build_bc_targets(data["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    bc_policy_np = bc_policy.numpy()
    obs_class, obs_masks = build_obs_class_map()
    T_action = build_transition_matrices()

    bayes_ent, bayes_ts = run_filter_per_timestep(
        data["val_trajs"], obs_class, obs_masks, bc_policy_np, T_action,
        min_chunk_len=MIN_CHUNK_LEN)
    print(f"{len(bayes_ent)} Bayesian filter steps", flush=True)

    bayes_valid = bayes_ts < MAX_SEQ_LEN
    bayes_bins = bayes_ts[bayes_valid] // BIN_WIDTH
    bayes_ent_valid = bayes_ent[bayes_valid]

    bayes_ent_per_bin = []
    for b in range(N_BINS):
        mask = bayes_bins == b
        mean_ent = bayes_ent_valid[mask].mean() if mask.sum() > 0 else np.nan
        bayes_ent_per_bin.append(mean_ent)

    x_steps = [(b + 0.5) * BIN_WIDTH for b in range(N_BINS)]
    chance = 1.0 / 127
    os.makedirs("figures", exist_ok=True)

    fig, ax = plt.subplots(figsize=(5, 4.5))
    ax.plot(x_steps, mean_probe * 100, "o-", color="#2c7bb6",
            linewidth=2, markersize=6, label=f"GRU probe (n={len(SEEDS)})")
    ax.fill_between(x_steps,
                    (mean_probe - std_probe) * 100,
                    (mean_probe + std_probe) * 100,
                    alpha=0.2, color="#2c7bb6")
    ax.set_xlabel("Timestep within trial")
    ax.set_ylabel("Node decoding accuracy (%)")
    ax.set_xlim(0, MAX_SEQ_LEN)
    ax.set_ylim(bottom=0)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    fig.savefig("figures/temporal_probe.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/temporal_probe.png", flush=True)

    fig, ax = plt.subplots(figsize=(5, 4.5))
    ax.plot(x_steps, mean_pc1, "s-", color="#d7191c",
            linewidth=2, markersize=6)
    ax.fill_between(x_steps,
                    mean_pc1 - std_pc1,
                    mean_pc1 + std_pc1,
                    alpha=0.2, color="#d7191c")
    ax.set_xlabel("Timestep within trial")
    ax.set_ylabel("|Spearman $\\rho$| (PC1 vs depth)")
    ax.set_xlim(0, MAX_SEQ_LEN)
    lo = max(0, (mean_pc1 - std_pc1).min() - 0.05)
    hi = min(1, (mean_pc1 + std_pc1).max() + 0.05)
    ax.set_ylim(lo, hi)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    fig.savefig("figures/temporal_pc1_depth.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/temporal_pc1_depth.png", flush=True)

    fig, ax = plt.subplots(figsize=(5, 4.5))
    ax.plot(x_steps, bayes_ent_per_bin, "D-", color="#fdae61",
            linewidth=2, markersize=5, label="Bayes filter entropy")
    ax.set_xlabel("Timestep within trial")
    ax.set_ylabel("Belief entropy (bits)")
    ax.set_xlim(0, MAX_SEQ_LEN)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    fig.savefig("figures/temporal_bayes.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/temporal_bayes.png", flush=True)

    results = {
        "mean_probe_accs": mean_probe.tolist(),
        "std_probe_accs": std_probe.tolist(),
        "mean_pc1_depth_corrs": mean_pc1.tolist(),
        "std_pc1_depth_corrs": std_pc1.tolist(),
        "all_probe_accs": all_probe_accs.tolist(),
        "all_pc1_corrs": all_pc1_corrs.tolist(),
        "bayes_ent_per_bin": bayes_ent_per_bin,
        "bin_centers_steps": x_steps,
        "seeds": SEEDS,
        "min_chunk_len": MIN_CHUNK_LEN,
        "chance_level": chance,
        "bin_width": BIN_WIDTH,
    }
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(results, "checkpoints/temporal_formation.pt")
    print(f"\nSaved to checkpoints/temporal_formation.pt", flush=True)

    print(f"probe {mean_probe[0]*100:.1f}% to {mean_probe[-1]*100:.1f}%, "
          f"|PC1-depth| {mean_pc1[0]:.3f} to {mean_pc1[-1]:.3f}, "
          f"Bayes entropy {bayes_ent_per_bin[0]:.2f} to {bayes_ent_per_bin[-1]:.2f} bits")


if __name__ == "__main__":
    main()
