import json
from regawa import GraphAgent
from rddleval.eval import evaluate_instance
from tqdm import tqdm
from rddleval.eval_data import EvalEntry
from regawa.rddl import register_env
import torch
import sys
from dataclasses import asdict

env_id = register_env()


def run(batch_id: str, domain: str, agent_path: str, episodes: int):
    agent, _ = GraphAgent.load_agent(
        agent_path, device="cuda:0" if torch.cuda.is_available() else "cpu"
    )
    instance_returns = [
        list(
            evaluate_instance(
                env_id, domain, instance, agent, True, episodes, verbose=False
            )[1]
        )
        for instance in tqdm(range(1, 11), total=10)
    ]

    run_id = agent_path.split("/")[3]

    stats = EvalEntry(
        batch_id, run_id=run_id, domain=domain, instance_returns=instance_returns
    )
    return stats


def main():
    batch_id = sys.argv[1]
    domain = sys.argv[2]
    agent_path = sys.argv[3]

    stats = run(batch_id, domain, agent_path, 100)
    print(json.dumps(asdict(stats)))


if __name__ == "__main__":
    main()
