import sys
import os
import csv
import json
import argparse
import fnmatch
import re
from collections import defaultdict

import numpy as np
import torch
from baselines.common.vec_env import DummyVecEnv
from baselines.logger import HumanOutputFormat
from tqdm import tqdm

import os
import matplotlib as mpl
import matplotlib.pyplot as plt

from envs.registration import make as gym_make
from envs.multigrid.maze import *
from envs.multigrid.crossing import *
from envs.multigrid.fourrooms import *
from envs.multigrid.mst_maze import *
from envs.box2d import *
from envs.bipedalwalker import *
from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
    MultiGridFullyObsWrapper,
    VecFrameStack,
    CarRacingWrapper,
)
from util import DotDict, str2bool, make_agent, create_parallel_env, is_discrete_actions
from arguments import parser

"""
Example usage:

python -m eval \
--env_name=MultiGrid-SixteenRooms-v0 \
--xpid=<xpid> \
--base_path="~/logs/dcd" \
--result_path="eval_results/"
--verbose
"""


def parse_args():
    parser = argparse.ArgumentParser(description="Eval")

    parser.add_argument(
        "--base_path",
        type=str,
        default="~/logs/dcd",
        help="Base path to experiment results directories.",
    )
    parser.add_argument(
        "--xpid",
        type=str,
        default="latest",
        help="Experiment ID (result directory name) for evaluation.",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default=None,
        help="Experiment ID prefix for evaluation (evaluate all matches).",
    )
    parser.add_argument(
        "--env_names",
        type=str,
        default="MultiGrid-Labyrinth-v0",
        help="CSV string of evaluation environments.",
    )
    parser.add_argument(
        "--result_path",
        type=str,
        default="eval_results/",
        help="Relative path to evaluation results directory.",
    )
    parser.add_argument(
        "--benchmark",
        type=str,
        default=None,
        choices=["maze", "f1", "bipedal", "poetrose"],
        help="Name of benchmark for evaluation.",
    )
    parser.add_argument(
        "--accumulator",
        type=str,
        default=None,
        help="Function for accumulating across multiple evaluation runs.",
    )
    parser.add_argument(
        "--singleton_env",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="When using a fixed env, whether the same environment should also be reused across workers.",
    )
    parser.add_argument("--seed", type=int, default=1, help="Random seed.")
    parser.add_argument(
        "--max_seeds",
        type=int,
        default=None,
        help="Maximum number of matched experiment IDs to evaluate.",
    )
    parser.add_argument(
        "--num_processes", type=int, default=2, help="Number of CPU processes to use."
    )
    parser.add_argument(
        "--max_num_processes",
        type=int,
        default=10,
        help="Maximum number of CPU processes to use.",
    )
    parser.add_argument(
        "--num_episodes",
        type=int,
        default=100,
        help="Number of evaluation episodes per xpid per environment.",
    )
    parser.add_argument(
        "--model_tar", type=str, default="model", help="Name of .tar to evaluate."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="agent",
        choices=["agent", "adversary_agent"],
        help="Which agent to evaluate.",
    )
    parser.add_argument(
        "--deterministic",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="Evaluate policy greedily.",
    )
    parser.add_argument(
        "--verbose",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="Show logging messages in stdout",
    )
    parser.add_argument(
        "--render",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="Render environment in first evaluation process to screen.",
    )
    parser.add_argument(
        "--record_video",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="Record video of first environment evaluation process.",
    )

    return parser.parse_args()


