import os
import csv
from typing import List, Dict, Any

import numpy as np
import torch
import matplotlib.pyplot as plt

from envs.registration import make as gym_make
from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
    MultiGridFullyObsWrapper,
    CarRacingWrapper,
)
from util import is_discrete_actions

try:
    import gymnasium as gym
except ImportError:
    import gym


def _make_raw_env(env_name: str, seed: int, record_video: bool = False, **kwargs):
    is_multigrid = env_name.startswith("MultiGrid")

    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        env = gym_make(env_name)

    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    env.reset(seed=seed)
    if hasattr(env.action_space, "seed"):
        env.action_space.seed(seed)
    if hasattr(env.observation_space, "seed"):
        env.observation_space.seed(seed)

    return env


def _save_hardcore_layout_images(
    env_name: str,
    num_seeds: int,
    seed_start: int,
    out_dir: str,
    **kwargs,
):
    if env_name != "BipedalWalkerHardcore-v3":
        return

    os.makedirs(out_dir, exist_ok=True)
    print(f"[Eval] Saving Hardcore layouts to: {out_dir}")

    for seed in range(seed_start, seed_start + num_seeds):
        env = _make_raw_env(env_name, seed, record_video=False, **kwargs)
        try:
            frame = env.render()
        except TypeError:
            frame = env.render("rgb_array")
        env.close()

        if isinstance(frame, (list, tuple)):
            frame = frame[0]
        if frame is None:
            continue

        fname = f"{env_name}_seed{seed:04d}.png"
        fpath = os.path.join(out_dir, fname)
        plt.imsave(fpath, frame)

    print("[Eval] Hardcore layouts saved.")


def _make_eval_env(env_name: str, device: str, record_video: bool, **kwargs):

    is_multigrid = env_name.startswith("MultiGrid")
    is_car_racing = env_name.startswith("CarRacing")
    is_bipedal = env_name.startswith("BipedalWalker")
    is_lava = env_name.startswith("Lava")

    # ---- base env ----
    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        env = gym_make(env_name)

    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    if is_car_racing:
        grayscale = kwargs.get("grayscale", False)
        num_action_repeat = kwargs.get("num_action_repeat", 8)
        nstack = kwargs.get("frame_stack", 4)
        crop = kwargs.get("crop_frame", False)

        env = CarRacingWrapper(
            env=env,
            grayscale=grayscale,
            reward_shaping=False,
            num_action_repeat=num_action_repeat,
            nstack=nstack,
            crop=crop,
            eval_=True,
        )

        if record_video:
            from gym.wrappers.monitor import Monitor

            env = Monitor(env, "videos/", force=True)
            print("Recording video!", flush=True)

    return env, is_multigrid, is_car_racing, is_bipedal, is_lava


def _build_parallel_env(
    env_name: str,
    num_processes: int,
    device: str,
    record_video: bool,
    **kwargs,
):
    make_fns = []

    for _ in range(num_processes):

        def _make(env_name_=env_name, record_video_=record_video, kwargs_=kwargs):
            env, is_multigrid, is_car_racing, is_bipedal, is_lava = _make_eval_env(
                env_name_, device=device, record_video=record_video_, **kwargs_
            )
            return env

        make_fns.append(_make)

    venv = ParallelAdversarialVecEnv(make_fns, adversary=False, is_eval=True)

    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_bipedal = env_name.startswith("BipedalWalker")
    is_lava = env_name.startswith("Lava")

    obs_key = None
    scale = None
    transpose_order = [2, 0, 1]

    if is_multigrid:
        obs_key = "image"
        scale = 10.0
    if is_bipedal or is_lava:
        transpose_order = None

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecPreprocessImageWrapper(
        venv=venv,
        obs_key=obs_key,
        transpose_order=transpose_order,
        scale=scale,
        device=device,
    )
    return venv


