"""PC ablation: remove PCA directions from GRU hidden states and measure behavioral impact."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.stats import spearmanr

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets, _ROS_TO_DIRL
from src.gru_policy import GRUPolicy, train_gru_policy

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
N_EPOCHS = 200
BATCH_SIZE = 64
MAX_SEQ_LEN = 200
BEST_SEED = 2
WATER_PORT = _ROS_TO_DIRL[100]


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_node_depth(node):
    d = 0
    n = node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def tree_distance(a, b):
    da, db = get_node_depth(a), get_node_depth(b)
    na, nb = a, b
    while da > db:
        na = (na - 1) // 2
        da -= 1
    while db > da:
        nb = (nb - 1) // 2
        db -= 1
    dist = (get_node_depth(a) - da) + (get_node_depth(b) - db)
    while na != nb:
        na = (na - 1) // 2
        nb = (nb - 1) // 2
        dist += 2
    return dist


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq, "states": state_seq,
            })
    return dataset


def ablated_eval(policy, val_data, pca, ablate_pcs, structural_obs):
    """Evaluate policy with specific PCs ablated from the hidden state.

    Ablation propagates: the modified hidden state is passed to the next GRU step.
    This is a true "lesion" — the information is destroyed, not just hidden from the output.
    """
    if ablate_pcs is None or len(ablate_pcs) == 0:
        return _eval_no_ablation(policy, val_data)

    components = pca.components_  # (n_pcs, hidden_dim)
    remove_dirs = components[ablate_pcs]  # (n_ablated, hidden_dim)
    P_remove = torch.tensor(remove_dirs.T @ remove_dirs, dtype=torch.float32)

    policy.eval()
    total_ll, total_correct, n_steps = 0.0, 0, 0

    with torch.no_grad():
        for d in val_data:
            h = torch.zeros(1, 1, policy.hidden_dim)
            for t in range(len(d["actions"])):
                obs_t = d["obs"][t].unsqueeze(0).unsqueeze(0)
                encoded = policy.obs_encoder(obs_t)
                _, h = policy.gru(encoded, h)

                h_flat = h.squeeze()
                h_ablated = h_flat - P_remove @ h_flat
                h = h_ablated.unsqueeze(0).unsqueeze(0)  # propagate ablated state

                logits = policy.policy_head(h_ablated.unsqueeze(0))
                log_probs = torch.log_softmax(logits.squeeze(0), dim=-1)

                a = d["actions"][t].item()
                total_ll += log_probs[a].item()
                if log_probs.argmax().item() == a:
                    total_correct += 1
                n_steps += 1

    ll_bits = total_ll / (n_steps * np.log(2))
    acc = total_correct / n_steps
    return ll_bits, acc


def _eval_no_ablation(policy, val_data):
    policy.eval()
    total_ll, total_correct, n_steps = 0.0, 0, 0
    with torch.no_grad():
        for d in val_data:
            obs = d["obs"].unsqueeze(0)
            logits, _ = policy(obs)
            log_probs = torch.log_softmax(logits.squeeze(0), dim=-1)
            for t in range(len(d["actions"])):
                a = d["actions"][t].item()
                total_ll += log_probs[t, a].item()
                if log_probs[t].argmax().item() == a:
                    total_correct += 1
                n_steps += 1
    return total_ll / (n_steps * np.log(2)), total_correct / n_steps


def ablated_rollout(policy, pca, ablate_pcs, structural_obs, n_rollouts=500,
                    max_steps=200, start_node=0):
    if ablate_pcs is not None and len(ablate_pcs) > 0:
        components = pca.components_
        remove_dirs = components[ablate_pcs]
        P_remove = torch.tensor(remove_dirs.T @ remove_dirs, dtype=torch.float32)
    else:
        P_remove = None

    policy.eval()
    all_depths = []
    all_unique_nodes = []
    all_water_visits = []
    all_steps_to_water = []

    with torch.no_grad():
        for _ in range(n_rollouts):
            node = start_node
            h = torch.zeros(1, 1, policy.hidden_dim)
            visited = set()
            depths_this = []
            water_hit = False
            steps_to_water = max_steps

            for step in range(max_steps):
                visited.add(node)
                depths_this.append(get_node_depth(node))

                if node == WATER_PORT and not water_hit:
                    water_hit = True
                    steps_to_water = step

                obs_t = structural_obs[node].unsqueeze(0).unsqueeze(0)
                encoded = policy.obs_encoder(obs_t)
                _, h = policy.gru(encoded, h)

                if P_remove is not None:
                    h_flat = h.squeeze()
                    h_ablated = h_flat - P_remove @ h_flat
                    h = h_ablated.unsqueeze(0).unsqueeze(0)
                    logits = policy.policy_head(h_ablated.unsqueeze(0))
                else:
                    logits = policy.policy_head(h.squeeze(0))

                probs = torch.softmax(logits.squeeze(0), dim=-1)
                action = torch.multinomial(probs, 1).item()

                left = 2 * node + 1
                right = 2 * node + 2
                parent = (node - 1) // 2 if node > 0 else 0
                if action == 0:
                    node = left if left < N_STATES else node
                elif action == 1:
                    node = right if right < N_STATES else node
                else:
                    node = parent

            all_depths.append(np.mean(depths_this))
            all_unique_nodes.append(len(visited))
            all_water_visits.append(water_hit)
            all_steps_to_water.append(steps_to_water)

    return {
        "mean_depth": np.mean(all_depths),
        "std_depth": np.std(all_depths),
        "mean_unique_nodes": np.mean(all_unique_nodes),
        "frac_reach_water": np.mean(all_water_visits),
        "mean_steps_to_water": np.mean([s for s, hit in zip(all_steps_to_water, all_water_visits) if hit]) if any(all_water_visits) else float('inf'),
        "depth_distributions": all_depths,
    }


def main():
    print("PCA Ablation Study \n", flush=True)

    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES, n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])

    train_data = build_obs_dataset_structural(d["train_trajs"], structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    model_path = os.path.join(PROJECT_ROOT, "checkpoints", "structural_gru_model.pt")
    if os.path.exists(model_path):
        print("Loading saved GRU model...", flush=True)
        policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        policy.load_state_dict(torch.load(model_path, weights_only=True))
    else:
        print(f"Training GRU (seed={BEST_SEED}, {N_EPOCHS} epochs)...", flush=True)
        torch.manual_seed(BEST_SEED)
        np.random.seed(BEST_SEED)
        policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        policy, _ = train_gru_policy(policy, train_data, n_epochs=N_EPOCHS,
                                     lr=3e-4, batch_size=BATCH_SIZE, print_every=50)
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(policy.state_dict(), model_path)
        print(f"Saved model to {model_path}", flush=True)

    print("\nCollecting hidden states for PCA...", flush=True)
    from src.analysis import collect_hidden_states
    hidden_states, positions, timesteps, traj_ids = collect_hidden_states(policy, val_data)
    pca = PCA(n_components=50)
    pca.fit(hidden_states)

    depths = np.array([get_node_depth(p) for p in positions])
    water_dists = np.array([tree_distance(p, WATER_PORT) for p in positions])

    print("\nPC correlations:", flush=True)
    for i in range(6):
        pc = hidden_states @ pca.components_[i]
        rho_d, _ = spearmanr(pc, depths)
        rho_w, _ = spearmanr(pc, water_dists)
        rho_t, _ = spearmanr(pc, timesteps)
        print(f"PC{i+1}: depth={rho_d:+.3f}  water={rho_w:+.3f}  time={rho_t:+.3f}  "
              f"(var={pca.explained_variance_ratio_[i]:.3f})", flush=True)

        ablations = [
        ("No ablation", []),
        ("Ablate PC1 (depth)", [0]),
        ("Ablate PC2", [1]),
        ("Ablate PC3", [2]),
        ("Ablate PC4", [3]),
        ("Ablate PCs 1-2", [0, 1]),
        ("Ablate PCs 1-4", [0, 1, 2, 3]),
        ("Ablate PCs 5-50", list(range(4, 50))),
        ("Ablate random dir", "random"),
    ]

    print("\nEvaluation with ablated hidden states ", flush=True)
    print(f"{'Ablation':<25s} {'LL (bits/dec)':<16s} {'Accuracy':<12s} {'LL drop':<10s}", flush=True)

    baseline_ll, baseline_acc = None, None
    eval_results = {}

    for name, pcs in ablations:
        if pcs == "random":
            rng = np.random.RandomState(42)
            rand_dir = rng.randn(HIDDEN_DIM).astype(np.float32)
            rand_dir /= np.linalg.norm(rand_dir)
            orig_component = pca.components_[49].copy()
            pca.components_[49] = rand_dir
            ll, acc = ablated_eval(policy, val_data, pca, [49], structural_obs)
            pca.components_[49] = orig_component
        else:
            ll, acc = ablated_eval(policy, val_data, pca, pcs, structural_obs)

        if baseline_ll is None:
            baseline_ll, baseline_acc = ll, acc

        drop = ll - baseline_ll
        print(f"{name:<25s} {ll:>+10.4f}        {acc:>6.3f}      {drop:>+7.4f}", flush=True)
        eval_results[name] = {"ll": ll, "acc": acc, "drop": drop, "pcs": pcs}

    rollout_ablations = [
        ("No ablation", []),
        ("Ablate PC1 (depth)", [0]),
        ("Ablate PCs 1-4", [0, 1, 2, 3]),
        ("Ablate PCs 2-4", [1, 2, 3]),
    ]

    print(f"{'Ablation':<25s} {'Mean depth':<14s} {'Unique nodes':<15s} "
          f"{'Reach water':<14s} {'Steps to water'}", flush=True)

    rollout_results = {}
    for name, pcs in rollout_ablations:
        r = ablated_rollout(policy, pca, pcs, structural_obs)
        print(f"{name:<25s} {r['mean_depth']:>5.2f}+/-{r['std_depth']:<5.2f}  "
              f"{r['mean_unique_nodes']:>6.1f}         "
              f"{r['frac_reach_water']:>5.1%}         "
              f"{r['mean_steps_to_water']:>6.1f}", flush=True)
        rollout_results[name] = r

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5))

    names = [n for n, _ in ablations]
    lls = [eval_results[n]["ll"] for n in names]
    drops = [eval_results[n]["drop"] for n in names]
    accs = [eval_results[n]["acc"] for n in names]

    colors = ["#2166ac" if d == 0 else "#b2182b" if d < -0.002 else "#f4a582" for d in drops]
    bars = ax1.barh(range(len(names)), drops, color=colors, edgecolor="white")
    ax1.set_yticks(range(len(names)))
    ax1.set_yticklabels(names, fontsize=9)
    ax1.set_xlabel("LL change (bits/dec)", fontsize=11)
    ax1.set_title("Impact of PC Ablation on Log-Likelihood", fontsize=12, fontweight="bold")
    ax1.axvline(x=0, color="black", linewidth=0.5)
    ax1.invert_yaxis()

    for bar, drop, ll in zip(bars, drops, lls):
        if drop != 0:
            ax1.text(bar.get_width() - 0.001, bar.get_y() + bar.get_height()/2,
                     f"{drop:+.4f}", va="center", ha="right", fontsize=8, fontweight="bold",
                     color="white" if abs(drop) > 0.003 else "black")

    rollout_names = [n for n, _ in rollout_ablations]
    mean_depths = [rollout_results[n]["mean_depth"] for n in rollout_names]
    unique_nodes = [rollout_results[n]["mean_unique_nodes"] for n in rollout_names]

    x = np.arange(len(rollout_names))
    width = 0.35
    bars1 = ax2.bar(x - width/2, mean_depths, width, label="Mean depth", color="#4393c3", alpha=0.85)
    ax2_r = ax2.twinx()
    bars2 = ax2_r.bar(x + width/2, unique_nodes, width, label="Unique nodes", color="#e6550d", alpha=0.85)

    ax2.set_xticks(x)
    ax2.set_xticklabels([n.replace(" (depth)", "\n(depth)").replace("PCs ", "PCs\n")
                         for n in rollout_names], fontsize=8)
    ax2.set_ylabel("Mean Depth", fontsize=11, color="#4393c3")
    ax2_r.set_ylabel("Unique Nodes Visited", fontsize=11, color="#e6550d")
    ax2.set_title("Rollout Behavior Under Ablation", fontsize=12, fontweight="bold")
    ax2.legend(loc="upper left", fontsize=9)
    ax2_r.legend(loc="upper right", fontsize=9)

    plt.tight_layout()
    os.makedirs("figures", exist_ok=True)
    plt.savefig("figures/pca_ablation.png", dpi=200, bbox_inches="tight")
    plt.close()
    print("Saved figures/pca_ablation.png", flush=True)

    torch.save({
        "eval_results": eval_results,
        "rollout_results": rollout_results,
        "pca_components": pca.components_,
        "pca_explained_variance": pca.explained_variance_ratio_,
    }, "checkpoints/pca_ablation.pt")
    print("Saved checkpoints/pca_ablation.pt", flush=True)


if __name__ == "__main__":
    main()