class Evaluator(object):
    def __init__(
        self,
        env_names,
        num_processes,
        num_episodes=10,
        record_video=False,
        device="cpu",
        **kwargs,
    ):
        self.kwargs = kwargs  # kwargs for env wrappers
        self._init_parallel_envs(
            env_names, num_processes, device=device, record_video=record_video, **kwargs
        )
        self.num_episodes = num_episodes
        if "Bipedal" in env_names[0]:
            self.solved_threshold = 230
        else:
            self.solved_threshold = 0

    def get_stats_keys(self):
        keys = []
        for env_name in self.env_names:
            keys += [f"solved_rate:{env_name}", f"test_returns:{env_name}"]
        return keys

    @staticmethod
    def make_env(env_name, record_video=False, **kwargs):
        if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
            env = gym.make(env_name)
        else:
            env = gym_make(env_name)

        is_multigrid = env_name.startswith("MultiGrid")
        is_car_racing = env_name.startswith("CarRacing")

        if is_car_racing:
            grayscale = kwargs.get("grayscale", False)
            num_action_repeat = kwargs.get("num_action_repeat", 8)
            nstack = kwargs.get("frame_stack", 4)
            crop = kwargs.get("crop_frame", False)

            env = CarRacingWrapper(
                env=env,
                grayscale=grayscale,
                reward_shaping=False,
                num_action_repeat=num_action_repeat,
                nstack=nstack,
                crop=crop,
                eval_=True,
            )

            if record_video:
                from gym.wrappers.monitor import Monitor

                env = Monitor(env, "videos/", force=True)
                print("Recording video!", flush=True)

        if is_multigrid and kwargs.get("use_global_policy"):
            env = MultiGridFullyObsWrapper(env, is_adversarial=False)

        return env

    @staticmethod
    def wrap_venv(venv, env_name, device="cpu"):
        is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
        is_car_racing = env_name.startswith("CarRacing")
        is_bipedal = env_name.startswith("BipedalWalker")

        obs_key = None
        scale = None
        if is_multigrid:
            obs_key = "image"
            scale = 10.0

        # Channels first
        transpose_order = [2, 0, 1]

        if is_bipedal:
            transpose_order = None

        venv = VecMonitor(venv=venv, filename=None, keep_buf=100)

        venv = VecPreprocessImageWrapper(
            venv=venv,
            obs_key=obs_key,
            transpose_order=transpose_order,
            scale=scale,
            device=device,
        )

        return venv

    def _init_parallel_envs(
        self, env_names, num_processes, device=None, record_video=False, **kwargs
    ):
        self.env_names = env_names
        self.num_processes = num_processes
        self.device = device
        self.venv = {env_name: None for env_name in env_names}

        make_fn = []
        for env_name in env_names:
            make_fn = [
                lambda: Evaluator.make_env(env_name, record_video, **kwargs)
            ] * self.num_processes
            venv = ParallelAdversarialVecEnv(make_fn, adversary=False, is_eval=True)
            venv = Evaluator.wrap_venv(venv, env_name, device=device)
            self.venv[env_name] = venv

        self.is_discrete_actions = is_discrete_actions(self.venv[env_names[0]])

    def close(self):
        for _, venv in self.venv.items():
            venv.close()

    def evaluate(
        self,
        agent,
        deterministic=False,
        show_progress=False,
        render=False,
        accumulator="mean",
    ):

        # Evaluate agent for N episodes
        venv = self.venv
        env_returns = {}
        env_solved_episodes = {}

        for env_name, venv in self.venv.items():
            returns = []
            solved_episodes = 0

            obs = venv.reset()
            recurrent_hidden_states = torch.zeros(
                self.num_processes,
                agent.algo.actor_critic.recurrent_hidden_state_size,
                device=self.device,
            )
            if (
                agent.algo.actor_critic.is_recurrent
                and agent.algo.actor_critic.rnn.arch == "lstm"
            ):
                recurrent_hidden_states = (
                    recurrent_hidden_states,
                    torch.zeros_like(recurrent_hidden_states),
                )
            masks = torch.ones(self.num_processes, 1, device=self.device)

            pbar = None
            if show_progress:
                pbar = tqdm(total=self.num_episodes)

            while len(returns) < self.num_episodes:
                # Sample actions
                with torch.no_grad():
                    _, action, _, recurrent_hidden_states = agent.act(
                        obs, recurrent_hidden_states, masks, deterministic=deterministic
                    )

                # Observe reward and next obs
                action = action.cpu().numpy()
                if not self.is_discrete_actions:
                    action = agent.process_action(action)
                obs, reward, done, infos = venv.step(action)

                masks = torch.tensor(
                    [[0.0] if done_ else [1.0] for done_ in done],
                    dtype=torch.float32,
                    device=self.device,
                )

                for i, info in enumerate(infos):
                    if "episode" in info.keys():
                        returns.append(info["episode"]["r"])
                        if returns[-1] > self.solved_threshold:
                            solved_episodes += 1
                        if pbar:
                            pbar.update(1)

                        # zero hidden states
                        if agent.is_recurrent:
                            recurrent_hidden_states[0][i].zero_()
                            recurrent_hidden_states[1][i].zero_()

                        if len(returns) >= self.num_episodes:
                            break

                if render:
                    venv.render_to_screen()

            if pbar:
                pbar.close()

            env_returns[env_name] = returns
            env_solved_episodes[env_name] = solved_episodes

        stats = {}
        for env_name in self.env_names:
            if accumulator == "mean":
                stats[f"solved_rate:{env_name}"] = (
                    env_solved_episodes[env_name] / self.num_episodes
                )

            if accumulator == "mean":
                stats[f"test_returns:{env_name}"] = np.mean(env_returns[env_name])
            else:
                stats[f"test_returns:{env_name}"] = env_returns[env_name]

        return stats


