import ray

from agents.callbacks import DefaultCallbacks


class EvalMainLeftPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        return "main_left"


class EvalMainRightPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        return "main_right"


class EvalLeftPopulationPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        """Dummy assignment. The policy mapping function will be assigned in the callback on episode start."""
        return f"left_0"


class EvalRightPopulationPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        """Dummy assignment. The policy mapping function will be assigned in the callback on episode start."""
        return f"right_0"


class SelfPlayMetricsCallback(DefaultCallbacks):
    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update policy mapping of the rollout worker before each episode starts."""
        coordinator = ray.get_actor("coordinator")
        worker.set_policy_mapping_fn(
            ray.get(coordinator.get_new_policy_mapping_fn.remote(list(policies.keys())))
        )

    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update game results when an episode is done."""
        left_policy = worker.policy_mapping_fn(0, episode, worker)
        right_policy = worker.policy_mapping_fn(1, episode, worker)
        if left_policy == "main":
            away = right_policy
            if episode.last_reward_for(0) > episode.last_reward_for(1):
                result = "win"
            elif episode.last_reward_for(0) < episode.last_reward_for(1):
                result = "lose"
            else:
                result = "draw"
        elif right_policy == "main":
            away = left_policy
            if episode.last_reward_for(0) > episode.last_reward_for(1):
                result = "lose"
            elif episode.last_reward_for(0) < episode.last_reward_for(1):
                result = "win"
            else:
                result = "draw"
        else:
            raise ValueError("None of the matched policy is main policy.")

        coordinator = ray.get_actor("coordinator")
        coordinator.update.remote(away, result)

        for k, v in episode.last_info_for(0).items():
            if k != "result":
                episode.custom_metrics[f"left_{k}"] = v
        for k, v in episode.last_info_for(1).items():
            if k == "score":
                episode.custom_metrics[f"right_{k}"] = -v
            elif k != "result":
                episode.custom_metrics[f"right_{k}"] = v


class AsymmetricSelfPlayMetricsCallback(DefaultCallbacks):
    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update policy mapping of the rollout worker before each episode starts."""
        coordinator = ray.get_actor("coordinator")
        # print("6666666666666worker.worker_index, env_index:", worker.worker_index, env_index)
        worker.set_policy_mapping_fn(
            ray.get(coordinator.get_new_policy_mapping_fn.remote(list(policies.keys())))
        )

    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update game results when an episode is done."""
        left_policy = worker.policy_mapping_fn("left", episode, worker)
        right_policy = worker.policy_mapping_fn("right", episode, worker)
        if left_policy.startswith("main"):
            home, away = left_policy, right_policy
            result = episode.last_info_for("left")["result"]
        elif right_policy.startswith("main"):
            home, away = right_policy, left_policy
            result = episode.last_info_for("right")["result"]
        else:
            raise ValueError("None of the matched policy is main policy.")

        coordinator = ray.get_actor("coordinator")
        coordinator.update.remote(home, away, result)

        for k, v in episode.last_info_for("left").items():
            if k != "result":
                episode.custom_metrics[f"left_{k}"] = v
        for k, v in episode.last_info_for("right").items():
            if k == "score":
                episode.custom_metrics[f"right_{k}"] = -v
            elif k != "result":
                episode.custom_metrics[f"right_{k}"] = v

        # remove_pol_ids = ray.get(coordinator.get_policy_to_remove.remote())
        # for pol_id in remove_pol_ids:
        #     if pol_id in worker.policy_map:
        #         print("========DELETE========:", pol_id)
        #         worker.remove_policy(pol_id)


