"""RSA: compare GRU hidden-state geometry against tree-structural RDMs via Mantel test."""

import os
import sys
import numpy as np
import torch
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import squareform, pdist
from scipy.stats import bootstrap

import rsatoolbox
from rsatoolbox.rdm import RDMs
from skbio.stats.distance import mantel, DistanceMatrix

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(PROJECT_ROOT, "src"))
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
FIGURE_DIR = os.path.join(PROJECT_ROOT, "figures")
os.makedirs(FIGURE_DIR, exist_ok=True)

N_STATES = 127
WATER_PORT = 100
SEEDS = [0, 1, 2, 3, 4]
PRIMARY_SEED = 2


def build_neural_rdm(hidden_states, positions):
    """Compute per-node mean hidden state and return Euclidean RDM."""
    hidden_dim = hidden_states.shape[1]
    mean_hidden = np.zeros((N_STATES, hidden_dim), dtype=np.float32)
    counts = np.zeros(N_STATES, dtype=np.int32)

    for i in range(len(positions)):
        node = positions[i]
        mean_hidden[node] += hidden_states[i]
        counts[node] += 1

    visited_mask = counts > 0
    mean_hidden[visited_mask] /= counts[visited_mask, None]

    n_visited = visited_mask.sum()
    n_unvisited = N_STATES - n_visited
    print(f"Nodes visited: {n_visited}/{N_STATES} (unvisited: {n_unvisited})")

    data = rsatoolbox.data.Dataset(
        mean_hidden,
        obs_descriptors={"node_id": np.arange(N_STATES)},
    )
    neural_rdm = rsatoolbox.rdm.calc_rdm(data, method="euclidean")

    return neural_rdm, mean_hidden


def build_binary_tree():
    G = nx.Graph()
    G.add_nodes_from(range(N_STATES))
    for node in range(N_STATES):
        left = 2 * node + 1
        right = 2 * node + 2
        if left < N_STATES:
            G.add_edge(node, left)
        if right < N_STATES:
            G.add_edge(node, right)
    return G


def get_root_path(node):
    """Return the L/R path string from root to node."""
    path = []
    while node > 0:
        if node % 2 == 1:  # left child
            path.append("L")
        else:  # right child
            path.append("R")
        node = (node - 1) // 2
    return "".join(reversed(path))


def build_model_rdms(G):
    """Build 4 candidate dissimilarity matrices (graph dist, depth, subtree, Hamming)."""
    graph_dist = dict(nx.all_pairs_shortest_path_length(G))
    graph_dm = np.array(
        [[graph_dist[i][j] for j in range(N_STATES)] for i in range(N_STATES)],
        dtype=float,
    )

    depths = nx.single_source_shortest_path_length(G, 0)
    depth_dm = np.array(
        [[abs(depths[i] - depths[j]) for j in range(N_STATES)] for i in range(N_STATES)],
        dtype=float,
    )

    paths = {node: get_root_path(node) for node in range(N_STATES)}

    def get_depth2_ancestor(node):
        p = paths[node]
        if len(p) < 2:
            return p  # root or depth-1 nodes have no depth-2 ancestor
        return p[:2]

    subtree_dm = np.zeros((N_STATES, N_STATES), dtype=float)
    for i in range(N_STATES):
        for j in range(i + 1, N_STATES):
            same = get_depth2_ancestor(i) == get_depth2_ancestor(j)
            subtree_dm[i, j] = 0.0 if same else 1.0
            subtree_dm[j, i] = subtree_dm[i, j]

    hamming_dm = np.zeros((N_STATES, N_STATES), dtype=float)
    for i in range(N_STATES):
        for j in range(i + 1, N_STATES):
            pi, pj = paths[i], paths[j]
            max_len = max(len(pi), len(pj))
            if max_len == 0:
                continue
            pi_padded = pi.ljust(max_len, "_")
            pj_padded = pj.ljust(max_len, "_")
            dist = sum(a != b for a, b in zip(pi_padded, pj_padded))
            hamming_dm[i, j] = dist
            hamming_dm[j, i] = dist

    return {
        "Graph distance": graph_dm,
        "Depth difference": depth_dm,
        "Subtree identity": subtree_dm,
        "Path Hamming": hamming_dm,
    }


