import enum
import random
import collections

import numpy as np

from operator import attrgetter
from types import SimpleNamespace as SN
from typing import Any, Dict, Tuple, Union, Optional, Sequence, List

from absl import logging
from pysc2 import maps
from pysc2 import run_configs
from pysc2.lib import protocol
from pysc2.lib.units import get_unit_type

from s2clientprotocol import common_pb2 as sc_common
from s2clientprotocol import sc2api_pb2 as sc_pb
from s2clientprotocol import raw_pb2 as r_pb
from s2clientprotocol import debug_pb2 as d_pb

from smac.env.starcraft2.starcraft2 import StarCraft2Env, actions


NOOP = 0
STOP = 1

ZERO_DELAY = 0


class ActionManager:
    def __init__(self, action: int=0, start_step: int=0, delay_value: int=0, completed: bool=False):
        self._start_step = start_step
        self._delay_value = delay_value
        self._completed = completed
        self._action = action
        self._dummy = True
    
    @property
    def action(self) -> int:
        return self._action

    @action.setter
    def action(self, v: int):
        self._action = v

    @property
    def start_step(self) -> int:
        return self._start_step

    @start_step.setter
    def start_step(self, v: int):
        self._start_step = v

    @property
    def delay_value(self) -> int:
        return self._delay_value

    @delay_value.setter
    def delay_value(self, v: int):
        self._delay_value = v

    @property
    def completed(self) -> bool:
        return self._completed

    @completed.setter
    def completed(self, v: int):
        self._completed = v

    def upate(self, action: int, start_step: int, delay_value: int, completed: bool):
        self.action = action
        self.start_step = start_step
        self.delay_value = delay_value
        self.completed = completed
        self._dummy = False


