"""Probe GRU hidden states for implicit previous-action information."""
import os
import sys

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder

from src.gru_policy import GRUPolicy
from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.stats import bca_bootstrap_ci

sys.path.insert(0, os.path.join(PROJECT_ROOT, "scripts"))
from bayesian_filter_baseline import build_obs_class_map

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
OBS_DIM = 12
MAX_SEQ_LEN = 200
SEEDS = [0, 1, 2, 3, 4]
ACTION_NAMES = ["left", "right", "reverse"]


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    degree = 2.0 if node == 0 else (1.0 if node >= first_leaf else 3.0)
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        d_degree = 2.0 if dest == 0 else (1.0 if dest >= first_leaf else 3.0)
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_depth(node):
    d, n = 0, node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def collect_hidden_with_actions(policy, val_data):
    """Collect hidden states paired with previous actions.

    For each chunk, at timestep t >= 1 within the chunk:
      - h_t: hidden state after processing obs[0..t]
      - a_{t-1}: action that transitioned states[t-1] -> states[t]
      - obs_class_prev, obs_class_curr: observation classes at t-1 and t

    Returns arrays of (hidden_states, prev_actions, obs_transitions,
    traj_ids, depths, timesteps).
    """
    obs_class_map, _ = build_obs_class_map()
    policy.eval()

    all_hidden = []
    all_prev_actions = []
    all_obs_prev = []
    all_obs_curr = []
    all_traj_ids = []
    all_depths = []
    all_timesteps = []

    with torch.no_grad():
        for tid, d in enumerate(val_data):
            states = d["states"]
            actions = d["actions"]
            T_actions = len(actions)

            h = torch.zeros(1, 1, policy.hidden_dim)
            for t in range(len(states)):
                obs_t = d["obs"][t].unsqueeze(0).unsqueeze(0)
                encoded = policy.obs_encoder(obs_t)
                _, h = policy.gru(encoded, h)

                if t >= 1 and (t - 1) < T_actions:
                    all_hidden.append(h.squeeze().numpy().copy())
                    all_prev_actions.append(int(actions[t - 1]))
                    all_obs_prev.append(obs_class_map[states[t - 1]])
                    all_obs_curr.append(obs_class_map[states[t]])
                    all_traj_ids.append(tid)
                    all_depths.append(get_depth(states[t]))
                    all_timesteps.append(t)

    return {
        "hidden": np.array(all_hidden),
        "prev_actions": np.array(all_prev_actions),
        "obs_prev": np.array(all_obs_prev),
        "obs_curr": np.array(all_obs_curr),
        "traj_ids": np.array(all_traj_ids),
        "depths": np.array(all_depths),
        "timesteps": np.array(all_timesteps),
    }


def obs_transition_features(obs_prev, obs_curr, n_classes=5):
    """One-hot encode (obs_class_prev, obs_class_curr) pair."""
    n = len(obs_prev)
    features = np.zeros((n, n_classes * n_classes))
    for i in range(n):
        features[i, obs_prev[i] * n_classes + obs_curr[i]] = 1.0
    return features


def probe_accuracy(X, y, traj_ids, seed=42):
    """80/20 trajectory-level split, logistic regression probe."""
    unique_tids = np.unique(traj_ids)
    rng = np.random.RandomState(seed)
    rng.shuffle(unique_tids)
    n_train = int(0.8 * len(unique_tids))
    train_tids = set(unique_tids[:n_train])
    train_mask = np.array([tid in train_tids for tid in traj_ids])
    test_mask = ~train_mask

    clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
    clf.fit(X[train_mask], y[train_mask])
    acc = clf.score(X[test_mask], y[test_mask])

    preds = clf.predict(X[test_mask])
    y_test = y[test_mask]
    per_action = {}
    for a in range(N_ACTIONS):
        mask_a = y_test == a
        if mask_a.sum() > 0:
            per_action[a] = float((preds[mask_a] == a).mean())
        else:
            per_action[a] = 0.0

    return acc, per_action, test_mask


