import json
from pathlib import Path
import torch
from regawa.gnn import GraphAgent

from regawa.rddl import register_env
import mlflow
from tqdm import tqdm
from rddleval.scripts.evaluate import evaluate_instance
from get_artifacts import get_artifact_paths


def f():
    env_id = register_env()
    query = "params.using_scaling = 'True'"
    paths = get_artifact_paths(query)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    for run_id, model_path in tqdm(paths.items(), total=len(paths.items())):
        with mlflow.start_run(run_id=run_id) as run:
            domain = run.data.params["domain"]
            tqdm.set_description(domain)
            remove_false = run.data.params["remove_false"] == "True"
            agent, _ = GraphAgent.load_agent(model_path, device=device)
            instances = range(1, 11)
            for instance in tqdm(instances, total=10):
                stats = evaluate_instance(
                    env_id, domain, instance, agent, remove_false, 10, verbose=False
                )

                for k, v in stats.items():
                    mlflow.log_metric(f"eval_per_instance/i{instance}__return_{k}", v)


def main(args: list[str]):
    domain = args[1]
    agent_path = args[2]
    remove_false = args[3] == "True"
    env_id = register_env()
    output_dir = Path(args[4])
    output_dir.mkdir(exist_ok=True)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    agent, _ = GraphAgent.load_agent(agent_path, device=device)
    instances = range(1, 11)
    for instance in instances:
        data = evaluate_instance(env_id, domain, instance, agent, remove_false)
        with open(output_dir / f"{domain}_{instance}_gnn.json", "w") as f:
            json.dump(data, f)


if __name__ == "__main__":
    f()