def _get_f1_env_names():
    env_names = [
        f"CarRacingF1-{name}-v0"
        for name, cls in formula1.__dict__.items()
        if isinstance(cls, RaceTrack)
    ]
    env_names.remove("CarRacingF1-LagunaSeca-v0")
    return env_names


def _get_zs_minigrid_env_names():
    env_names = [
        "MultiGrid-SixteenRooms-v0",
        "MultiGrid-SixteenRoomsFewerDoors-v0" "MultiGrid-Labyrinth-v0",
        "MultiGrid-Labyrinth2-v0",
        "MultiGrid-Maze-v0",
        "MultiGrid-Maze2-v0",
        "MultiGrid-LargeCorridor-v0",
        "MultiGrid-PerfectMazeMedium-v0",
        "MultiGrid-PerfectMazeLarge-v0",
        "MultiGrid-PerfectMazeXL-v0",
    ]
    return env_names


def _get_bipedal_env_names():
    env_names = [
        "BipedalWalker-v3",
        "BipedalWalkerHardcore-v3",
        "BipedalWalker-Med-Stairs-v0",
        "BipedalWalker-Med-PitGap-v0",
        "BipedalWalker-Med-StumpHeight-v0",
        "BipedalWalker-Med-Roughness-v0",
    ]
    return env_names


def _get_poet_rose_env_names():
    env_names = [
        f"BipedalWalker-POET-Rose-{id}-v0" for id in ["1a", "1b", "2a", "2b", "3a", "3b"]
    ]
    return env_names


