from dataclasses import asdict
import json
from pathlib import Path
import sys
import torch
from regawa.gnn import GraphAgent

from regawa.rddl import register_env
from tqdm import tqdm

from rddleval.eval_data import EvalEntry
from rddleval.eval import evaluate_instance


def main():
    args = sys.argv
    domain = args[1]
    run_id = args[2]
    agent_path = args[3]
    remove_false = args[4] == "True"
    batch_id = args[5]
    episodes = 100
    env_id = register_env()
    output_dir = Path(args[4])
    output_dir.mkdir(exist_ok=True)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    agent, _ = GraphAgent.load_agent(agent_path, device=device)

    instances = range(1, 11)
    instance_returns: list[list[float]] = []
    for instance in tqdm(instances):
        _, r = evaluate_instance(
            env_id, domain, instance, agent, remove_false, episodes=episodes
        )
        instance_returns.append(list(r))

    print(json.dumps(asdict(EvalEntry(batch_id, run_id, domain, instance_returns))))


if __name__ == "__main__":
    main()