def main():
    print("Testing whether GRU hidden states encode previous actions")
    print("(actions are NEVER part of the GRU input)\n")

    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}

    val_data = []
    for traj in d["val_trajs"]:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, MAX_SEQ_LEN):
            end = min(start + MAX_SEQ_LEN, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions],
                                      dtype=torch.long)
            val_data.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq,
                "states": [int(s) for s in chunk_states],
            })

    print(f"{len(d['val_trajs'])} val trajectories, {len(val_data)} chunks\n")

    all_actions = []
    for chunk in val_data:
        all_actions.extend(chunk["actions"].numpy().tolist())
    all_actions = np.array(all_actions)
    print("Action distribution in validation data:")
    for a in range(N_ACTIONS):
        frac = (all_actions == a).mean()
        print(f"{ACTION_NAMES[a]}: {100 * frac:.1f}%")
    print()

    seed_results = []

    for seed in SEEDS:
        print(f"Seed {seed} ", flush=True)
        model_path = f"checkpoints/structural_gru_model_seed{seed}.pt"
        if not os.path.exists(model_path):
            print(f"Model not found, skipping\n")
            continue

        trained = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM,
                            n_actions=N_ACTIONS)
        trained.load_state_dict(
            torch.load(model_path, map_location="cpu", weights_only=True))
        trained.eval()

        data = collect_hidden_with_actions(trained, val_data)
        H = data["hidden"]
        y = data["prev_actions"]
        traj_ids = data["traj_ids"]
        obs_prev = data["obs_prev"]
        obs_curr = data["obs_curr"]
        depths = data["depths"]
        timesteps = data["timesteps"]

        n_samples = len(y)
        print(f"{n_samples} timesteps with previous actions")

        gru_acc, gru_per_action, test_mask = probe_accuracy(
            H, y, traj_ids, seed=seed)

        obs_trans = obs_transition_features(obs_prev, obs_curr)
        obs_acc, obs_per_action, _ = probe_accuracy(
            obs_trans, y, traj_ids, seed=seed)

        torch.manual_seed(seed + 1000)
        untrained = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM,
                              n_actions=N_ACTIONS)
        untrained.eval()
        data_unt = collect_hidden_with_actions(untrained, val_data)
        H_unt = data_unt["hidden"]

        unt_acc, unt_per_action, _ = probe_accuracy(
            H_unt, y, traj_ids, seed=seed)

        unique_tids = np.unique(traj_ids)
        rng = np.random.RandomState(seed)
        rng.shuffle(unique_tids)
        n_train = int(0.8 * len(unique_tids))
        train_tids = set(unique_tids[:n_train])
        train_mask = np.array([tid in train_tids for tid in traj_ids])
        t_mask = ~train_mask

        clf_full = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
        clf_full.fit(H[train_mask], y[train_mask])
        preds_full = clf_full.predict(H[t_mask])

        depth_accs = {}
        for dep in sorted(np.unique(depths)):
            dep_mask = depths[t_mask] == dep
            if dep_mask.sum() > 0:
                depth_accs[dep] = float(
                    (preds_full[dep_mask] == y[t_mask][dep_mask]).mean())

        ts = timesteps[t_mask]
        early = ts <= 10
        mid = (ts > 10) & (ts <= 100)
        late = ts > 100
        temporal_accs = {}
        for label, mask in [("early (t<=10)", early),
                            ("mid (10<t<=100)", mid),
                            ("late (t>100)", late)]:
            if mask.sum() > 0:
                temporal_accs[label] = float(
                    (preds_full[mask] == y[t_mask][mask]).mean())

        print(f"obs-trans {100*obs_acc:.1f}%, untrained {100*unt_acc:.1f}%, "
              f"trained {100*gru_acc:.1f}% (+{100*(gru_acc-obs_acc):.1f} pp vs obs)")
        for a in range(N_ACTIONS):
            print(f"  {ACTION_NAMES[a]}: {100*gru_per_action[a]:.1f}% (obs-trans {100*obs_per_action[a]:.1f}%)")
        for dep in sorted(depth_accs.keys()):
            print(f"  depth {dep}: {100*depth_accs[dep]:.1f}%")
        for label, acc in temporal_accs.items():
            print(f"  {label}: {100*acc:.1f}%")

        seed_results.append({
            "seed": seed,
            "gru_acc": gru_acc,
            "obs_trans_acc": obs_acc,
            "untrained_acc": unt_acc,
            "gru_per_action": gru_per_action,
            "obs_per_action": obs_per_action,
            "depth_accs": depth_accs,
            "temporal_accs": temporal_accs,
        })
        print()

    if not seed_results:
        print("No results.")
        return

    n = len(seed_results)
    gru_accs = np.array([r["gru_acc"] for r in seed_results])
    obs_accs = np.array([r["obs_trans_acc"] for r in seed_results])
    unt_accs = np.array([r["untrained_acc"] for r in seed_results])
    gaps = gru_accs - obs_accs

    gru_ci = bca_bootstrap_ci(gru_accs)
    obs_ci = bca_bootstrap_ci(obs_accs)
    gap_ci = bca_bootstrap_ci(gaps)

    print(f"obs-trans {100*obs_accs.mean():.1f}%, untrained {100*unt_accs.mean():.1f}%, "
          f"trained {100*gru_accs.mean():.1f}% +/- {100*gru_accs.std():.1f}%, "
          f"gap {100*gaps.mean():+.1f} pp 95% CI [{100*gap_ci[0]:+.1f}, {100*gap_ci[1]:+.1f}]")

    for a in range(N_ACTIONS):
        gru_a = np.mean([r["gru_per_action"][a] for r in seed_results])
        obs_a = np.mean([r["obs_per_action"][a] for r in seed_results])
        print(f"{ACTION_NAMES[a]}: GRU {100 * gru_a:.1f}%  "
              f"obs-trans {100 * obs_a:.1f}%  "
              f"gap {100 * (gru_a - obs_a):+.1f} pp")

    os.makedirs("checkpoints", exist_ok=True)
    save_data = {
        "seed_results": seed_results,
        "gru_accs": gru_accs,
        "obs_trans_accs": obs_accs,
        "untrained_accs": unt_accs,
        "gaps": gaps,
        "gru_ci": gru_ci,
        "gap_ci": gap_ci,
    }
    torch.save(save_data, "checkpoints/action_probe.pt")
    print(f"\nSaved to checkpoints/action_probe.pt")


if __name__ == "__main__":
    main()