class AsymmetricSelfPlayEvalCallback(DefaultCallbacks):
    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Log env evaluation results when an episode is done."""
        if episode.last_info_for("left") is not None:
            assert worker.policy_mapping_fn("left", episode, worker) == "main_left"
            for k, v in episode.last_info_for("left").items():
                if k != "result":
                    episode.custom_metrics[f"left_{k}"] = v
                else:
                    if episode.last_info_for("left")["result"] == "win":
                        episode.custom_metrics[f"left_win"] = 1
                    elif episode.last_info_for("left")["result"] == "draw":
                        episode.custom_metrics[f"left_win"] = 0.5
                    else:
                        episode.custom_metrics[f"left_win"] = 0
        if episode.last_info_for("right") is not None:
            assert worker.policy_mapping_fn("right", episode, worker) == "main_right"
            for k, v in episode.last_info_for("right").items():
                if k == "score":
                    episode.custom_metrics[f"right_{k}"] = -v
                elif k != "result":
                    episode.custom_metrics[f"right_{k}"] = v
                else:
                    if episode.last_info_for("right")["result"] == "win":
                        episode.custom_metrics[f"right_win"] = 1
                    elif episode.last_info_for("right")["result"] == "draw":
                        episode.custom_metrics[f"right_win"] = 0.5
                    else:
                        episode.custom_metrics[f"right_win"] = 0


class PopulationMetricsCallback(DefaultCallbacks):
    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update policy mapping of the rollout worker before each episode starts."""
        coordinator = ray.get_actor("coordinator")
        worker.set_policy_mapping_fn(
            ray.get(coordinator.get_new_policy_mapping_fn.remote(list(policies.keys())))
        )

    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update game results when an episode is done."""
        left_policy_name = worker.policy_mapping_fn("left", episode, worker)
        right_policy_name = worker.policy_mapping_fn("right", episode, worker)
        left_info = episode.last_info_for("left")
        right_info = episode.last_info_for("right")

        coordinator = ray.get_actor("coordinator")
        for team in ["left", "right"]:
            if team == "left":
                home, away, result = left_policy_name, right_policy_name, left_info["result"]
            else:
                home, away, result = right_policy_name, left_policy_name, right_info["result"]
            coordinator.update.remote(home, away, result)

        for k, v in left_info.items():
            if k != "result":
                episode.custom_metrics[f"{left_policy_name}_{k}"] = v
        for k, v in right_info.items():
            if k == "score":
                episode.custom_metrics[f"{right_policy_name}_{k}"] = -v
            elif k != "result":
                episode.custom_metrics[f"{right_policy_name}_{k}"] = v


class PopulationEvalCallback(DefaultCallbacks):
    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Update policy mapping of the rollout worker before each episode starts."""
        coordinator = ray.get_actor("coordinator")
        worker.set_policy_mapping_fn(
            ray.get(coordinator.get_new_policy_mapping_fn.remote(list(policies.keys())))
        )

    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Log env evaluation results when an episode is done."""
        left_info = episode.last_info_for("left")
        right_info = episode.last_info_for("right")

        if left_info is not None:
            left_policy_name = worker.policy_mapping_fn("left", episode, worker)
            for k, v in left_info.items():
                if k != "result":
                    episode.custom_metrics[f"{left_policy_name}_{k}"] = v
                else:
                    if left_info["result"] == "win":
                        episode.custom_metrics[f"{left_policy_name}_win"] = 1
                    elif left_info["result"] == "draw":
                        episode.custom_metrics[f"{left_policy_name}_win"] = 0.5
                    else:
                        episode.custom_metrics[f"{left_policy_name}_win"] = 0

        if right_info is not None:
            right_policy_name = worker.policy_mapping_fn("right", episode, worker)
            for k, v in right_info.items():
                if k == "score":
                    episode.custom_metrics[f"{right_policy_name}_{k}"] = -v
                elif k != "result":
                    episode.custom_metrics[f"{right_policy_name}_{k}"] = v
                else:
                    if right_info["result"] == "win":
                        episode.custom_metrics[f"{right_policy_name}_win"] = 1
                    elif right_info["result"] == "draw":
                        episode.custom_metrics[f"{right_policy_name}_win"] = 0.5
                    else:
                        episode.custom_metrics[f"{right_policy_name}_win"] = 0
