from gym import spaces
import numpy as np
from agents.callbacks import DefaultCallbacks


class NMMOMetricsCallback(DefaultCallbacks):
    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) -> None:
        """Runs when an episode is done."""
        if episode.last_info_for("agent_0"):
            for k, v in episode.last_info_for("agent_0").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for():
            for k, v in episode.last_info_for().items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("high_level_policy"):
            for k, v in episode.last_info_for("high_level_policy").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("high_level_0"):
            for k, v in episode.last_info_for("high_level_0").items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v
        elif episode.last_info_for("group_1"):
            for k, v in episode.last_info_for("group_1")["_group_info"][0].items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v               
        elif episode.last_info_for(0):
            for k, v in episode.last_info_for(0).items():
                episode.custom_metrics[k] = int(v) if isinstance(v, bool) else v



class MultiAgentParameterSharingPolicyMappingFn:
    def __call__(self, agent_id, episode, worker, **kwargs):
        return "shared_policy"

class FeatureParser:
    map_size = 15
    n_actions = 5
    NEIGHBOR = [(6, 7), (8, 7), (7, 8), (7, 6)]  # north, south, east, west
    OBSTACLE = (0, 1, 5)  # lava, water, stone
    feature_spec = {
        "terrain": spaces.Box(low=0, high=6, shape=(15, 15), dtype=np.int64),
        "camp": spaces.Box(low=0, high=4, shape=(15, 15), dtype=np.int64),
        "entity": spaces.Box(low=0,
                             high=2000,
                             shape=(7, 15, 15),
                             dtype=np.float32),
        "va": spaces.Box(low=0, high=2, shape=(5, ), dtype=np.int64),
    }

    def parse(self, obs):
        ret = {}
        for agent_id in obs:
            terrain = np.zeros((self.map_size, self.map_size), dtype=np.int64)
            camp = np.zeros((self.map_size, self.map_size), dtype=np.int64)
            entity = np.zeros((7, self.map_size, self.map_size),
                              dtype=np.float32)
            va = np.ones(self.n_actions, dtype=np.int64)

            # terrain feature
            tile = obs[agent_id]["Tile"]["Continuous"]
            LT_R, LT_C = tile[0, 2], tile[0][3]
            for line in tile:
                terrain[int(line[2] - LT_R),
                        int(line[3] - LT_C)] = int(line[1])

            # npc and player
            raw_entity = obs[agent_id]["Entity"]["Continuous"]
            P = raw_entity[0, 4]
            _r, _c = raw_entity[0, 5:7]
            assert int(_r - LT_R) == int(
                _c - LT_C) == 7, f"({int(_r - LT_R)}, {int(_c - LT_C)})"
            for line in raw_entity:
                if line[0] != 1:
                    continue
                raw_pop, raw_r, raw_c = line[4:7]
                r, c = int(raw_r - LT_R), int(raw_c - LT_C)
                camp[r, c] = 2 if raw_pop == P else np.sign(raw_pop) + 2
                # level
                entity[0, r, c] = line[3]
                # damage, timealive, food, water, health, is_freezed
                entity[1:, r, c] = line[7:]

            # valid action
            for i, (r, c) in enumerate(self.NEIGHBOR):
                if terrain[r, c] in self.OBSTACLE:
                    va[i + 1] = 0

            ret[agent_id] = {
                "terrain": terrain,
                "camp": camp,
                "entity": entity,
                "va": va
            }
        return ret


class RewardParser:
    def parse(self, prev_achv, achv):
        reward = {}
        for pid in achv:
            pm, last_pm = achv[pid], prev_achv[pid]
            r = (pm["PlayerDefeats"] - last_pm["PlayerDefeats"]) / 6.0 + \
                (pm["Exploration"] - last_pm["Exploration"]) / 127.0 + \
                (pm["Foraging"] - last_pm["Foraging"]) / 50.0 + \
                (pm["Equipment"] - last_pm["Equipment"]) / 20.0
            reward[pid] = r
        return reward