def run_mantel_tests(neural_rdm, model_rdms, n_permutations=9999):
    neural_sq = squareform(neural_rdm.dissimilarities.flatten())

    results = {}
    for name, model_dm in model_rdms.items():
        neural_skbio = DistanceMatrix(neural_sq)
        model_skbio = DistanceMatrix(model_dm)
        coeff, p_value, n_obs = mantel(
            neural_skbio, model_skbio,
            method="spearman",
            permutations=n_permutations,
        )
        results[name] = (coeff, p_value)
        print(f"{name:25s}: rho = {coeff:.4f}, p = {p_value:.2e}")

    return results


def bootstrap_mantel_ci(neural_sq, model_dm, n_bootstrap=1000, seed=42):
    """Bootstrap 95% CI for Mantel correlation by resampling node pairs."""
    from scipy.stats import spearmanr

    neural_upper = squareform(neural_sq, checks=False)
    model_upper = squareform(model_dm, checks=False)

    rng = np.random.default_rng(seed)
    n_pairs = len(neural_upper)
    rhos = np.zeros(n_bootstrap)

    for b in range(n_bootstrap):
        idx = rng.choice(n_pairs, size=n_pairs, replace=True)
        r, _ = spearmanr(neural_upper[idx], model_upper[idx])
        rhos[b] = r

    ci_low = np.percentile(rhos, 2.5)
    ci_high = np.percentile(rhos, 97.5)
    return ci_low, ci_high


def dfs_order(G, root=0):
    """Depth-first traversal order so nearby tree nodes are adjacent in matrices."""
    return list(nx.dfs_preorder_nodes(G, source=root))


def plot_rdm_heatmaps(neural_sq, model_rdms, G, save_path):
    order = dfs_order(G)
    idx = np.array(order)

    fig, axes = plt.subplots(1, 4, figsize=(22, 5))

    rdms_to_plot = [
        ("Neural (GRU)", neural_sq),
        ("Graph Distance", model_rdms["Graph distance"]),
        ("Depth Difference", model_rdms["Depth difference"]),
        ("Subtree Identity", model_rdms["Subtree identity"]),
    ]

    for ax, (_title, dm) in zip(axes, rdms_to_plot):
        dm_ordered = dm[np.ix_(idx, idx)]
        dm_norm = dm_ordered / dm_ordered.max() if dm_ordered.max() > 0 else dm_ordered
        im = ax.imshow(dm_norm, cmap="viridis", aspect="equal", interpolation="none")
        ax.set_xlabel("Node (DFS order)", fontsize=9)
        ax.set_ylabel("Node (DFS order)", fontsize=9)
        plt.colorbar(im, ax=ax, shrink=0.82)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {save_path}")


def plot_mantel_bar_chart(mean_rhos, std_rhos, p_vals, save_path):
    names = list(mean_rhos.keys())
    means = [mean_rhos[n] for n in names]
    stds = [std_rhos[n] for n in names]
    ps = [p_vals[n] for n in names]

    colors = sns.color_palette("deep", len(names))

    fig, ax = plt.subplots(figsize=(8, 5))
    bars = ax.bar(
        range(len(names)), means,
        yerr=stds,
        capsize=5, color=colors, edgecolor="black", linewidth=0.8,
    )

    for i, (r, p, s) in enumerate(zip(means, ps, stds)):
        star = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "n.s."
        if r >= 0:
            star_y = r + s + 0.02
            label_y = r / 2
        else:
            star_y = r - s - 0.02
            label_y = r / 2
        ax.text(i, star_y, star, ha="center",
                va="bottom" if r >= 0 else "top", fontsize=12, fontweight="bold")
        ax.text(i, label_y, f"{r:.3f}", ha="center", va="center",
                fontsize=11, fontweight="bold", color="white")

    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, rotation=20, ha="right", fontsize=11)
    ax.set_ylabel("Spearman $\\rho$ (Mantel)", fontsize=13)
    y_max = max(means) + max(stds) + 0.1
    y_min = min(means) - max(stds) - 0.1 if min(means) < 0 else -0.05
    ax.set_ylim(y_min, y_max)
    ax.axhline(0, color="gray", linestyle="-", linewidth=0.8)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {save_path}")


