"""Hidden state analysis of structural-observation GRU: PCA, t-SNE, linear probe."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import silhouette_score

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets, _ROS_TO_DIRL
from src.gru_policy import GRUPolicy, train_gru_policy
from src.analysis import collect_hidden_states, run_dimensionality_reduction

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
N_EPOCHS = 200
BATCH_SIZE = 64
MAX_SEQ_LEN = 200
BEST_SEED = 2
SEEDS = [0, 1, 2, 3, 4]
WATER_PORT_DIRL = _ROS_TO_DIRL[100]


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0

    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0

    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_node_depth(node):
    """Depth of a node in a binary tree (root=0)."""
    d = 0
    n = node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def tree_distance(a, b):
    da, db = get_node_depth(a), get_node_depth(b)
    na, nb = a, b
    while da > db:
        na = (na - 1) // 2
        da -= 1
    while db > da:
        nb = (nb - 1) // 2
        db -= 1
    dist = (get_node_depth(a) - da) + (get_node_depth(b) - db)
    while na != nb:
        na = (na - 1) // 2
        nb = (nb - 1) // 2
        dist += 2
    return dist


def get_obs_class(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    if node == 0:
        return 0, "Root"
    elif node <= 2:
        return 1, "Depth 1"
    elif node <= 30:
        return 2, "Depth 2-4"
    elif node < first_leaf:
        return 3, "Depth 5"
    else:
        return 4, "Leaves"


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq, "states": state_seq,
            })
    return dataset


def main():
    print("Structural-Observation GRU Hidden State Analysis \n", flush=True)
    print(f"Seeds: {SEEDS}\n", flush=True)

    d = load_rosenberg_everything()

    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])

    train_data = build_obs_dataset_structural(d["train_trajs"], structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("figures", exist_ok=True)

    for seed in SEEDS:
        print(f"Seed {seed} ", flush=True)

        print(f"Training GRU (seed={seed}, {N_EPOCHS} epochs)...", flush=True)
        torch.manual_seed(seed)
        np.random.seed(seed)
        policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        policy, _ = train_gru_policy(policy, train_data, n_epochs=N_EPOCHS,
                                     lr=3e-4, batch_size=BATCH_SIZE, print_every=50)

        model_path = f"checkpoints/structural_gru_model_seed{seed}.pt"
        torch.save(policy.state_dict(), model_path)
        print(f"Saved model: {model_path}", flush=True)

        print("\nCollecting hidden states from validation set...", flush=True)
        hidden_states, positions, timesteps, traj_ids = collect_hidden_states(policy, val_data)
        N = len(positions)
        print(f"{N} hidden state vectors collected", flush=True)

        depths = np.array([get_node_depth(p) for p in positions])
        water_dists = np.array([tree_distance(p, WATER_PORT_DIRL) for p in positions])
        obs_classes = np.array([get_obs_class(p)[0] for p in positions])
        obs_class_labels = {0: "Root", 1: "Depth 1", 2: "Depth 2-4", 3: "Depth 5", 4: "Leaves"}

        is_primary = (seed == BEST_SEED)
        if is_primary:
            hidden_2d, pca, hidden_pca = run_dimensionality_reduction(hidden_states)
        else:
            from sklearn.decomposition import PCA as PCAModel
            pca = PCAModel(n_components=min(50, hidden_states.shape[1], hidden_states.shape[0]))
            hidden_pca = pca.fit_transform(hidden_states)
            hidden_2d = None

        print("Explained variance (top 10 PCs):", flush=True)
        cumvar = 0.0
        for i in range(min(10, len(pca.explained_variance_ratio_))):
            v = pca.explained_variance_ratio_[i]
            cumvar += v
            print(f"PC{i+1}: {v:.4f}  (cumulative: {cumvar:.4f})", flush=True)

        pc1 = hidden_pca[:, 0]
        rho_depth, p_depth = spearmanr(pc1, depths)
        rho_water, p_water = spearmanr(pc1, water_dists)
        rho_time, p_time = spearmanr(pc1, timesteps)
        print(f"\nPC1 Spearman correlations:", flush=True)
        print(f"vs tree depth: rho={rho_depth:+.4f} (p={p_depth:.2e})", flush=True)
        print(f"vs water distance: rho={rho_water:+.4f} (p={p_water:.2e})", flush=True)
        print(f"vs timestep: rho={rho_time:+.4f} (p={p_time:.2e})", flush=True)

        if is_primary:
            print("\nt-SNE Visualization", flush=True)
            fig, axes = plt.subplots(1, 4, figsize=(22, 5))

            ax = axes[0]
            sc = ax.scatter(hidden_2d[:, 0], hidden_2d[:, 1], c=depths, cmap="viridis",
                            s=6, alpha=0.7, edgecolors="none")
            ax.set_title("(a) Tree Depth", fontsize=12, fontweight="bold")
            plt.colorbar(sc, ax=ax, shrink=0.82, label="Depth")
            ax.set_xticks([]); ax.set_yticks([])

            ax = axes[1]
            class_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00"]
            for cls_id in range(5):
                mask = obs_classes == cls_id
                ax.scatter(hidden_2d[mask, 0], hidden_2d[mask, 1],
                           c=class_colors[cls_id], s=6, alpha=0.6,
                           label=obs_class_labels[cls_id], edgecolors="none")
            ax.set_title("(b) Structural Observation Class", fontsize=12, fontweight="bold")
            ax.legend(fontsize=7, markerscale=3, loc="best")
            ax.set_xticks([]); ax.set_yticks([])

            ax = axes[2]
            sc = ax.scatter(hidden_2d[:, 0], hidden_2d[:, 1], c=timesteps, cmap="plasma",
                            s=6, alpha=0.7, edgecolors="none")
            ax.set_title("(c) Timestep within Bout", fontsize=12, fontweight="bold")
            plt.colorbar(sc, ax=ax, shrink=0.82, label="Timestep")
            ax.set_xticks([]); ax.set_yticks([])

            ax = axes[3]
            sc = ax.scatter(hidden_2d[:, 0], hidden_2d[:, 1], c=water_dists, cmap="RdYlGn_r",
                            s=6, alpha=0.7, edgecolors="none")
            ax.set_title("(d) Distance to Water Port", fontsize=12, fontweight="bold")
            plt.colorbar(sc, ax=ax, shrink=0.82, label="Tree Distance")
            ax.set_xticks([]); ax.set_yticks([])

            plt.tight_layout()
            plt.savefig("figures/structural_hidden_main.png", dpi=200, bbox_inches="tight")
            plt.close()
            print("Saved figures/structural_hidden_main.png", flush=True)

        unique_nodes = np.unique(positions)
        n_visited = len(unique_nodes)
        chance = 1.0 / n_visited

        unique_traj_ids = np.unique(traj_ids)
        n_probe_train = int(0.8 * len(unique_traj_ids))
        probe_train_tids = set(unique_traj_ids[:n_probe_train])
        train_mask = np.array([tid in probe_train_tids for tid in traj_ids])
        test_mask = ~train_mask

        clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
        clf.fit(hidden_states[train_mask], positions[train_mask])
        probe_acc = clf.score(hidden_states[test_mask], positions[test_mask])

        print(f"Visited nodes: {n_visited}", flush=True)
        print(f"Chance level:  {chance:.4f} ({chance*100:.1f}%)", flush=True)
        print(f"Probe accuracy: {probe_acc:.4f} ({probe_acc*100:.1f}%)", flush=True)
        print(f"Ratio over chance: {probe_acc/chance:.1f}x", flush=True)

        print("\nPer observation class:", flush=True)
        for cls_id in range(5):
            cls_mask = obs_classes[test_mask] == cls_id
            if cls_mask.sum() == 0:
                continue
            cls_positions = positions[test_mask][cls_mask]
            n_cls_nodes = len(np.unique(cls_positions))
            cls_chance = 1.0 / max(n_cls_nodes, 1)
            cls_acc = clf.score(hidden_states[test_mask][cls_mask], cls_positions)
            print(f"{obs_class_labels[cls_id]:>10s}: acc={cls_acc:.3f} "
                  f"(chance={cls_chance:.3f}, {n_cls_nodes} nodes)", flush=True)

        print("Within-class silhouette scores (true node ID as label):", flush=True)

        for cls_id in range(5):
            cls_mask = obs_classes == cls_id
            cls_hidden = hidden_states[cls_mask]
            cls_nodes = positions[cls_mask]
            n_cls_labels = len(np.unique(cls_nodes))

            if n_cls_labels < 2:
                print(f"{obs_class_labels[cls_id]:>10s}: skipped (only {n_cls_labels} node)", flush=True)
                continue

            if len(cls_hidden) > 10000:
                idx = np.random.RandomState(42).choice(len(cls_hidden), 10000, replace=False)
                cls_hidden_sub = cls_hidden[idx]
                cls_nodes_sub = cls_nodes[idx]
            else:
                cls_hidden_sub = cls_hidden
                cls_nodes_sub = cls_nodes

            sil = silhouette_score(cls_hidden_sub, cls_nodes_sub, sample_size=min(5000, len(cls_hidden_sub)))
            print(f"{obs_class_labels[cls_id]:>10s}: silhouette={sil:+.4f} "
                  f"({n_cls_labels} aliased nodes, {cls_mask.sum()} samples)", flush=True)

        sil_overall = np.nan
        if len(unique_nodes) >= 2:
            sil_sub = min(10000, N)
            sil_overall = silhouette_score(hidden_states, positions, sample_size=sil_sub,
                                           random_state=42)
            print(f"\nOverall silhouette (node ID labels): {sil_overall:+.4f}", flush=True)

        print(f"PCA: PC1 explains {pca.explained_variance_ratio_[0]:.1%} of variance", flush=True)
        print(f"PC1 vs depth: rho={rho_depth:+.3f}", flush=True)
        print(f"PC1 vs water dist: rho={rho_water:+.3f}", flush=True)
        print(f"Linear probe node decoding: {probe_acc:.1%} (chance={chance:.1%}, "
              f"{probe_acc/chance:.1f}x)", flush=True)

        results = {
            "seed": seed,
            "pca_explained_variance": pca.explained_variance_ratio_.tolist(),
            "pc1_rho_depth": rho_depth, "pc1_rho_water": rho_water, "pc1_rho_time": rho_time,
            "probe_accuracy": probe_acc, "chance_level": chance,
            "n_visited_nodes": n_visited,
            "silhouette_overall": sil_overall,
            "hidden_states": hidden_states, "positions": positions,
            "depths": depths, "water_dists": water_dists,
            "obs_classes": obs_classes, "timesteps": timesteps,
            "hidden_2d": hidden_2d,
        }
        ckpt_path = f"checkpoints/structural_hidden_analysis_seed{seed}.pt"
        torch.save(results, ckpt_path)
        print(f"Saved to {ckpt_path}", flush=True)

    import shutil
    print(f"\nCopying seed={BEST_SEED} results to original paths for backward compat...", flush=True)
    shutil.copy2(f"checkpoints/structural_hidden_analysis_seed{BEST_SEED}.pt",
                 "checkpoints/structural_hidden_analysis.pt")
    shutil.copy2(f"checkpoints/structural_gru_model_seed{BEST_SEED}.pt",
                 "checkpoints/structural_gru_model.pt")
    print("to checkpoints/structural_hidden_analysis.pt", flush=True)
    print("to checkpoints/structural_gru_model.pt", flush=True)
    print("\nAll seeds complete ", flush=True)


if __name__ == "__main__":
    main()
