"""Worker script: train one leave-one-out fold (IRL + GRU + eval + optional fine-tune)."""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import copy
import time
import torch
import numpy as np
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

from src.maxent_irl import MaxCausalEntIRL, train_irl
from src.gru_policy import GRUPolicy, train_gru_policy
from src.evaluation import compute_log_likelihood_gru, compute_prediction_accuracy
from src.analysis import collect_hidden_states_generic
from src.utils import trajectories_to_sa_pairs

N_ACTIONS = 3
N_STATES = 127
HIDDEN_DIM = 128
IRL_EPOCHS = 300
GRU_EPOCHS = 200
FT_EPOCHS = 20
BATCH_SIZE = 64
MAX_SEQ_LEN = 200


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2

    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0

    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0

    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def build_obs_dataset(trajs, structural_obs, pi_star, max_len=MAX_SEQ_LEN):
    """Build observation dataset with structural obs and IRL-derived soft targets."""
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)

        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]

            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([pi_star[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq,
                "targets": target_seq,
                "actions": action_seq,
                "states": state_seq,
            })
    return dataset


def run_probing(policy, val_data, n_states=127):
    """Run linear probe and PC1-depth correlation on hidden states."""
    hidden_states, positions, _, _ = collect_hidden_states_generic(policy, val_data)

    if len(np.unique(positions)) < 2:
        return 0.0, 0.0

    clf = LogisticRegression(max_iter=1000, solver="lbfgs")
    clf.fit(hidden_states, positions)
    probe_acc = clf.score(hidden_states, positions)

    pca = PCA(n_components=1)
    pc1 = pca.fit_transform(hidden_states).ravel()
    depths = np.array([int(np.floor(np.log2(p + 1))) for p in positions])
    rho, _ = spearmanr(pc1, depths)

    return probe_acc, abs(rho)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", required=True)
    parser.add_argument("--outdir", required=True)
    parser.add_argument("--irl_epochs", type=int, default=IRL_EPOCHS)
    parser.add_argument("--gru_epochs", type=int, default=GRU_EPOCHS)
    parser.add_argument("--ft_epochs", type=int, default=FT_EPOCHS)
    args = parser.parse_args()

    torch.set_num_threads(2)

    ckpt = torch.load(args.ckpt, weights_only=False)
    train_trajs = ckpt["train_trajs"]
    test_trajs = ckpt["test_trajs"]
    fold_type = ckpt["fold_type"]
    held_out = ckpt["held_out"]
    T_tensor = ckpt["T_tensor"]

    torch.manual_seed(42)
    np.random.seed(42)

    t0 = time.time()

    train_sa = trajectories_to_sa_pairs(train_trajs)
    irl_model = MaxCausalEntIRL(N_STATES, T_tensor, gamma=0.99, n_vi_iters=200, l2_reg=0.01)
    train_irl(irl_model, train_sa, n_epochs=args.irl_epochs, lr=0.01,
              print_every=max(1, args.irl_epochs // 3))

    with torch.no_grad():
        _, Q_soft, pi_star = irl_model.soft_vi(irl_model.reward_params)
    pi_star = pi_star.detach()

    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)
    obs_dim = len(structural_obs[0])

    train_data = build_obs_dataset(train_trajs, structural_obs, pi_star)
    test_data = build_obs_dataset(test_trajs, structural_obs, pi_star)

    gru = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
    gru, history = train_gru_policy(gru, train_data, n_epochs=args.gru_epochs,
                                    lr=3e-4, batch_size=BATCH_SIZE,
                                    print_every=max(1, args.gru_epochs // 3))

    ll = compute_log_likelihood_gru(gru, test_data)
    acc = compute_prediction_accuracy(gru, test_data)
    probe_acc, pc1_rho = run_probing(gru, test_data)

    train_time = time.time() - t0

    result = {
        "fold_type": fold_type,
        "held_out": held_out,
        "ll": ll,
        "acc": acc,
        "probe_acc": probe_acc,
        "pc1_rho": pc1_rho,
        "train_time": train_time,
    }

    if fold_type == "cross":
        ft_gru = copy.deepcopy(gru)

        n_test = len(test_trajs)
        rng = np.random.RandomState(42)
        perm = rng.permutation(n_test)
        n_ft_train = int(0.8 * n_test)
        ft_train_trajs = [test_trajs[i] for i in perm[:n_ft_train]]
        ft_val_trajs = [test_trajs[i] for i in perm[n_ft_train:]]

        ft_train_data = build_obs_dataset(ft_train_trajs, structural_obs, pi_star)
        ft_val_data = build_obs_dataset(ft_val_trajs, structural_obs, pi_star)

        ft_gru, _ = train_gru_policy(ft_gru, ft_train_data, n_epochs=args.ft_epochs,
                                     lr=1e-4, batch_size=BATCH_SIZE,
                                     print_every=max(1, args.ft_epochs // 3))

        ft_ll = compute_log_likelihood_gru(ft_gru, ft_val_data)
        ft_acc = compute_prediction_accuracy(ft_gru, ft_val_data)
        ft_probe_acc, ft_pc1_rho = run_probing(ft_gru, ft_val_data)

        result["ft_ll"] = ft_ll
        result["ft_acc"] = ft_acc
        result["ft_probe_acc"] = ft_probe_acc
        result["ft_pc1_rho"] = ft_pc1_rho

    out_path = os.path.join(args.outdir, f"{fold_type}_{held_out}.pt")
    torch.save(result, out_path)

    tag = f"{fold_type}/{held_out}"
    ft_str = f", FT_LL={result.get('ft_ll', 0):.4f}" if fold_type == "cross" else ""
    print(f"{tag}: LL={ll:.4f}, Acc={acc:.3f}, Probe={probe_acc:.3f}, "
          f"PC1-rho={pc1_rho:.3f}{ft_str}, Time={train_time:.0f}s", flush=True)


if __name__ == "__main__":
    main()