def main():
    print(f"Aggregating over {len(SEEDS)} seeds: {SEEDS}\n")

    G = build_binary_tree()
    model_rdms = build_model_rdms(G)
    for name, dm in model_rdms.items():
        print(f"{name}: shape={dm.shape}, range=[{dm.min():.1f}, {dm.max():.1f}]")
    print()

    all_rhos = {name: [] for name in model_rdms}
    all_pvals = {name: [] for name in model_rdms}
    primary_neural_sq = None
    primary_mean_hidden = None

    for seed in SEEDS:
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"structural_hidden_analysis_seed{seed}.pt")
        print(f"Seed {seed}: Loading {ckpt_path}")
        ckpt = torch.load(ckpt_path, weights_only=False)
        hidden_states = ckpt["hidden_states"]
        positions = ckpt["positions"]

        if isinstance(hidden_states, torch.Tensor):
            hidden_states = hidden_states.numpy()
        if isinstance(positions, torch.Tensor):
            positions = positions.numpy()

        print(f"Hidden states: {hidden_states.shape}, "
              f"unique positions: {len(np.unique(positions))}")

        neural_rdm, mean_hidden = build_neural_rdm(hidden_states, positions)
        neural_sq = squareform(neural_rdm.dissimilarities.flatten())

        if seed == PRIMARY_SEED:
            primary_neural_sq = neural_sq
            primary_mean_hidden = mean_hidden

        print(f"Running Mantel tests (9999 permutations)...")
        mantel_results = run_mantel_tests(neural_rdm, model_rdms)
        for name in model_rdms:
            rho, p = mantel_results[name]
            all_rhos[name].append(rho)
            all_pvals[name].append(p)
        print()

    mean_rhos = {name: np.mean(all_rhos[name]) for name in model_rdms}
    std_rhos = {name: np.std(all_rhos[name]) for name in model_rdms}
    min_pvals = {name: np.min(all_pvals[name]) for name in model_rdms}

    plot_rdm_heatmaps(
        primary_neural_sq, model_rdms, G,
        os.path.join(FIGURE_DIR, "rsa_rdm_heatmaps.png"),
    )
    plot_mantel_bar_chart(
        mean_rhos, std_rhos, min_pvals,
        os.path.join(FIGURE_DIR, "rsa_mantel_bar.png"),
    )

    results = {
        "mean_rhos": mean_rhos,
        "std_rhos": std_rhos,
        "all_rhos": all_rhos,
        "all_pvals": all_pvals,
        "seeds": SEEDS,
        "mean_hidden_states": primary_mean_hidden,
        "neural_dm": primary_neural_sq,
        "model_rdms": model_rdms,
    }
    save_path = os.path.join(CHECKPOINT_DIR, "rsa_results.pt")
    torch.save(results, save_path)
    print(f"\nResults saved: {save_path}")

    print(f"\nRSA Summary ({len(SEEDS)} seeds) \n")
    print(f"{'Model RDM':25s}  {'Mean rho':>10s}  {'Std':>8s}  Per-seed rhos")
    for name in model_rdms:
        rhos_str = ", ".join(f"{r:.4f}" for r in all_rhos[name])
        print(f"{name:25s}  {mean_rhos[name]:10.4f}  {std_rhos[name]:8.4f}  {rhos_str}")



if __name__ == "__main__":
    main()