class AsyncStarCraft2Env(StarCraft2Env):
    """
    The StarCraft II environment for decentralised asynchronous 
    multi-agent micromanagement scenarios.
    """
    def __init__(
        self,
        map_name="8m",
        step_mul=8,
        move_amount=2,
        difficulty="7",
        game_version=None,
        seed=None,
        continuing_episode=False,
        obs_all_health=True,
        obs_own_health=True,
        obs_last_action=False,
        obs_pathing_grid=False,
        obs_terrain_height=False,
        obs_instead_of_state=False,
        obs_timestep_number=False,
        state_last_action=True,
        state_timestep_number=False,
        reward_sparse=False,
        reward_only_positive=True,
        reward_death_value=10,
        reward_win=200,
        reward_defeat=0,
        reward_negative_scale=0.5,
        reward_scale=True,
        reward_scale_rate=20,
        replay_dir="",
        replay_prefix="",
        window_size_x=1920,
        window_size_y=1200,
        heuristic_ai=False,
        heuristic_rest=False,
        debug=False,
        random_delay=True,
        async_config_static=None,
        async_config_random=None,
        async_on=True,
        policy_override_rule='no', # no, partial, full
        policy_override_prob=-1, # policy override prob for each episode, -1 means no overriding
        episode_limit=-1,  # default is -1
    ):
        """
        Create a StarCraftC2Env environment.

        Parameters
        ----------
        map_name : str, optional
            The name of the SC2 map to play (default is "8m"). The full list
            can be found by running bin/map_list.
        step_mul : int, optional
            How many game steps per agent step (default is 8). None
            indicates to use the default map step_mul.
        move_amount : float, optional
            How far away units are ordered to move per step (default is 2).
        difficulty : str, optional
            The difficulty of built-in computer AI bot (default is "7").
        game_version : str, optional
            StarCraft II game version (default is None). None indicates the
            latest version.
        seed : int, optional
            Random seed used during game initialisation. This allows to
        continuing_episode : bool, optional
            Whether to consider episodes continuing or finished after time
            limit is reached (default is False).
        obs_all_health : bool, optional
            Agents receive the health of all units (in the sight range) as part
            of observations (default is True).
        obs_own_health : bool, optional
            Agents receive their own health as a part of observations (default
            is False). This flag is ignored when obs_all_health == True.
        obs_last_action : bool, optional
            Agents receive the last actions of all units (in the sight range)
            as part of observations (default is False).
        obs_pathing_grid : bool, optional
            Whether observations include pathing values surrounding the agent
            (default is False).
        obs_terrain_height : bool, optional
            Whether observations include terrain height values surrounding the
            agent (default is False).
        obs_instead_of_state : bool, optional
            Use combination of all agents' observations as the global state
            (default is False).
        obs_timestep_number : bool, optional
            Whether observations include the current timestep of the episode
            (default is False).
        state_last_action : bool, optional
            Include the last actions of all agents as part of the global state
            (default is True).
        state_timestep_number : bool, optional
            Whether the state include the current timestep of the episode
            (default is False).
        reward_sparse : bool, optional
            Receive 1/-1 reward for winning/loosing an episode (default is
            False). Whe rest of reward parameters are ignored if True.
        reward_only_positive : bool, optional
            Reward is always positive (default is True).
        reward_death_value : float, optional
            The amount of reward received for killing an enemy unit (default
            is 10). This is also the negative penalty for having an allied unit
            killed if reward_only_positive == False.
        reward_win : float, optional
            The reward for winning in an episode (default is 200).
        reward_defeat : float, optional
            The reward for loosing in an episode (default is 0). This value
            should be nonpositive.
        reward_negative_scale : float, optional
            Scaling factor for negative rewards (default is 0.5). This
            parameter is ignored when reward_only_positive == True.
        reward_scale : bool, optional
            Whether or not to scale the reward (default is True).
        reward_scale_rate : float, optional
            Reward scale rate (default is 20). When reward_scale == True, the
            reward received by the agents is divided by (max_reward /
            reward_scale_rate), where max_reward is the maximum possible
            reward per episode without considering the shield regeneration
            of Protoss units.
        replay_dir : str, optional
            The directory to save replays (default is None). If None, the
            replay will be saved in Replays directory where StarCraft II is
            installed.
        replay_prefix : str, optional
            The prefix of the replay to be saved (default is None). If None,
            the name of the map will be used.
        window_size_x : int, optional
            The length of StarCraft II window size (default is 1920).
        window_size_y: int, optional
            The height of StarCraft II window size (default is 1200).
        heuristic_ai: bool, optional
            Whether or not to use a non-learning heuristic AI (default False).
        heuristic_rest: bool, optional
            At any moment, restrict the actions of the heuristic AI to be
            chosen from actions available to RL agents (default is False).
            Ignored if heuristic_ai == False.
        debug: bool, optional
            Log messages about observations, state, actions and rewards for
            debugging purposes (default is False).
        random_delay: bool, optional
            Whether use random deley, i.e each type of agent have different delays for each action, 
            at each time step, different delays can be sampled.
        async_config_static: dict, optional
            configuration for each async actions (static).
        async_config_random: dict, optional
            configuration for each async actions (random).
        """
        super().__init__(
            map_name=map_name,
            step_mul=step_mul,
            move_amount=move_amount,
            difficulty=difficulty,
            game_version=game_version,
            seed=seed,
            continuing_episode=continuing_episode,
            obs_all_health=obs_all_health,
            obs_own_health=obs_own_health,
            obs_last_action=obs_last_action,
            obs_pathing_grid=obs_pathing_grid,
            obs_terrain_height=obs_terrain_height,
            obs_instead_of_state=obs_instead_of_state,
            obs_timestep_number=obs_timestep_number,
            state_last_action=state_last_action,
            state_timestep_number=state_timestep_number,
            reward_sparse=reward_sparse,
            reward_only_positive=reward_only_positive,
            reward_death_value=reward_death_value,
            reward_win=reward_win,
            reward_defeat=reward_defeat,
            reward_negative_scale=reward_negative_scale,
            reward_scale=reward_scale,
            reward_scale_rate=reward_scale_rate,
            replay_dir=replay_dir,
            replay_prefix=replay_prefix,
            window_size_x=window_size_x,
            window_size_y=window_size_y,
            heuristic_ai=heuristic_ai,
            heuristic_rest=heuristic_rest,
            debug=debug
        )

        if episode_limit != -1:
            self.episode_limit = episode_limit

        self.policy_override_rule = policy_override_rule
        self.policy_override_prob = policy_override_prob

        self.random_delay = random_delay

        # the action memory to async actions
        self.execution_plans = []  # list of ActionManager() for async actions per episode

        self.async_config_static = SN(**async_config_static)
        self.async_config_random = SN(**async_config_random)

        self._setup_async_config()

        # current time step
        self.curr_step = 0

        # configuration for delay
        self.delay_step_min = 1   # no delay
        self.delay_step_max = self.max_delay_step = self._get_max_delay()

        self.action_changed = [False] * self.n_agents

        self.async_on = async_on

    def _get_max_delay(self) -> int:
        max_delay = 1
        for _agent_type, _conf in self.action_duration.items():
            for _v in _conf:
                max_v = np.max([_v])
                if max_v >= max_delay:
                    max_delay = max_v
        return max_delay

    def _setup_async_config(self):
        Delay = collections.namedtuple('Delay', ['NOOP', 'MOVE', 'STOP', 'ATTACK'])
        cfg = self.async_config_random if self.random_delay else self.async_config_static
        self.action_duration = {
            'baneling':  Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'zergling':  Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'colossus':  Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'stalker':   Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'hydralisk': Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'zealot':    Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'medivac':   Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'marauder':  Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
            'marine':    Delay._make([cfg.NOOP, cfg.MOVE, cfg.STOP, cfg.ATTACK]),
        }

    def reset(self):
        self.curr_step = 0
        self.execution_plans.clear()  # clear memory for each episode
        self._setup_async_config()
        return super().reset()

    @staticmethod
    def _is_action_async(action: int, delay_value: int, async_on: bool=True) -> bool:
        default_v = False
        # action == 0, NOOP
        # action == 1, STOP
        # action in [2, 3, 4, 5], MOVE
        # action > 5, ATTACK
        if delay_value > 0 and async_on:
            default_v = True
        # if action in {2, 3, 4, 5}:
        #     if delay_value != ZERO_DELAY:
        #         default_v = True
        # elif action > 5:
        #     if delay_value != ZERO_DELAY:
        #         default_v = True
        return default_v

    @staticmethod
    def _get_action_meanining(action: int) -> str:
        # TODO: double check the meaning of actions
        # FIXME: this is not a good way to do this, only for marines? 
        #        for MMM and MMM2, it may not work????
        if action in {2, 3, 4, 5}:  # move
            return 'MOVE'
        elif action > 5:
            return 'ATTACK'
        elif action == 0: 
            return 'NOOP'
        elif action == 1:
            return 'STOP'
        else:
            raise ValueError(f'No such action: {action}')

    def get_agent_pending_actions(self, agent_id: int) -> List[int]:
        """
        Get pending actions: [a_{t-m}, a_{t-m+1}, ..., a_{t-1}]
        """
        pending_actions = []
        act_managers = [execution_plan[agent_id] for execution_plan in self.execution_plans]
        for _i, _act_manager in enumerate(act_managers):
            if not _act_manager.completed:
                pending_actions.append(_act_manager.action)
        assert len(pending_actions) <= self.max_delay_step
        return pending_actions

    def get_agents_pending_actions(self) -> List[List[int]]:
        pending_actions = []
        for agent_id in range(self.n_agents):
            pending_actions.append(self.get_agent_pending_actions(agent_id))
        return pending_actions

    def update_past_memory(self, agent_id: int) -> int:
        _async_done_action = STOP  # if not found, do STOP

        act_managers = [execution_plan[agent_id] for execution_plan in self.execution_plans]
        for _i, _act_manager in enumerate(reversed(act_managers)):
            if not _act_manager.completed and _act_manager.delay_value > ZERO_DELAY:
                if _act_manager.start_step + _act_manager.delay_value - 1 == self.curr_step:
                    _async_done_action = _act_manager.action
                    _act_manager.completed = True
                    break

        # post process, some other actions may also finished
        for _i, _act_manager in enumerate(reversed(act_managers)):
            if not _act_manager.completed and _act_manager.delay_value > ZERO_DELAY:
                if _act_manager.start_step + _act_manager.delay_value - 1 < self.curr_step:
                    _act_manager.completed = True
        
        return _async_done_action

    def _get_action_delay(self, agent_id, action):
        # get action delay scheme
        agent = self.agents[agent_id]
        agent_type = self.agents_type[agent.unit_type]
        delay_scheme = self.action_duration[agent_type]

        _async_action_meanining = self._get_action_meanining(action)
        if _async_action_meanining == 'ATTACK' and self.async_on:
            if self.random_delay:
                delay_value = random.choice(delay_scheme.ATTACK)
            else:
                delay_value = delay_scheme.ATTACK
        # elif _async_action_meanining == 'MOVE':
        #     if self.random_delay:
        #         delay_value = random.choice(delay_scheme.MOVE)
        #     else:
        #         delay_value = delay_scheme.MOVE
        else:
            delay_value = ZERO_DELAY
        return delay_value

    def _async_action_scheduler(self, actions: int) -> List[int]:
        """
        Receive actions from the RL model and return allowed actions to be executed in the env
        """
        # process the action, possibly actions is a Torch Tensor
        actions_int = [int(a) for a in actions]
        allowed_actions, self.action_changed = [], [False] * len(actions_int)

        delay_values = [self._get_action_delay(_i, _act) for _i, _act in enumerate(actions_int)]
        # NOTE: this is a placeholder of action manager,
        #       it will be replaced by the real action manager later
        execution_plan = [
            ActionManager(start_step=self.curr_step, 
                          action=action,
                          completed=True if delay_values[i] == ZERO_DELAY else False)
                for i, action in enumerate(actions_int)
        ]

        assert len(actions_int) == self.n_agents, "len(actions_int) != self.n_agents"

        for _i, _action in enumerate(actions_int):
            # agent type
            agent = self.agents[_i]
            delay_value = self._get_action_delay(_i, _action)

            # check async of current action
            if self._is_action_async(_action, delay_value, async_on=self.async_on):
                # has history
                if len(self.execution_plans) > 0:
                    # update previous memory and return an async action to be executed at this time
                    _async_done_action = self.update_past_memory(_i)  # return STOP by default
                    # NOTE only ATTACk is allowed to be async
                    if self._get_action_meanining(_async_done_action) in {'ATTACK'}:
                        # NOTE for ATTACK actions, it does not matter if agent.health <= 0
                        # if agent.health < 0:  # agent dead, current health value is below zero
                        #     allowed_actions.append(NOOP)
                        # else:
                        allowed_actions.append(_async_done_action)
                    else:
                        allowed_actions.append(_async_done_action)
                else:
                    # assign STOP, no history, so no need to update past memory
                    allowed_actions.append(STOP)
                # update _action to the memory, this is an async action
                # when action == 0, noop, agent is dead
                execution_plan[_i].upate(action=_action,
                                         start_step=self.curr_step,
                                         delay_value=delay_value,
                                         completed=True if self._get_action_delay(_i, _action)==ZERO_DELAY else False)
                self.action_changed[_i] = True
            else:
                _async_done_action = self.update_past_memory(_i)  # return STOP by default
                # check the priority: ATTACK > MOVE > STOP > NOOP to select actions
                # NOTE only ATTACk is allowed to be async
                if self._get_action_meanining(_async_done_action) in {'ATTACK'}:
                    allowed_actions.append(_async_done_action)
                    self.action_changed[_i] = True
                else:
                    allowed_actions.append(_action)

            # NOTE do not check the health of the agent           
            # # check the health value
            # if agent.health <= 0:
            #     allowed_actions[-1] = NOOP

        self.execution_plans.append(execution_plan)

        assert len(allowed_actions) == len(actions_int), \
            f"Some actions are missing, len: {len(allowed_actions)}, expected len: {len(actions_int)}"
        return allowed_actions

    def step(self, actions: List[int]) -> Union[float, bool, dict]:
        """
        Override step()
        A single environment step. Returns reward, terminated, info.
        """
        allowed_actions = actions
        # if self.curr_step > self.max_delay_step:
        self.action_changed = [False] * self.n_agents  # reset the value

        if self.async_on:  # async mode on or off
            allowed_actions = self._async_action_scheduler(actions)

        self.allowed_actions = allowed_actions

        reward, terminated, info = super().step(allowed_actions)
        self.curr_step += 1
        info['async_action'] = float(True if sum(self.action_changed) >= 1 else False)
        return reward, terminated, info

    def policy_override(self, actions, override_policy=True):
        if self.policy_override_rule == 'no' or not override_policy:
            return actions
        else:
            # TODO
            pass
        return actions

    def get_agent_action(self, a_id, action):
        # NOTE override this function from super class
        """ Construct the action for agent a_id. """
        avail_actions = self.get_avail_agent_actions(a_id)

        delay_value = self._get_action_delay(a_id, action)
        # check async of current action
        async_flag = self._is_action_async(action, delay_value, async_on=self.async_on)

        # TODO check if action is a async action to be executed rigth now, let is pass
        assert avail_actions[action] == 1 or async_flag, \
                f"Agent {a_id} cannot perform action {action}, please check the action."

        unit = self.get_unit_by_id(a_id)
        tag = unit.tag
        x = unit.pos.x
        y = unit.pos.y

        if action == 0:
            # no-op (valid only when dead)
            assert unit.health == 0, "No-op only available for dead agents."
            if self.debug:
                logging.debug("Agent {}: Dead".format(a_id))
            return None
        elif action == 1:
            # stop
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["stop"],
                unit_tags=[tag],
                queue_command=False)
            if self.debug:
                logging.debug("Agent {}: Stop".format(a_id))

        elif action == 2:
            # move north
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(
                    x=x, y=y + self._move_amount),
                unit_tags=[tag],
                queue_command=False)
            if self.debug:
                logging.debug("Agent {}: Move North".format(a_id))

        elif action == 3:
            # move south
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(
                    x=x, y=y - self._move_amount),
                unit_tags=[tag],
                queue_command=False)
            if self.debug:
                logging.debug("Agent {}: Move South".format(a_id))

        elif action == 4:
            # move east
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(
                    x=x + self._move_amount, y=y),
                unit_tags=[tag],
                queue_command=False)
            if self.debug:
                logging.debug("Agent {}: Move East".format(a_id))

        elif action == 5:
            # move west
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(
                    x=x - self._move_amount, y=y),
                unit_tags=[tag],
                queue_command=False)
            if self.debug:
                logging.debug("Agent {}: Move West".format(a_id))
        else:
            # attack/heal units that are in range
            target_id = action - self.n_actions_no_attack
            if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
                target_unit = self.agents[target_id]
                action_name = "heal"
            else:
                target_unit = self.enemies[target_id]
                action_name = "attack"

            action_id = actions[action_name]
            target_tag = target_unit.tag

            cmd = r_pb.ActionRawUnitCommand(
                ability_id=action_id,
                target_unit_tag=target_tag,
                unit_tags=[tag],
                queue_command=False)

            if self.debug:
                logging.debug("Agent {} {}s unit # {}".format(
                    a_id, action_name, target_id))

        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
        return sc_action

    def get_avail_agent_actions(self, agent_id: int) -> List[int]:
        unit = self.get_unit_by_id(agent_id)
        avail_actions = super().get_avail_agent_actions(agent_id)
        if self.action_changed[agent_id]:
            agent_action = self.allowed_actions[agent_id]
            if unit.health > 0 and agent_action != 0:
                avail_actions[agent_action] = 1
        return avail_actions

    def _init_ally_unit_types(self, min_unit_type: int):
        """
        Initialise ally unit types. Should be called once from the
        init_units function.
        """
        self.agents_type = {}

        self._min_unit_type = min_unit_type
        if self.map_type == "marines":
            self.marine_id = min_unit_type
            self.agents_type.update({
                self.marine_id: 'marine',
            })
        elif self.map_type == "stalkers_and_zealots":
            self.stalker_id = min_unit_type
            self.zealot_id = min_unit_type + 1
            self.agents_type.update({
                self.stalker_id: 'stalker',
                self.zealot_id: 'zealot',
            })
        elif self.map_type == "colossi_stalkers_zealots":
            self.colossus_id = min_unit_type
            self.stalker_id = min_unit_type + 1
            self.zealot_id = min_unit_type + 2
            self.agents_type.update({
                self.colossus_id: 'colossus',
                self.stalker_id: 'stalker',
                self.zealot_id: 'zealot',
            })
        elif self.map_type == "MMM":
            self.marauder_id = min_unit_type
            self.marine_id = min_unit_type + 1
            self.medivac_id = min_unit_type + 2
            self.agents_type.update({
                self.marauder_id: 'marauder',
                self.marine_id: 'marine',
                self.medivac_id: 'medivac',
            })
        elif self.map_type == "zealots":
            self.zealot_id = min_unit_type
            self.agents_type.update({
                self.zealot_id: 'zealot',
            })
        elif self.map_type == "hydralisks":
            self.hydralisk_id = min_unit_type
            self.agents_type.update({
                self.hydralisk_id: 'hydralisk',
            })
        elif self.map_type == "stalkers":
            self.stalker_id = min_unit_type
            self.agents_type.update({
                self.stalker_id: 'stalker',
            })
        elif self.map_type == "colossus":
            self.colossus_id = min_unit_type
            self.agents_type.update({
                self.colossus_id: 'colossus',
            })
        elif self.map_type == "bane":
            self.baneling_id = min_unit_type
            self.zergling_id = min_unit_type + 1
            self.agents_type.update({
                self.baneling_id: 'baneling',
                self.zergling_id: 'zergling',
            })

        assert len(self.agents_type) != 0, "Something wrong?"

    def debug_get_true_reward(self, info, scheme):
        old_episode_rewards = info['episode_rewards']
        episode_rewards = [0.0] * len(old_episode_rewards)
        if scheme.startswith('proximal_'):
            for i, rew in enumerate(old_episode_rewards):
                if rew > 0:
                    pivot_time = i - (self.max_delay_step - 1)
                    assert pivot_time >= 0, f"pivot_time < 0: {pivot_time}"
                    episode_rewards[pivot_time] = rew
        return episode_rewards

    def get_obs(self, w_hist_actions=False):
        # TODO w_hist_actions is not used
        obs = super().get_obs()
        # # a assert for debugging
        # for _obs in obs:
        #     assert np.sum(np.array(_obs) > 1) == 0, "there are values > 1"
        #     assert np.sum(np.array(_obs) < -1) == 0, "there are values < -1"
        return obs

    def get_async_agents_actions(self):
        # zero means sync actions, one means async actions        
        async_action_flags = [ [0] * self.n_actions ] * self.n_agents
        for agent_idx in range(self.n_agents):
            for action, flag in enumerate(async_action_flags[agent_idx]):
                delay_value = self._get_action_delay(agent_idx, action)
                # check async of current action
                if self._is_action_async(action, delay_value):
                    async_action_flags[agent_idx][action] = 1
                else:
                    async_action_flags[agent_idx][action] = 0
        return async_action_flags
