from regawa.rl.util import evaluate, save_eval_data
import gymnasium as gym
from dataclasses import dataclass
import numpy as np
from regawa.gnn import GraphAgent


@dataclass
class Args:
    env_id: str
    domain: str
    instance: int
    agent_file: str
    output_file: str
    remove_false: bool = False


def main(args: Args):
    agent, config = GraphAgent.load_agent(args.agent_file)
    eval_env = gym.make(
        args.env_id,
        domain=args.domain,
        instance=args.instance,
        remove_false=args.remove_false,
    )

    seeds = range(75)
    data = [evaluate(eval_env, agent, seed, deterministic=True) for seed in seeds]
    rewards, *_ = zip(*data)
    avg_mean_reward = np.mean([np.mean(r) for r in rewards])
    returns = [np.sum(r) for r in rewards]
    save_eval_data(data, f"{args.output_file}.json")

    stats = {
        "mean": np.mean(returns),
        "median": np.median(returns),
        "min": np.min(returns),
        "max": np.max(returns),
        "std": np.std(returns),
    }

    print(stats)
    print(f"avg_reward: {avg_mean_reward}")


if __name__ == "__main__":
    from regawa.rddl import register_env

    env_id = register_env()
    main(
        Args(
            env_id=env_id,
            domain="AcademicAdvising_ippc2018",
            instance=1,
            remove_false=True,
            agent_file="AcademicAdvising_ippc2018__1__ppo_gnn__0.pth",
            output_file="eval_data",
        )
    )