def _evaluate_single_env_fixed_seeds(
    env_name: str,
    venv,
    agent,
    num_seeds: int,
    seed_start: int,
    num_processes: int,
    episodes_per_seed: int,
    device: str,
    render_first_batch: bool,
    solved_threshold: float,
):

    assert num_seeds % num_processes == 0

    per_seed_returns: Dict[int, list] = {
        seed: [] for seed in range(seed_start, seed_start + num_seeds)
    }

    total_episodes = 0
    total_solved = 0

    device = torch.device(device)
    actor_critic = agent.algo.actor_critic
    hidden_size = actor_critic.recurrent_hidden_state_size
    is_recurrent = getattr(actor_critic, "is_recurrent", False) or getattr(
        agent, "is_recurrent", False
    )

    all_seeds = list(range(seed_start, seed_start + num_seeds))
    batches = num_seeds // num_processes
    first_batch = True

    is_disc = is_discrete_actions(venv)

    for b in range(batches):
        seeds_batch = all_seeds[b * num_processes : (b + 1) * num_processes]
        num_envs = len(seeds_batch)

        obs = venv.reset_with_seeds(seeds_batch)

        # recurrent hidden states
        rh = torch.zeros(
            num_envs,
            hidden_size,
            device=device,
        )
        if is_recurrent:
            rh = (rh, torch.zeros_like(rh))

        masks = torch.ones(num_envs, 1, device=device)
        episodes_done_per_env = [0 for _ in range(num_envs)]
        target_episodes_per_env = episodes_per_seed
        target_total_episodes = num_envs * target_episodes_per_env

        while sum(episodes_done_per_env) < target_total_episodes:
            with torch.no_grad():
                _, action, _, rh = agent.act(
                    obs,
                    rh,
                    masks,
                    deterministic=True,
                )

            action_np = action.cpu().numpy()
            if not is_disc:
                action_np = agent.process_action(action_np)

            obs, reward, done, infos = venv.step(action_np)

            masks = torch.tensor(
                [[0.0] if d else [1.0] for d in done],
                dtype=torch.float32,
                device=device,
            )

            if render_first_batch and first_batch:
                venv.render_to_screen()

            for i, info in enumerate(infos):
                if "episode" not in info:
                    continue
                if episodes_done_per_env[i] >= target_episodes_per_env:
                    continue

                seed_i = seeds_batch[i]
                ret = info["episode"]["r"]
                per_seed_returns[seed_i].append(ret)
                episodes_done_per_env[i] += 1
                total_episodes += 1

                if ret >= solved_threshold:
                    total_solved += 1

                if is_recurrent:
                    rh[0][i].zero_()
                    rh[1][i].zero_()

        #
        for i, seed_i in enumerate(seeds_batch):
            if len(per_seed_returns[seed_i]) > target_episodes_per_env:
                per_seed_returns[seed_i] = per_seed_returns[seed_i][
                    :target_episodes_per_env
                ]

        first_batch = False

    means = [
        np.mean(per_seed_returns[seed])
        for seed in range(seed_start, seed_start + num_seeds)
        if len(per_seed_returns[seed]) > 0
    ]

    if len(means) == 0:
        mean_return = 0.0
        std_return = 0.0
    else:
        mean_return = float(np.mean(means))
        std_return = float(np.std(means))

    if total_episodes == 0:
        solved_rate = 0.0
    else:
        solved_rate = total_solved / total_episodes

    return per_seed_returns, mean_return, std_return, solved_rate


# ---------------------- Evaluator class ---------------------- #


