"""Probe baselines: raw obs, untrained GRU, and selectivity control."""

import os
import sys
import argparse

import numpy as np
import torch
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.gru_policy import GRUPolicy
from src.analysis import collect_hidden_states

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
MAX_SEQ_LEN = 200
N_SHUFFLES = 10
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    degree = 2.0 if node == 0 else (1.0 if node >= first_leaf else 3.0)
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        d_degree = 2.0 if dest == 0 else (1.0 if dest >= first_leaf else 3.0)
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_node_depth(node):
    d, n = 0, node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq, "states": state_seq,
            })
    return dataset


def probe_train_test(X, y, traj_ids):
    """80/20 split by trajectory, fit LogReg, return test accuracy."""
    unique_tids = np.unique(traj_ids)
    n_train = int(0.8 * len(unique_tids))
    train_tids = set(unique_tids[:n_train])
    train_mask = np.array([tid in train_tids for tid in traj_ids])
    test_mask = ~train_mask
    clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
    clf.fit(X[train_mask], y[train_mask])
    return clf.score(X[test_mask], y[test_mask]), train_mask, test_mask


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2, 3, 4])
    args = parser.parse_args()
    seeds = args.seeds


    print("Loading Rosenberg data...", flush=True)
    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    obs_matrix = np.stack([structural_obs[s].numpy() for s in range(N_STATES)])

    results = {
        "raw_obs": [], "untrained_gru": [],
        "selectivity": [], "trained_gru": [],
        "untrained_pc1_rho": [], "trained_pc1_rho": [],
        "untrained_pc1_time": [], "trained_pc1_time": [],
    }
    chance = 1.0 / len(np.unique([s for dd in val_data for s in dd["states"]]))

    for seed in seeds:
        print(f"\nSeed {seed} ", flush=True)

        model_path = os.path.join(CHECKPOINT_DIR, f"structural_gru_model_seed{seed}.pt")
        if not os.path.exists(model_path):
            print(f"SKIP: {model_path} not found", flush=True)
            continue

        trained = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        trained.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
        trained.eval()

        h_trained, positions, timesteps, traj_ids = collect_hidden_states(trained, val_data)
        depths = np.array([get_node_depth(p) for p in positions])

        obs_vectors = obs_matrix[positions]
        raw_acc, train_mask, test_mask = probe_train_test(
            obs_vectors, positions, traj_ids)
        results["raw_obs"].append(raw_acc)
        print(f"Raw observation probe: {raw_acc:.4f} ({raw_acc*100:.1f}%)", flush=True)

        torch.manual_seed(seed)
        untrained = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        untrained.eval()
        h_untrained, _, _, _ = collect_hidden_states(untrained, val_data)

        untrained_acc, _, _ = probe_train_test(h_untrained, positions, traj_ids)
        results["untrained_gru"].append(untrained_acc)

        pca_u = PCA(n_components=min(5, HIDDEN_DIM))
        h_pca_u = pca_u.fit_transform(h_untrained)
        rho_u_depth, _ = spearmanr(h_pca_u[:, 0], depths)
        rho_u_time, _ = spearmanr(h_pca_u[:, 0], timesteps)
        results["untrained_pc1_rho"].append(abs(rho_u_depth))
        results["untrained_pc1_time"].append(abs(rho_u_time))

        print(f"Untrained GRU probe: {untrained_acc:.4f} ({untrained_acc*100:.1f}%)"
              f"  |PC1-depth| = {abs(rho_u_depth):.3f}"
              f"  |PC1-time| = {abs(rho_u_time):.3f}", flush=True)

        shuffle_accs = []
        rng = np.random.RandomState(seed)
        for i in range(N_SHUFFLES):
            shuffled = rng.permutation(positions)
            clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
            clf.fit(h_trained[train_mask], shuffled[train_mask])
            s_acc = clf.score(h_trained[test_mask], shuffled[test_mask])
            shuffle_accs.append(s_acc)
        mean_sel = np.mean(shuffle_accs)
        std_sel = np.std(shuffle_accs)
        results["selectivity"].append(mean_sel)
        print(f"Selectivity control: {mean_sel:.4f} +/- {std_sel:.4f} "
              f"({mean_sel*100:.1f}%)", flush=True)

        trained_acc, _, _ = probe_train_test(h_trained, positions, traj_ids)
        results["trained_gru"].append(trained_acc)

        pca_t = PCA(n_components=min(5, HIDDEN_DIM))
        h_pca_t = pca_t.fit_transform(h_trained)
        rho_t_depth, _ = spearmanr(h_pca_t[:, 0], depths)
        rho_t_time, _ = spearmanr(h_pca_t[:, 0], timesteps)
        results["trained_pc1_rho"].append(abs(rho_t_depth))
        results["trained_pc1_time"].append(abs(rho_t_time))

        print(f"Trained GRU probe: {trained_acc:.4f} ({trained_acc*100:.1f}%)"
              f"  |PC1-depth| = {abs(rho_t_depth):.3f}"
              f"  |PC1-time| = {abs(rho_t_time):.3f}", flush=True)

    n = len(results["trained_gru"])
    if n == 0:
        print("\nNo seeds completed.")
        return

    print(f"{'Baseline':<25s} {'Accuracy':>18s}  {'|PC1-depth|':>14s}  {'|PC1-time|':>14s}")

    rows = [
        ("Raw observation", results["raw_obs"], None, None),
        ("Untrained GRU", results["untrained_gru"], results["untrained_pc1_rho"], results["untrained_pc1_time"]),
        ("Selectivity (shuffled)", results["selectivity"], None, None),
        ("Trained GRU", results["trained_gru"], results["trained_pc1_rho"], results["trained_pc1_time"]),
    ]
    for label, accs, rhos, times in rows:
        a = np.array(accs)
        if n == 1:
            acc_str = f"{a[0]:.4f} ({a[0]*100:.1f}%)"
        else:
            acc_str = f"{a.mean():.4f} +/- {a.std():.4f}"
        rho_str = ""
        if rhos is not None:
            r = np.array(rhos)
            rho_str = f"{r.mean():.3f} +/- {r.std():.3f}" if n > 1 else f"{r[0]:.3f}"
        time_str = ""
        if times is not None:
            t = np.array(times)
            time_str = f"{t.mean():.3f} +/- {t.std():.3f}" if n > 1 else f"{t[0]:.3f}"
        print(f"{label:<25s} {acc_str:>18s}  {rho_str:>14s}  {time_str:>14s}")

    print(f"\nChance level: {chance:.4f} ({chance*100:.1f}%)")
    print(f"Memory contribution: trained ({np.mean(results['trained_gru'])*100:.1f}%) "
          f"- raw obs ({np.mean(results['raw_obs'])*100:.1f}%) "
          f"= {(np.mean(results['trained_gru']) - np.mean(results['raw_obs']))*100:.1f} pp")

    save_path = os.path.join(CHECKPOINT_DIR, "probe_baselines.pt")
    torch.save({
        "seeds": seeds,
        "results": results,
        "chance": chance,
        "n_shuffles": N_SHUFFLES,
    }, save_path)
    print(f"\nResults saved: {save_path}")


if __name__ == "__main__":
    main()
