import concurrent.futures
import mlflow
from tqdm import tqdm
from rddleval.scripts.evaluate import evaluate_instance
import pandas as pd
from regawa.gnn import GraphAgent
import json
from regawa.rddl import register_env

from get_artifacts import get_all_paths
from get_run_ids import save_run_ids

env_id = register_env()
import logging

logger = logging.getLogger(__file__)


def evaluate_run(r: pd.Series, model_paths: dict[str, str]):
    device = "cuda:0"
    run_id = r["run_id"]
    domain = r["params.domain"]
    try:
        agent_path = model_paths[run_id]
    except KeyError:
        return {}
    # tqdm.set_description(domain)
    remove_false = True
    try:
        agent, _ = GraphAgent.load_agent(agent_path, device=device)
    except RuntimeError as e:
        logger.error(e)
        return {}
    results = {}
    instances = range(1, 11)
    stats = {"domain": domain, "instance_returns": []}

    for instance in tqdm(instances, total=10):
        _, h = evaluate_instance(
            env_id, domain, instance, agent, remove_false, 10, verbose=False
        )
        stats["instance_returns"].append(list(h))

    results[run_id] = stats

    return results


if __name__ == "__main__":
    mlflow.set_tracking_uri("http://localhost:5000")
    get_all_paths()
    save_run_ids()

    logger.addHandler(logging.FileHandler("evaluate.log"))

    with open("evaluation_results.json", "r") as f:
        results = json.load(f)

    runs = pd.read_csv("runs.csv")
    with open("model_paths.json") as f:
        model_paths = json.load(f)

    with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor:
        d = {
            executor.submit(evaluate_run, r, model_paths): r["run_id"]
            for i, r in runs.iterrows()
            if r["run_id"] not in results
        }

        for future in concurrent.futures.as_completed(d):
            run_id = d[future]
            try:
                data = future.result()
            except Exception as exc:
                print("%r generated an exception: %s" % (run_id, exc))
            else:
                results.update(data)
    with open("evaluation_results.json", "w") as f:
        json.dump(results, f)
