"""Worker script: train one (model_type, seed) for the radial arm maze."""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import torch
import numpy as np

from src.gru_policy import GRUPolicy, train_gru_policy, MLPPolicy, train_mlp_policy
from src.evaluation import (
    compute_log_likelihood_gru, compute_prediction_accuracy,
    compute_log_likelihood_mlp, compute_prediction_accuracy_mlp,
    compute_per_node_accuracy, compute_per_node_accuracy_mlp,
)
import src.radial_arm_env as ram

HIDDEN_DIM = 128
N_EPOCHS = 50
BATCH_SIZE = 64
MAX_SEQ_LEN = 200


def build_obs_dataset(trajs, obs_encoding, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        traj_targets = traj["targets"]
        T = len(actions)

        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end]
            chunk_actions = actions[start:end]
            chunk_targets = traj_targets[start:end]

            obs_seq = torch.stack([obs_encoding[s] for s in chunk_states])
            target_seq = torch.tensor(chunk_targets, dtype=torch.float32)
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq,
                "targets": target_seq,
                "actions": action_seq,
                "states": state_seq,
            })
    return dataset


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True, choices=["gru", "mlp"])
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ckpt", required=True)
    parser.add_argument("--outdir", required=True)
    args = parser.parse_args()

    torch.set_num_threads(2)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    ckpt = torch.load(args.ckpt, weights_only=False)
    obs_encoding = ckpt["obs_encoding"]
    train_trajs = ckpt["train_trajs"]
    val_trajs = ckpt["val_trajs"]
    obs_dim = ckpt["obs_dim"]

    train_data = build_obs_dataset(train_trajs, obs_encoding)
    val_data = build_obs_dataset(val_trajs, obs_encoding)

    if args.model == "gru":
        pol = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM,
                        n_actions=ram.N_ACTIONS)
        pol, hist = train_gru_policy(pol, train_data, n_epochs=N_EPOCHS,
                                     lr=3e-4, batch_size=BATCH_SIZE,
                                     print_every=N_EPOCHS + 1)
        ll = compute_log_likelihood_gru(pol, val_data)
        acc = compute_prediction_accuracy(pol, val_data)
        per_node = compute_per_node_accuracy(pol, val_data)
    else:
        pol = MLPPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM,
                        n_actions=ram.N_ACTIONS)
        pol, hist = train_mlp_policy(pol, train_data, n_epochs=N_EPOCHS,
                                     lr=3e-4, batch_size=BATCH_SIZE,
                                     print_every=N_EPOCHS + 1)
        ll = compute_log_likelihood_mlp(pol, val_data)
        acc = compute_prediction_accuracy_mlp(pol, val_data)
        per_node = compute_per_node_accuracy_mlp(pol, val_data)

    out_path = os.path.join(args.outdir, f"{args.model}_s{args.seed}.pt")
    torch.save({
        "ll": ll, "acc": acc, "per_node": per_node,
        "loss_history": hist["loss"],
        "model_state": pol.state_dict(),
    }, out_path)
    print(f"{args.model.upper()} seed={args.seed}: LL={ll:.4f}, Acc={acc:.3f}",
          flush=True)


if __name__ == "__main__":
    main()
