from agents.callbacks import DefaultCallbacks


class PvEMetricsCallback(DefaultCallbacks):
    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Runs when an episode is done."""
        if episode.last_info_for("agent_0"):
            for k, v in episode.last_info_for("agent_0").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for():
            for k, v in episode.last_info_for().items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("high_level_policy"):
            for k, v in episode.last_info_for("high_level_policy").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("high_level_0"):
            for k, v in episode.last_info_for("high_level_0").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("group_1"):
            for k, v in episode.last_info_for("group_1")["_group_info"][0].items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v


class MultiAgentParameterSharingPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        return "shared_policy"
