from agents.callbacks import DefaultCallbacks
import logging
import ray

from worker.metrics import collect_metrics

logger = logging.getLogger(__name__)


class PvEMetricsCallback(DefaultCallbacks):
    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Runs when an episode is done."""
        if episode.last_info_for("agent_0"):
            for k, v in episode.last_info_for("agent_0").items():
                episode.custom_metrics[k] = v
        elif episode.last_info_for():
            for k, v in episode.last_info_for().items():
                episode.custom_metrics[k] = v
        elif episode.last_info_for("high_level_policy"):
            for k, v in episode.last_info_for("high_level_policy").items():
                episode.custom_metrics[k] = v
        elif episode.last_info_for("high_level_0"):
            for k, v in episode.last_info_for("high_level_0").items():
                episode.custom_metrics[k] = v
        elif episode.last_info_for("group_1"):
            for k, v in episode.last_info_for("group_1")["_group_info"][0].items():
                episode.custom_metrics[k] = v


class MultiAgentParameterSharingPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        return "shared_policy"


class EvalFn:
    def __call__(self, trainer, eval_workers):
        """A custom evaluation function.
        Args:
            trainer (Trainer): trainer class to evaluate.
            eval_workers (WorkerSet): evaluation workers.
        Returns:
            metrics (dict): evaluation metrics dict.
        """

        if eval_workers is None and trainer.workers.local_worker().input_reader is None:
            raise ValueError(
                "Cannot evaluate w/o an evaluation worker set in "
                "the Trainer or w/o an env on the local worker!\n"
                "Try one of the following:\n1) Set "
                "`evaluation_interval` >= 0 to force creating a "
                "separate evaluation worker set.\n2) Set "
                "`create_env_on_driver=True` to force the local "
                "(non-eval) worker to have an environment to "
                "evaluate on."
            )

        def _valid_env_config(env):
            if hasattr(env, "num_left_agents"):
                logger.info(
                    f"Evaluation controls {env.num_left_agents} on left and {env.num_right_agents} on right."
                )
            elif hasattr(env, "groups"):
                logger.info(f"Evaluation controls {env.groups} on right.")

        for w in eval_workers.remote_workers():
            w.foreach_env.remote(_valid_env_config)

        assert trainer.config["evaluation_duration_unit"] == "episodes"
        duration = (
            trainer.config["evaluation_duration"]
            if trainer.config["evaluation_duration"] != "auto"
            else (trainer.config["evaluation_num_workers"] or 1)
        )
        num_ts_run = 0

        logger.info(f"Evaluating current policy for {duration} episodes.")

        metrics = None
        # No evaluation worker set ->
        # Do worker using the local worker. Expect error due to the
        # local worker not having an env.
        if eval_workers is None:
            # Run n times `sample()` (each sample produces exactly 1 episode).
            for _ in range(duration):
                num_ts_run += len(trainer.workers.local_worker().sample())
            metrics = collect_metrics(trainer.workers.local_worker())

        # Evaluation worker set only has local worker.
        elif trainer.config["evaluation_num_workers"] == 0:
            # Run n times `sample()` (each sample produces exactly 1 episode).
            for _ in range(duration):
                num_ts_run += len(eval_workers.local_worker().sample())

        # Evaluation worker set has n remote workers.
        else:
            num_episodes_done = 0
            round_ = 0
            while True:
                episodes_left = duration - num_episodes_done
                if episodes_left <= 0:
                    break

                round_ += 1
                batches = ray.get(
                    [
                        w.sample.remote()
                        for i, w in enumerate(eval_workers.remote_workers())
                        if i < episodes_left
                    ]
                )
                # 1 episode per returned batch.
                num_episodes_done += len(batches)

                logger.info(
                    f"Ran round {round_} of parallel worker "
                    f"({num_episodes_done}/{duration} episodes done)"
                )

        # Collect the accumulated episodes on the workers, and then summarize the episode stats into a metrics dict.
        if metrics is None:
            metrics = collect_metrics(
                eval_workers.local_worker(),
                eval_workers.remote_workers(),
            )
        metrics["eval_timesteps_this_iter"] = num_ts_run

        # Put custom values in the metrics dict.
        # NOTE: eval with only scoring reward, so that this is actually win rate.

        # metrics for asymmetric self-play
        left_rewards = metrics["hist_stats"].get("policy_main_left_reward", None)
        if left_rewards is not None:
            left_len = len(left_rewards)
            metrics["eval_left_win_rate"] = (
                len(list(filter(lambda x: (x > 0), left_rewards))) / left_len
            )
            metrics["eval_left_draw_rate"] = (
                len(list(filter(lambda x: (x == 0), left_rewards))) / left_len
            )
            metrics["eval_left_lose_rate"] = (
                len(list(filter(lambda x: (x < 0), left_rewards))) / left_len
            )
        right_rewards = metrics["hist_stats"].get("policy_main_right_reward", None)
        if right_rewards is not None:
            right_len = len(right_rewards)
            metrics["eval_right_win_rate"] = (
                len(list(filter(lambda x: (x >= 0), right_rewards))) / right_len
            )
            metrics["eval_right_lose_rate"] = (
                len(list(filter(lambda x: (x < 0), right_rewards))) / right_len
            )
            logger.info(
                f"Iter {trainer.iteration} with left win rate = {metrics['eval_left_win_rate']} and "
                f"right win rate = {metrics['eval_right_win_rate']} against bots."
            )

        # metrics for asymmetric PvE
        episode_reward = metrics["hist_stats"].get("episode_reward", None)
        if episode_reward is not None:
            reward_len = len(episode_reward)
            metrics["win_rate"] = len(list(filter(lambda x: (x > 0), episode_reward))) / reward_len
            metrics["withdraw_rate"] = len(list(filter(lambda x: (x == 0), episode_reward))) / reward_len
            metrics["lose_rate"] = len(list(filter(lambda x: (x < 0), episode_reward))) / reward_len
            logger.info(f"Iter {trainer.iteration} with win-rate {metrics['win_rate']} against bots.")

        return metrics
