"""Worker script: train one (model_type, seed) policy run with structural observations."""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import torch
import numpy as np

from src.gru_policy import GRUPolicy, train_gru_policy, MLPPolicy, train_mlp_policy
from src.evaluation import (
    compute_log_likelihood_gru, compute_prediction_accuracy,
    compute_log_likelihood_mlp, compute_prediction_accuracy_mlp,
    compute_per_node_accuracy, compute_per_node_accuracy_mlp,
)

N_ACTIONS = 3
HIDDEN_DIM = 128
N_EPOCHS = 200
BATCH_SIZE = 64
MAX_SEQ_LEN = 200


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    """Build observation dataset with structural observations and BC targets.

    Truncates sequences longer than max_len into non-overlapping chunks.
    """
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)

        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]

            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq,
                "targets": target_seq,
                "actions": action_seq,
                "states": state_seq,
            })
    return dataset


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, choices=["gru", "mlp"])
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ckpt", required=True)
    parser.add_argument("--outdir", required=True)
    args = parser.parse_args()

    torch.set_num_threads(2)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    ckpt = torch.load(args.ckpt, weights_only=False)
    bc_policy = ckpt["bc_policy"]
    structural_obs = ckpt["structural_obs"]
    train_trajs = ckpt["train_trajs"]
    val_trajs = ckpt["val_trajs"]
    obs_dim = ckpt["obs_dim"]

    train_data = build_obs_dataset_structural(train_trajs, structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(val_trajs, structural_obs, bc_policy)

    if args.model == "gru":
        pol = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        pol, _ = train_gru_policy(pol, train_data, n_epochs=N_EPOCHS,
                                  lr=3e-4, batch_size=BATCH_SIZE,
                                  print_every=N_EPOCHS + 1)
        ll = compute_log_likelihood_gru(pol, val_data)
        acc = compute_prediction_accuracy(pol, val_data)
        per_node = compute_per_node_accuracy(pol, val_data)
    else:
        pol = MLPPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
        pol, _ = train_mlp_policy(pol, train_data, n_epochs=N_EPOCHS,
                                  lr=3e-4, batch_size=BATCH_SIZE,
                                  print_every=N_EPOCHS + 1)
        ll = compute_log_likelihood_mlp(pol, val_data)
        acc = compute_prediction_accuracy_mlp(pol, val_data)
        per_node = compute_per_node_accuracy_mlp(pol, val_data)

    out_path = os.path.join(args.outdir, f"{args.model}_s{args.seed}.pt")
    torch.save({"ll": ll, "acc": acc, "per_node": per_node}, out_path)
    print(f"{args.model.upper()} seed={args.seed}: LL={ll:.4f}, Acc={acc:.3f}", flush=True)


if __name__ == "__main__":
    main()
