"""Per-node policy distortion under PC1 ablation."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets, _ROS_TO_DIRL
from src.gru_policy import GRUPolicy
from src.analysis import collect_hidden_states

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
MAX_SEQ_LEN = 200
WATER_PORT = _ROS_TO_DIRL[100]
SEEDS = [0, 1, 2, 3, 4]
PRIMARY_SEED = 2


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_node_depth(node):
    d = 0
    n = node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq, "states": state_seq,
            })
    return dataset


def binary_tree_layout(n_states=127):
    pos = {}
    for node in range(n_states):
        depth = get_node_depth(node)
        first_at_depth = (1 << depth) - 1
        idx_at_depth = node - first_at_depth
        n_at_depth = 1 << depth
        x = (idx_at_depth + 0.5) / n_at_depth
        y = -depth
        pos[node] = (x, y)
    return pos


def collect_per_node_distributions(policy, val_data, pca, ablate_pcs=None):
    """Run GRU on val data, collect per-node action distributions.

    If ablate_pcs is provided, ablates those PCs (propagating).
    Returns dict: {node: (n_visits, mean_probs[3])}
    """
    if ablate_pcs is not None and len(ablate_pcs) > 0:
        components = pca.components_
        remove_dirs = components[ablate_pcs]
        P_remove = torch.tensor(remove_dirs.T @ remove_dirs, dtype=torch.float32)
    else:
        P_remove = None

    policy.eval()
    node_probs = {s: [] for s in range(N_STATES)}

    with torch.no_grad():
        for d in val_data:
            h = torch.zeros(1, 1, policy.hidden_dim)
            for t in range(len(d["states"])):
                obs_t = d["obs"][t].unsqueeze(0).unsqueeze(0)
                encoded = policy.obs_encoder(obs_t)
                _, h = policy.gru(encoded, h)

                if P_remove is not None:
                    h_flat = h.squeeze()
                    h_ablated = h_flat - P_remove @ h_flat
                    h = h_ablated.unsqueeze(0).unsqueeze(0)
                    logits = policy.policy_head(h_ablated.unsqueeze(0))
                else:
                    logits = policy.policy_head(h.squeeze(0))

                probs = torch.softmax(logits.squeeze(), dim=-1).numpy()
                node = d["states"][t]
                node_probs[node].append(probs)

    result = {}
    for node in range(N_STATES):
        if len(node_probs[node]) > 0:
            all_probs = np.array(node_probs[node])
            result[node] = (len(all_probs), all_probs.mean(axis=0))
    return result


def compute_divergences(intact_dist, ablated_dist):
    tv = {}
    kl = {}
    for node in range(N_STATES):
        if node not in intact_dist or node not in ablated_dist:
            continue
        p = intact_dist[node][1]  # intact probs
        q = ablated_dist[node][1]  # ablated probs

        tv[node] = 0.5 * np.abs(p - q).sum()
        kl[node] = np.sum(p * np.log(np.clip(p, 1e-10, 1) / np.clip(q, 1e-10, 1)))

    return tv, kl


def run_ablation_for_seed(seed, val_data, obs_dim):
    from scipy.stats import spearmanr as _spearmanr

    model_path = os.path.join(PROJECT_ROOT, "checkpoints", f"structural_gru_model_seed{seed}.pt")
    print(f"Loading model: {model_path}", flush=True)
    policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
    policy.load_state_dict(torch.load(model_path, weights_only=True))

    hidden_states, positions, _, _ = collect_hidden_states(policy, val_data)
    pca = PCA(n_components=50)
    pca.fit(hidden_states)

    depths = np.array([get_node_depth(p) for p in positions])
    pc1_vals = hidden_states @ pca.components_[0]
    rho, _ = _spearmanr(pc1_vals, depths)
    if rho < 0:
        pca.components_[0] *= -1
        print(f"Flipped PC1 sign (rho was {rho:.3f})", flush=True)

    intact = collect_per_node_distributions(policy, val_data, pca, ablate_pcs=None)
    ablated_pc1 = collect_per_node_distributions(policy, val_data, pca, ablate_pcs=[0])
    ablated_pc24 = collect_per_node_distributions(policy, val_data, pca, ablate_pcs=[1, 2, 3])

    tv_pc1, kl_pc1 = compute_divergences(intact, ablated_pc1)
    tv_pc24, kl_pc24 = compute_divergences(intact, ablated_pc24)

    return tv_pc1, tv_pc24, kl_pc1, kl_pc24, intact, ablated_pc1, ablated_pc24


def main():
    print("Per-Node Policy Distortion Under PC Ablation \n", flush=True)
    print(f"Seeds: {SEEDS}\n", flush=True)

    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES, n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    all_tv_pc1 = []
    all_tv_pc24 = []
    primary_intact = None
    primary_ablated_pc1 = None
    primary_ablated_pc24 = None
    primary_tv_pc1 = None
    primary_tv_pc24 = None
    primary_kl_pc1 = None
    primary_kl_pc24 = None

    for seed in SEEDS:
        print(f"\nSeed {seed} ", flush=True)
        tv_pc1, tv_pc24, kl_pc1, kl_pc24, intact, ablated_pc1, ablated_pc24 = \
            run_ablation_for_seed(seed, val_data, obs_dim)
        all_tv_pc1.append(tv_pc1)
        all_tv_pc24.append(tv_pc24)

        if seed == PRIMARY_SEED:
            primary_intact = intact
            primary_ablated_pc1 = ablated_pc1
            primary_ablated_pc24 = ablated_pc24
            primary_tv_pc1 = tv_pc1
            primary_tv_pc24 = tv_pc24
            primary_kl_pc1 = kl_pc1
            primary_kl_pc24 = kl_pc24

        print(f"Per-depth TV (PC1 ablated):", flush=True)
        for depth in range(7):
            nodes = [n for n in range(N_STATES) if get_node_depth(n) == depth and n in tv_pc1]
            if nodes:
                vals = [tv_pc1[n] for n in nodes]
                print(f"depth {depth}: {np.mean(vals):.4f} +/- {np.std(vals):.4f}", flush=True)

    tv_pc1 = primary_tv_pc1
    tv_pc24 = primary_tv_pc24
    intact = primary_intact
    ablated_pc1 = primary_ablated_pc1

    print("\nPlotting tree heatmaps (primary seed)...", flush=True)
    pos = binary_tree_layout()
    os.makedirs("figures", exist_ok=True)

    water_path = []
    n = WATER_PORT
    while n > 0:
        water_path.append(n)
        n = (n - 1) // 2
    water_path.append(0)
    water_path = water_path[::-1]
    water_edges = list(zip(water_path[:-1], water_path[1:]))

    fig, axes = plt.subplots(1, 3, figsize=(20, 7))

    ax = axes[0]
    tv_vals = np.array([tv_pc1.get(n, 0) for n in range(N_STATES)])
    xs = [pos[n][0] for n in range(N_STATES)]
    ys = [pos[n][1] for n in range(N_STATES)]

    for node in range(N_STATES):
        left = 2 * node + 1
        right = 2 * node + 2
        if left < N_STATES:
            ax.plot([pos[node][0], pos[left][0]], [pos[node][1], pos[left][1]],
                    'k-', alpha=0.1, linewidth=0.5)
        if right < N_STATES:
            ax.plot([pos[node][0], pos[right][0]], [pos[node][1], pos[right][1]],
                    'k-', alpha=0.1, linewidth=0.5)
    for n1, n2 in water_edges:
        ax.plot([pos[n1][0], pos[n2][0]], [pos[n1][1], pos[n2][1]],
                'b-', alpha=0.4, linewidth=2)
    sc = ax.scatter(xs, ys, c=tv_vals, cmap="YlOrRd", s=40, edgecolors="gray",
                    linewidths=0.3, vmin=0, vmax=tv_vals.max())
    ax.scatter(*pos[0], c="none", s=100, marker="^", edgecolors="black", linewidths=1.5, zorder=10)
    ax.scatter(*pos[WATER_PORT], c="none", s=100, marker="s", edgecolors="blue", linewidths=1.5, zorder=10)
    plt.colorbar(sc, ax=ax, shrink=0.82, label="TV Distance")
    ax.set_title("PC1 Ablation: Policy Distortion", fontsize=12, fontweight="bold")
    ax.axis("off")

    ax = axes[1]
    tv24_vals = np.array([tv_pc24.get(n, 0) for n in range(N_STATES)])
    for node in range(N_STATES):
        left = 2 * node + 1
        right = 2 * node + 2
        if left < N_STATES:
            ax.plot([pos[node][0], pos[left][0]], [pos[node][1], pos[left][1]],
                    'k-', alpha=0.1, linewidth=0.5)
        if right < N_STATES:
            ax.plot([pos[node][0], pos[right][0]], [pos[node][1], pos[right][1]],
                    'k-', alpha=0.1, linewidth=0.5)
    for n1, n2 in water_edges:
        ax.plot([pos[n1][0], pos[n2][0]], [pos[n1][1], pos[n2][1]],
                'b-', alpha=0.4, linewidth=2)
    sc = ax.scatter(xs, ys, c=tv24_vals, cmap="YlOrRd", s=40, edgecolors="gray",
                    linewidths=0.3, vmin=0, vmax=tv24_vals.max())
    ax.scatter(*pos[0], c="none", s=100, marker="^", edgecolors="black", linewidths=1.5, zorder=10)
    ax.scatter(*pos[WATER_PORT], c="none", s=100, marker="s", edgecolors="blue", linewidths=1.5, zorder=10)
    plt.colorbar(sc, ax=ax, shrink=0.82, label="TV Distance")
    ax.set_title("PCs 2-4 Ablation: Policy Distortion", fontsize=12, fontweight="bold")
    ax.axis("off")

    ax = axes[2]
    diff_vals = np.array([tv_pc1.get(n, 0) - tv_pc24.get(n, 0) for n in range(N_STATES)])
    for node in range(N_STATES):
        left = 2 * node + 1
        right = 2 * node + 2
        if left < N_STATES:
            ax.plot([pos[node][0], pos[left][0]], [pos[node][1], pos[left][1]],
                    'k-', alpha=0.1, linewidth=0.5)
        if right < N_STATES:
            ax.plot([pos[node][0], pos[right][0]], [pos[node][1], pos[right][1]],
                    'k-', alpha=0.1, linewidth=0.5)
    for n1, n2 in water_edges:
        ax.plot([pos[n1][0], pos[n2][0]], [pos[n1][1], pos[n2][1]],
                'b-', alpha=0.4, linewidth=2)
    vmax = max(abs(diff_vals.min()), abs(diff_vals.max()))
    sc = ax.scatter(xs, ys, c=diff_vals, cmap="RdBu_r", s=40, edgecolors="gray",
                    linewidths=0.3, vmin=-vmax, vmax=vmax)
    ax.scatter(*pos[0], c="none", s=100, marker="^", edgecolors="black", linewidths=1.5, zorder=10)
    ax.scatter(*pos[WATER_PORT], c="none", s=100, marker="s", edgecolors="blue", linewidths=1.5, zorder=10)
    plt.colorbar(sc, ax=ax, shrink=0.82, label="TV(PC1) - TV(PCs 2-4)")
    ax.set_title("PC1 vs PCs 2-4: Where Depth Matters Most", fontsize=12, fontweight="bold")
    ax.axis("off")

    plt.suptitle("Per-Node Policy Distortion Under Representational Ablation\n"
                 "(triangle = root, square = water port, blue line = path to water)",
                 fontsize=13, fontweight="bold", y=1.03)
    plt.tight_layout()
    plt.savefig("figures/ablation_per_node.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/ablation_per_node.png", flush=True)

    example_nodes = [0, 5, 11, 24, 49, 100, WATER_PORT]
    example_nodes = [n for n in example_nodes if n in intact and n in ablated_pc1]

    fig, axes = plt.subplots(1, len(example_nodes), figsize=(3.2 * len(example_nodes), 4))
    if len(example_nodes) == 1:
        axes = [axes]

    action_labels = ["Left", "Right", "Reverse"]
    x = np.arange(3)
    width = 0.35

    for idx, node in enumerate(example_nodes):
        ax = axes[idx]
        p = intact[node][1]
        q = ablated_pc1[node][1]

        ax.bar(x - width/2, p, width, label="Intact", color="#4393c3", alpha=0.85)
        ax.bar(x + width/2, q, width, label="PC1 ablated", color="#d6604d", alpha=0.85)

        ax.set_xticks(x)
        ax.set_xticklabels(action_labels, fontsize=8)
        ax.set_ylim(0, 1.0)
        depth = get_node_depth(node)
        tv = tv_pc1[node]
        label = f"Node {node}" if node != WATER_PORT else f"Node {node} (water)"
        ax.set_title(f"{label}\ndepth={depth}, TV={tv:.3f}", fontsize=9, fontweight="bold")
        if idx == 0:
            ax.set_ylabel("Action probability", fontsize=10)
            ax.legend(fontsize=7)

    plt.suptitle("Action Distribution Shift Under PC1 (Depth) Ablation",
                 fontsize=12, fontweight="bold", y=1.02)
    plt.tight_layout()
    plt.savefig("figures/ablation_action_shift.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/ablation_action_shift.png", flush=True)

    torch.save({
        "tv_pc1": primary_tv_pc1, "kl_pc1": primary_kl_pc1,
        "tv_pc24": primary_tv_pc24, "kl_pc24": primary_kl_pc24,
        "all_tv_pc1": all_tv_pc1,
        "all_tv_pc24": all_tv_pc24,
        "seeds": SEEDS,
        "intact": {n: (v[0], v[1].tolist()) for n, v in primary_intact.items()},
        "ablated_pc1": {n: (v[0], v[1].tolist()) for n, v in primary_ablated_pc1.items()},
        "ablated_pc24": {n: (v[0], v[1].tolist()) for n, v in primary_ablated_pc24.items()},
    }, "checkpoints/ablation_per_node.pt")
    print("Saved checkpoints/ablation_per_node.pt", flush=True)


if __name__ == "__main__":
    main()
