"""Worker script: train one (arch, seed) policy run for architecture baselines."""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import time
import torch
import numpy as np
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

from src.gru_policy import GRUPolicy, train_gru_policy
from src.architectures import ARCH_REGISTRY
from src.evaluation import (
    compute_log_likelihood_gru, compute_prediction_accuracy,
    compute_per_node_accuracy,
)
from src.analysis import collect_hidden_states_generic

N_ACTIONS = 3
HIDDEN_DIM = 64
N_EPOCHS = 200
BATCH_SIZE = 64
MAX_SEQ_LEN = 200


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)

        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]

            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq,
                "targets": target_seq,
                "actions": action_seq,
                "states": state_seq,
            })
    return dataset


def run_probing(policy, val_data, n_states=127):
    """Run linear probe and PC1-depth correlation on hidden states."""
    hidden_states, positions, _, _ = collect_hidden_states_generic(policy, val_data)

    if len(np.unique(positions)) < 2:
        probe_acc = 0.0
    else:
        clf = LogisticRegression(max_iter=1000, solver="lbfgs")
        clf.fit(hidden_states, positions)
        probe_acc = clf.score(hidden_states, positions)

    pca = PCA(n_components=1)
    pc1 = pca.fit_transform(hidden_states).ravel()
    depths = np.array([int(np.floor(np.log2(p + 1))) for p in positions])
    rho, _ = spearmanr(pc1, depths)

    return probe_acc, abs(rho)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--arch", required=True, choices=["gru", "lstm", "gtrxl", "mamba", "mingru"])
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ckpt", required=True)
    parser.add_argument("--outdir", required=True)
    parser.add_argument("--n_epochs", type=int, default=N_EPOCHS)
    args = parser.parse_args()

    torch.set_num_threads(2)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    ckpt = torch.load(args.ckpt, weights_only=False)
    bc_policy = ckpt["bc_policy"]
    structural_obs = ckpt["structural_obs"]
    train_trajs = ckpt["train_trajs"]
    val_trajs = ckpt["val_trajs"]
    obs_dim = ckpt["obs_dim"]

    train_data = build_obs_dataset_structural(train_trajs, structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(val_trajs, structural_obs, bc_policy)

    if args.arch == "gru":
        pol = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
    else:
        Cls = ARCH_REGISTRY[args.arch]
        pol = Cls(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)

    n_params = sum(p.numel() for p in pol.parameters())

    t0 = time.time()
    pol, history = train_gru_policy(pol, train_data, n_epochs=args.n_epochs,
                                    lr=3e-4, batch_size=BATCH_SIZE,
                                    print_every=args.n_epochs + 1)
    train_time = time.time() - t0

    ll = compute_log_likelihood_gru(pol, val_data)
    acc = compute_prediction_accuracy(pol, val_data)
    per_node = compute_per_node_accuracy(pol, val_data)

    probe_acc, pc1_rho = run_probing(pol, val_data)

    out_path = os.path.join(args.outdir, f"{args.arch}_s{args.seed}.pt")
    torch.save({
        "ll": ll,
        "acc": acc,
        "per_node": per_node,
        "probe_acc": probe_acc,
        "pc1_rho": pc1_rho,
        "train_time": train_time,
        "n_params": n_params,
        "loss_history": history["loss"],
    }, out_path)
    print(f"{args.arch.upper()} seed={args.seed}: LL={ll:.4f}, Acc={acc:.3f}, "
          f"Probe={probe_acc:.3f}, PC1-rho={pc1_rho:.3f}, "
          f"Time={train_time:.0f}s, Params={n_params}", flush=True)


if __name__ == "__main__":
    main()
