from gym.spaces import Discrete, Box, Dict
import random
import copy
import numpy as np
import nmmo
from ijcai2022nmmo import CompetitionConfig
from ijcai2022nmmo.env.team_based_env import TeamBasedEnv
from ijcai2022nmmo.scripted import CombatTeam, ForageTeam, RandomTeam
from ijcai2022nmmo.scripted.baselines import Scripted
from ijcai2022nmmo.scripted.scripted_team import ScriptedTeam
from env.multi_agent_env import MultiAgentEnv
from env.nmmo.utils import FeatureParser, RewardParser
from env.nmmo.config import P4A4Config, P4A8Config, P4A16Config

class RLlibNMMO2Env(MultiAgentEnv):
    """Wraps a IJCAI 2022 Neural MMO env to be compatible with RLlib"""

    def __init__(self, **nmmo_args):
        """Create a new multi-agent StarCraft env compatible with RLlib.
        
        Args:
            nmmo_args (dict): Arguments to pass to the underlying
                ijcai2022nmmo.env.team_based_env.TeamBasedEnv instance.
        """
        # env_config = nmmo_args["cfg"]
        self.num_agents = nmmo_args['num_of_controlled_agents']
        if self.num_agents == 4:
            self.config = P4A4Config()
        elif self.num_agents == 8:
            self.config = P4A8Config()
        elif self.num_agents == 16:
            self.config = P4A16Config()
        else:
            raise NotImplementedError
        self.num_pop = self.config.NPOP
        self._env = TeamBasedEnv(self.config)
        self.use_auxiliary_script = False
        # self.teams = nmmo_args["teams"]
        self.max_step = 1024
        self.TT_ID = 0  # training team index
        self._ready_agents = []
        
        self.feature_parser = FeatureParser()
        self.reward_parser = RewardParser()
        
        self.observation_space = Dict(self.feature_parser.feature_spec)
        self.action_space = Discrete(5)
        self._dummy_feature = {
            key: np.zeros(shape=val.shape, dtype=val.dtype)
            for key, val in self.feature_parser.feature_spec.items()
        }

    def reset(self):
        """Resets the env and returns observations from ready agents.
        
        Returns:
            obs (dict): New observations for each ready agent.
        """
        raw_obs = self._env.reset() # return observation by team, default len == 16
        obs = raw_obs[self.TT_ID] # get team observation
        obs = self.feature_parser.parse(obs)
        
        self.reset_auxiliary_script(self.config)
        self.reset_scripted_team(self.config)
        self.agents = list(obs.keys())
        self._prev_achv = self._env.metrices_by_team()[self.TT_ID]
        self._prev_raw_obs = raw_obs
        self._step = 0

        return obs

    def step(self, actions):
        decisions = self.get_scripted_team_decision(self._prev_raw_obs)
        decisions[self.TT_ID] = self.transform_action(
            actions,
            observations=self._prev_raw_obs[self.TT_ID],
            auxiliary_script=self.auxiliary_script)

        raw_obs, _, raw_done, raw_info = self._env.step(decisions)
        if self.TT_ID in raw_obs:
            obs = raw_obs[self.TT_ID]
            done_tmp = raw_done[self.TT_ID]
            info = raw_info[self.TT_ID]

            obs = self.feature_parser.parse(obs)
            achv = self._env.metrices_by_team()[self.TT_ID]
            reward = self.reward_parser.parse(self._prev_achv, achv)
            self._prev_achv = achv
            new_info = copy.deepcopy(achv)
        else:
            obs, reward, done_tmp, info, new_info = {}, {}, {}, {}, copy.deepcopy(self._prev_achv)       

        for agent_id in self.agents:
            if agent_id not in obs:
                obs[agent_id] = self._dummy_feature
                reward[agent_id] = 0
                done_tmp[agent_id] = True

        self._prev_raw_obs = raw_obs
        self._step += 1
        
        done = {key: all(done_tmp.values()) for key in done_tmp.keys()}
        if self._step >= self.max_step:
            done = {key: True for key in done_tmp.keys()}
        
        done["__all__"] = all(done.values())
        return obs, reward, done, new_info

    def reset_auxiliary_script(self, config):
        if not self.use_auxiliary_script:
            self.auxiliary_script = None
            return
        if getattr(self, "auxiliary_script", None) is not None:
            self.auxiliary_script.reset()
            return
        self.auxiliary_script = AttackTeam("auxiliary", config)

    def reset_scripted_team(self, config):
        if getattr(self, "_scripted_team", None) is not None:
            for team in self._scripted_team.values():
                team.reset()
            return
        self._scripted_team = {}
        assert config.NPOP == self.num_pop
        for i in range(config.NPOP):
            if i == self.TT_ID:
                continue
            if self.TT_ID < i <= config.NPOP-2:
                self._scripted_team[i] = ForageTeam(f"forage-{i}", config)
            else:
                self._scripted_team[i] = CombatTeam(f"combat-{i}", config)

    def get_scripted_team_decision(self, observations):
        decisions = {}
        tt_id = self.TT_ID
        for team_id, obs in observations.items():
            if team_id == tt_id:
                continue
            decisions[team_id] = self._scripted_team[team_id].act(obs)
        return decisions

    @staticmethod
    def transform_action(actions, observations=None, auxiliary_script=None):
        """neural network move + scripted attack"""
        decisions = {}

        # move decisions
        for agent_id, val in actions.items():
            if observations is not None and agent_id not in observations:
                continue
            if val == 0:
                decisions[agent_id] = {}
            elif 1 <= val <= 4:
                decisions[agent_id] = {
                    nmmo.action.Move: {
                        nmmo.action.Direction: val - 1
                    }
                }
            else:
                raise ValueError(f"invalid action: {val}")

        # attack decisions
        if auxiliary_script is not None:
            assert observations is not None
            attack_decisions = auxiliary_script.act(observations)
            # merge decisions
            for agent_id, d in decisions.items():
                d.update(attack_decisions[agent_id])
                decisions[agent_id] = d
        return decisions
    def close(self):
        self._env.close()

    def seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        
class Attack(Scripted):
    '''attack'''
    name = 'Attack_'

    def __call__(self, obs):
        super().__call__(obs)

        self.scan_agents()
        self.target_weak()
        self.style = nmmo.action.Range
        self.attack()
        return self.actions


class AttackTeam(ScriptedTeam):
    agent_klass = Attack