class Evaluator(object):

    def __init__(
        self,
        env_names: List[str],
        num_processes: int,
        num_episodes: int = 128,
        record_video: bool = False,
        device: str = "cpu",
        seed_start: int = 0,
        episodes_per_seed: int = 1,
        save_hardcore_layouts: bool = False,
        hardcore_layout_dir: str = None,
        save_eval_csv: bool = True,
        eval_csv_dir: str = "./eval_results",
        eval_csv_filename: str = "fixed_seed_eval_results.csv",
        **kwargs,
    ):
        self.env_names = env_names
        self.num_processes = num_processes
        self.num_episodes = num_episodes
        self.record_video = record_video
        self.device = device
        self.seed_start = seed_start
        self.episodes_per_seed = episodes_per_seed
        self.save_hardcore_layouts = save_hardcore_layouts
        self.hardcore_layout_dir = hardcore_layout_dir

        self.save_eval_csv = save_eval_csv
        self.eval_csv_dir = eval_csv_dir
        self.eval_csv_filename = eval_csv_filename
        self.eval_csv_path = os.path.join(self.eval_csv_dir, self.eval_csv_filename)
        self._eval_call_idx = 0

        self.kwargs = kwargs  # frame_stack, grayscale, use_global_critic,

        self.venv: Dict[str, Any] = {}
        for env_name in self.env_names:
            self.venv[env_name] = _build_parallel_env(
                env_name=env_name,
                num_processes=self.num_processes,
                device=self.device,
                record_video=self.record_video,
                **self.kwargs,
            )

    def get_stats_keys(self):
        keys = []
        for env_name in self.env_names:
            keys += [
                f"solved_rate:{env_name}",
                f"test_returns:{env_name}",
                f"test_returns_std:{env_name}",
            ]
        return keys

    def close(self):
        for v in self.venv.values():
            v.close()

    # ----- CSV -----

    def _write_eval_csv_rows(self, rows: List[Dict[str, Any]]):
        if not self.save_eval_csv or len(rows) == 0:
            return

        os.makedirs(self.eval_csv_dir, exist_ok=True)

        file_exists = os.path.exists(self.eval_csv_path)
        fieldnames = [
            "eval_index",
            "env_name",
            "mean_return",
            "std_return",
            "solved_rate",
            "num_seeds",
            "seed_start",
            "episodes_per_seed",
            "num_processes",
        ]

        with open(self.eval_csv_path, "a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            if not file_exists:
                writer.writeheader()
            writer.writerows(rows)

    # ----- evaluate -----

    def evaluate(
        self,
        agent,
        deterministic: bool = True,
        show_progress: bool = False,
        render: bool = False,
        accumulator: str = "mean",
    ) -> Dict[str, Any]:
        _ = deterministic
        _ = show_progress
        _ = accumulator

        stats: Dict[str, Any] = {}
        csv_rows: List[Dict[str, Any]] = []

        num_seeds = self.num_episodes
        seed_start = self.seed_start

        for idx, env_name in enumerate(self.env_names):
            if "Bipedal" in env_name:
                solved_threshold = 230.0
            else:
                solved_threshold = 0.0

            venv = self.venv[env_name]

            per_seed_returns, mean_ret, std_ret, solved_rate = (
                _evaluate_single_env_fixed_seeds(
                    env_name=env_name,
                    venv=venv,
                    agent=agent,
                    num_seeds=num_seeds,
                    seed_start=seed_start,
                    num_processes=self.num_processes,
                    episodes_per_seed=self.episodes_per_seed,
                    device=self.device,
                    render_first_batch=(render and idx == 0),
                    solved_threshold=solved_threshold,
                )
            )

            stats[f"solved_rate:{env_name}"] = solved_rate
            stats[f"test_returns:{env_name}"] = mean_ret
            stats[f"test_returns_std:{env_name}"] = std_ret

            csv_rows.append(
                {
                    "eval_index": self._eval_call_idx,
                    "env_name": env_name,
                    "mean_return": mean_ret,
                    "std_return": std_ret,
                    "solved_rate": solved_rate,
                    "num_seeds": num_seeds,
                    "seed_start": seed_start,
                    "episodes_per_seed": self.episodes_per_seed,
                    "num_processes": self.num_processes,
                }
            )

            if (
                self.save_hardcore_layouts
                and env_name == "BipedalWalkerHardcore-v3"
                and self.hardcore_layout_dir is not None
            ):
                _save_hardcore_layout_images(
                    env_name=env_name,
                    num_seeds=num_seeds,
                    seed_start=seed_start,
                    out_dir=self.hardcore_layout_dir,
                )

        self._write_eval_csv_rows(csv_rows)
        self._eval_call_idx += 1

        return stats
