import os
import traceback
import warnings
from datetime import datetime as dt
from pathlib import Path
from typing import Callable, Optional

os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

warnings.filterwarnings("ignore")

import d4rl

warnings.filterwarnings("ignore")

import gym
from comet_ml import Experiment
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import DDPG
from utils.evaluate import evaluate
from utils.utils import process_args


def _evaluate(
    args: DictConfig,
    experiment: Optional[Experiment] = None,
):

    logdir = Path(args.logdir)

    env = gym.make(args.target_env_id, reward_type="sparse")
    env.reset()

    model = DDPG.load(logdir / "ddpg_agent.zip")

    metrics_dict = evaluate(
        args=args,
        env_id=args.target_env_id,
        task_ids=args.inference_task_ids,
        model=model,
        domain_id=1,
        num_task_ids=args.num_task_ids,
        n_episodes=100,
    )

    logfile = logdir / "eval.csv"
    with open(logfile, "w") as f:
        keys = list(metrics_dict.keys())
        for key in keys:
            f.write(f"{key},")
        f.write("\n")

        n = len(metrics_dict[keys[0]])
        for i in range(n):
            for key in keys:
                f.write(f"{metrics_dict[key][i]},")
            f.write("\n")


if __name__ == "__main__":
    base_args = OmegaConf.create(
        {"logdir": "results/test_m2m_o/m2m_o_20220917_112754"})
    cli_args = OmegaConf.from_cli()
    args = OmegaConf.merge(base_args, cli_args)

    conf_args = OmegaConf.load(Path(args.logdir) / "config.yaml")
    args = OmegaConf.merge(args, conf_args)
    OmegaConf.resolve(args)

    args = process_args(args, "align", args.inference_task_ids)

    experiment = None

    _evaluate(args)