if __name__ == "__main__":
    os.environ["OMP_NUM_THREADS"] = "1"

    display = None
    if sys.platform.startswith("linux"):
        print("Setting up virtual display")

        import pyvirtualdisplay

        display = pyvirtualdisplay.Display(visible=0, size=(1400, 900), color_depth=24)
        display.start()

    args = DotDict(vars(parse_args()))
    args.num_processes = min(args.num_processes, args.num_episodes)

    # === Determine device ====
    device = "cpu"

    # === Load checkpoint ===
    # Load meta.json into flags object
    base_path = os.path.expandvars(os.path.expanduser(args.base_path))

    xpids = [args.xpid]
    if args.prefix is not None:
        all_xpids = fnmatch.filter(os.listdir(base_path), f"{args.prefix}*")
        filter_re = re.compile(".*_[0-9]*$")
        xpids = [x for x in all_xpids if filter_re.match(x)]

    # Set up results management
    os.makedirs(args.result_path, exist_ok=True)
    if args.prefix is not None:
        result_fname = args.prefix
    else:
        result_fname = args.xpid
    result_fname = f"{result_fname}-{args.model_tar}-{args.model_name}"
    result_fpath = os.path.join(args.result_path, result_fname)
    if os.path.exists(f"{result_fpath}.csv"):
        result_fpath = os.path.join(args.result_path, f"{result_fname}_redo")
    result_fpath = f"{result_fpath}.csv"

    csvout = open(result_fpath, "w", newline="")
    csvwriter = csv.writer(csvout)

    env_results = defaultdict(list)

    # Get envs
    if args.benchmark == "maze":
        env_names = _get_zs_minigrid_env_names()
    elif args.benchmark == "f1":
        env_names = _get_f1_env_names()
    elif args.benchmark == "bipedal":
        env_names = _get_bipedal_env_names()
    elif args.benchmark == "poetrose":
        env_names = _get_poet_rose_env_names()
    else:
        env_names = args.env_names.split(",")

    num_envs = len(env_names)
    if num_envs * args.num_processes > args.max_num_processes:
        chunk_size = args.max_num_processes // args.num_processes
    else:
        chunk_size = num_envs

    num_chunks = int(np.ceil(num_envs / chunk_size))

    if args.record_video:
        num_chunks = 1
        chunk_size = 1
        args.num_processes = 1

    num_seeds = 0
    for xpid in xpids:
        if args.max_seeds is not None and num_seeds >= args.max_seeds:
            break

        xpid_dir = os.path.join(base_path, xpid)
        meta_json_path = os.path.join(xpid_dir, "meta.json")

        model_tar = f"{args.model_tar}.tar"
        checkpoint_path = os.path.join(xpid_dir, model_tar)

        if os.path.exists(checkpoint_path):
            meta_json_file = open(meta_json_path)
            xpid_flags = DotDict(json.load(meta_json_file)["args"])

            make_fn = [lambda: Evaluator.make_env(env_names[0])]
            dummy_venv = ParallelAdversarialVecEnv(make_fn, adversary=False, is_eval=True)
            dummy_venv = Evaluator.wrap_venv(
                dummy_venv, env_name=env_names[0], device=device
            )

            # Load the agent
            agent = make_agent(name="agent", env=dummy_venv, args=xpid_flags, device=device)

            try:
                checkpoint = torch.load(checkpoint_path, map_location="cpu")
            except:
                continue
            model_name = args.model_name

            if "runner_state_dict" in checkpoint:
                agent.algo.actor_critic.load_state_dict(
                    checkpoint["runner_state_dict"]["agent_state_dict"][model_name]
                )
            else:
                agent.algo.actor_critic.load_state_dict(checkpoint)

            num_seeds += 1

            # Evaluate environment batch in increments of chunk size
            for i in range(num_chunks):
                start_idx = i * chunk_size
                env_names_ = env_names[start_idx : start_idx + chunk_size]

                # Evaluate the model
                xpid_flags.update(args)
                xpid_flags.update({"use_skip": False})

                evaluator = Evaluator(
                    env_names_,
                    num_processes=args.num_processes,
                    num_episodes=args.num_episodes,
                    frame_stack=xpid_flags.frame_stack,
                    grayscale=xpid_flags.grayscale,
                    use_global_critic=xpid_flags.use_global_critic,
                    record_video=args.record_video,
                )

                stats = evaluator.evaluate(
                    agent,
                    deterministic=args.deterministic,
                    show_progress=args.verbose,
                    render=args.render,
                    accumulator=args.accumulator,
                )

                for k, v in stats.items():
                    if args.accumulator:
                        env_results[k].append(v)
                    else:
                        env_results[k] += v

                evaluator.close()
        else:
            print(f"No model path {checkpoint_path}")

    output_results = {}
    for k, _ in stats.items():
        results = env_results[k]
        output_results[k] = f"{np.mean(results):.2f} +/- {np.std(results):.2f}"
        q1 = np.percentile(results, 25, interpolation="midpoint")
        q3 = np.percentile(results, 75, interpolation="midpoint")
        median = np.median(results)
        output_results[f"iq_{k}"] = f"{q1:.2f}--{median:.2f}--{q3:.2f}"
        print(f"{k}: {output_results[k]}")
    HumanOutputFormat(sys.stdout).writekvs(output_results)

    if args.accumulator:
        csvwriter.writerow(
            [
                "metric",
            ]
            + [x for x in range(num_seeds)]
        )
    else:
        csvwriter.writerow(
            [
                "metric",
            ]
            + [x for x in range(num_seeds * args.num_episodes)]
        )
    for k, v in env_results.items():
        row = [
            k,
        ] + v
        csvwriter.writerow(row)

    if display:
        display.stop()
