import numpy as np

from addict import Dict
from typing import List, Callable
from copy import deepcopy
import h5py as h5
from tqdm import trange

from s2clientprotocol import common_pb2 as sc_common
from s2clientprotocol import raw_pb2 as r_pb
from s2clientprotocol import sc2api_pb2 as sc_pb

from smac.env import StarCraft2Env
from smac.env.starcraft2.starcraft2 import actions as ACTION_TYPE_TO_ID_MAPPING

from mawm.envs.sc2.conventions import UnitTypes, ActionTypes
# This is not used, but it registers the new maps
import mawm.envs.sc2.map_registry


class ActionSpec(object):
    NULL_TARGET_ID = 255

    def __init__(self, action_type, target=None, target_id=None):
        self.action_type = action_type
        self.target = target
        self.target_id = target_id

    def to_proto(self, env: StarCraft2Env, agent_id: int):
        if self.action_type in [ActionTypes.UNKNOWN, ActionTypes.NOOP]:
            # Nothing to do here
            return None
        # Get the agent
        agent = TrajectoryTools.get_agent_at_idx(env, agent_id)
        # Now, if action_spec is a movement, we need to construct a proto appropriately
        if self.action_type in [ActionTypes.NORTH, ActionTypes.SOUTH, ActionTypes.EAST, ActionTypes.WEST]:
            # noinspection PyProtectedMember
            movement = {
                ActionTypes.NORTH: np.array([0, +env._move_amount]),
                ActionTypes.SOUTH: np.array([0, -env._move_amount]),
                ActionTypes.EAST: np.array([+env._move_amount, 0]),
                ActionTypes.WEST: np.array([-env._move_amount, 0])
            }[self.action_type]
            target_x, target_y = list(TrajectoryTools.position_as_array(agent) + movement)
            unit_command = r_pb.ActionRawUnitCommand(ability_id=ACTION_TYPE_TO_ID_MAPPING['move'],
                                                     target_world_space_pos=sc_common.Point2D(x=target_x, y=target_y),
                                                     unit_tags=[agent.tag],
                                                     queue_command=False)
        elif self.action_type == ActionTypes.ATTACK:
            assert self.target is not None
            unit_command = r_pb.ActionRawUnitCommand(ability_id=ACTION_TYPE_TO_ID_MAPPING['attack'],
                                                     target_unit_tag=self.target.tag,
                                                     unit_tags=[agent.tag],
                                                     queue_command=False)
        else:
            raise RuntimeError
        # Make sc_action and return
        action_proto = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=unit_command))
        return action_proto

    @classmethod
    def null_action(cls):
        return cls(ActionTypes.NOOP)


class TrajectoryTools(object):
    # ----------- TRAJECTORY -----------
    # The state comprises:
    #   - TA2 tensor of positions (float32)
    #   - TA tensor of healths (float32)
    #   - TA tensor of energy / cooldown (float32)
    #   - TA tensor of shields (float32)
    #   - A tensor of unit types (uint8)
    # The action comprises
    #   - TA tensor of base actions (0-7, unint8)
    #   - TA tensor of action targets (unit8, assumes there are less than 255 units (index 255 is reserved for null))
    # The reward comprises
    #   - TA tensor of rewards

    OWN_UNIT_TYPE_OFFSET = 1970

    # noinspection PyPep8Naming
    @classmethod
    def MAP_TYPE_2_OWN_UNIT_TYPE_MAPPINGS(cls, map_type):
        # See: https://github.com/oxwhirl/smac/blob/master/smac/env/starcraft2/starcraft2.py#L1319
        own_unit_type_offset = cls.OWN_UNIT_TYPE_OFFSET
        map_type_2_own_unit_type_mappings = {
            'marines': {
                own_unit_type_offset + 0: UnitTypes.MARINE,
                48: UnitTypes.MARINE
            },
            'stalkers_and_zealots': {
                own_unit_type_offset + 0: UnitTypes.STALKER,
                own_unit_type_offset + 1: UnitTypes.ZEALOT,
            },
            'colossi_stalkers_zealots': {
                own_unit_type_offset + 0: UnitTypes.COLOSSUS,
                own_unit_type_offset + 1: UnitTypes.STALKER,
                own_unit_type_offset + 2: UnitTypes.ZEALOT,
            },
            'MMM': {
                own_unit_type_offset + 0: UnitTypes.MARAUDER,
                own_unit_type_offset + 1: UnitTypes.MARINE,
                own_unit_type_offset + 2: UnitTypes.MEDIVAC,
            },
            'zealots': {
                own_unit_type_offset + 0: UnitTypes.ZEALOT,
            },
            'hydralisks': {
                own_unit_type_offset + 0: UnitTypes.HYDRALISK
            },
            'stalkers': {
                own_unit_type_offset + 0: UnitTypes.STALKER
            },
            'colossus': {
                own_unit_type_offset + 0: UnitTypes.COLOSSUS
            },
            'bane': {
                own_unit_type_offset + 0: UnitTypes.BANELING,
                own_unit_type_offset + 1: UnitTypes.ZERGLING
            }
        }
        return map_type_2_own_unit_type_mappings[map_type]

    EPS = 10e-5
    INF = float('inf')

    @classmethod
    def clip(cls, x):
        return np.clip(x, cls.EPS, cls.INF)

    @staticmethod
    def as_list(d):
        return [d[k] for k in range(len(d))]

    @classmethod
    def get_agent_at_idx(cls, env: StarCraft2Env, idx: int):
        return (cls.as_list(env.agents) + cls.as_list(env.enemies))[idx]

    @classmethod
    def get_all_agents(cls, env: StarCraft2Env):
        return cls.as_list(env.agents) + cls.as_list(env.enemies)

    @classmethod
    def get_num_friendlies(cls, env: StarCraft2Env):
        return len(env.agents)

    @classmethod
    def get_num_enemies(cls, env: StarCraft2Env):
        return len(env.enemies)

    @classmethod
    def get_num_agents(cls, env: StarCraft2Env):
        return len(env.agents) + len(env.enemies)

    @classmethod
    def as_unit(cls, env, unit=None, idx: int = None):
        if unit is None:
            assert idx is not None
            unit = cls.get_agent_at_idx(env, idx)
        else:
            assert idx is None
        return unit

    @classmethod
    def is_alive(cls, env: StarCraft2Env, unit=None, idx: int = None):
        unit = cls.as_unit(env, unit, idx)
        return unit.health > 0.

    @classmethod
    def is_friend(cls, env: StarCraft2Env, unit=None, idx: int = None):
        unit = cls.as_unit(env, unit, idx)
        return unit.owner == 1

    @classmethod
    def is_foe(cls, env: StarCraft2Env, unit=None, idx: int = None):
        unit = cls.as_unit(env, unit, idx)
        return unit.owner == 2

    @staticmethod
    def distance_between(unit_0, unit_1):
        return ((unit_0.pos.x - unit_1.pos.x) ** 2 + (unit_0.pos.y - unit_1.pos.y) ** 2) ** 0.5

    @staticmethod
    def get_vector(from_unit, to_unit):
        return np.array([to_unit.pos.x - from_unit.pos.x, to_unit.pos.y - from_unit.pos.y])

    @staticmethod
    def position_as_array(unit):
        return np.array([unit.pos.x, unit.pos.y])

    @classmethod
    def get_no_ops(cls, env: StarCraft2Env):
        return ([ActionSpec(ActionTypes.NOOP) for _ in range(len(env.agents))] +
                [ActionSpec(ActionTypes.UNKNOWN) for _ in range(len(env.enemies))])

    STATE_KEYS = ['positions', 'healths', 'energies', 'cooldowns', 'shields']

    # We do not normalize by the max, because then the normalized value means nothing to the game engine.
    # Instead, we normalize by a fixed value.
    HEALTH_NORMALIZER = lambda agent: 45.      # Max-health of a marine
    ENERGY_NORMALIZER = lambda agent: 200.     # Max-energy of a medivac
    COOLDOWN_NORMALIZER = lambda agent: 15.    # Max-cooldown of a marine
    SHIELD_NORMALIZER = lambda agent: 50.      # Max-shield of a zealot

    @classmethod
    def extract_state(cls, env: StarCraft2Env, buffer: Dict = None):
        # This function is applied at every timestep, and it extracts:
        #   - positions (A2)
        #   - health (A)
        #   - energy (A)
        #   - shields/cooldown (A)
        buffer = Dict() if buffer is None else buffer
        CDN = cls.COOLDOWN_NORMALIZER
        # Write normalized states
        buffer.positions = np.array([np.array([agent.pos.x, agent.pos.y]) for agent in cls.as_list(env.agents)] +
                                    [np.array([agent.pos.x, agent.pos.y]) for agent in cls.as_list(env.enemies)],
                                    dtype='float32')
        buffer.healths = np.array([agent.health / cls.HEALTH_NORMALIZER(agent) for agent in cls.as_list(env.agents)] +
                                  [agent.health / cls.HEALTH_NORMALIZER(agent) for agent in cls.as_list(env.enemies)],
                                  dtype='float32')
        buffer.energies = np.array([agent.energy / cls.ENERGY_NORMALIZER(agent) for agent in cls.as_list(env.agents)] +
                                   [agent.energy / cls.ENERGY_NORMALIZER(agent) for agent in cls.as_list(env.enemies)],
                                   dtype='float32')
        buffer.cooldowns = np.array([agent.weapon_cooldown / CDN(agent) for agent in cls.as_list(env.agents)] +
                                    [agent.weapon_cooldown / CDN(agent) for agent in cls.as_list(env.enemies)],
                                    dtype='float32')
        buffer.shields = np.array([agent.shield / cls.SHIELD_NORMALIZER(agent) for agent in cls.as_list(env.agents)] +
                                  [agent.shield / cls.SHIELD_NORMALIZER(agent) for agent in cls.as_list(env.enemies)],
                                  dtype='float32')
        return buffer

    @classmethod
    def extract_unit_types(cls, env: StarCraft2Env, buffer: Dict = None):
        # Helper functions
        def _get_agent_unit_type(agent):
            return cls.MAP_TYPE_2_OWN_UNIT_TYPE_MAPPINGS(env.map_type)[agent.unit_type]

        def _get_enemy_unit_type(agent):
            return UnitTypes.SC2_UNIT_TYPE_TO_OWN_UNIT_TYPE_MAPPINGS[agent.unit_type]

        buffer = Dict() if buffer is None else buffer
        buffer.unit_types = np.array([_get_agent_unit_type(agent) for agent in cls.as_list(env.agents)] +
                                     [_get_enemy_unit_type(agent) for agent in cls.as_list(env.enemies)],
                                     dtype='uint8')
        return buffer

    @classmethod
    def set_own_unit_type_offset(cls, env):
        cls.OWN_UNIT_TYPE_OFFSET = min([agent.unit_type for agent in cls.as_list(env.agents)])

    ACTION_KEYS = ['action_types', 'action_target_ids']
    @classmethod
    def extract_actions(cls, env: StarCraft2Env, action_specs: List[ActionSpec], buffer: Dict = None):
        # This function is applied at every timestep, and it extracts:
        #   - action types (A)
        #   - action target_id, like in the list get_all_agent returns (A)
        # Note that the action target_id defaults to ActionSpec.NULL_TARGET_ID if no targets are found
        buffer = Dict() if buffer is None else buffer
        buffer.action_types = np.array([action_spec.action_type for action_spec in action_specs], dtype='uint8')
        buffer.action_target_ids = np.array([action_spec.target_id
                                             if action_spec.target_id is not None else
                                             ActionSpec.NULL_TARGET_ID
                                             for action_spec in action_specs], dtype='uint8')
        return buffer

    @classmethod
    def delete_actions(cls, env: StarCraft2Env, buffer: Dict):
        for key in cls.ACTION_KEYS:
            del buffer[key]
        return buffer

    REWARD_KEYS = ['rewards']
    @classmethod
    def extract_step_info(cls, env: StarCraft2Env, reward_terminated_info: tuple, buffer: Dict = None):
        buffer = Dict() if buffer is None else buffer
        reward, terminated, info = reward_terminated_info
        buffer.rewards = np.array([reward], dtype='float32')
        return buffer

    @classmethod
    def consolidate_buffers(cls, env: StarCraft2Env, buffers: List[Dict], consolidated_buffer: Dict = None):
        consolidated_buffer = Dict() if consolidated_buffer is None else consolidated_buffer
        # Collect all the states
        for key in (cls.STATE_KEYS + cls.ACTION_KEYS + cls.REWARD_KEYS):
            consolidated_buffer[key] = np.array([buffer[key] for buffer in buffers if key in buffer])
        return consolidated_buffer

    @classmethod
    def pad_buffer_list(cls, env: StarCraft2Env, buffers: List[Dict], size: int = None):
        size = env.episode_limit if size is None else size
        # We sub 1 because the last buffer does not have an associated action (but only a state), but "steps" is a
        # count of the number of actions so far.
        num_steps_so_far = len(buffers) - 1
        if num_steps_so_far >= size:
            # This assert may or may not be needed -- let's keep it here for now
            assert num_steps_so_far == size
            return buffers
        # Okay, so num steps so far is smaller than size, and we need some padding. The strategy is to repeat the
        # last state with a bunch of NOOPs as action.
        # Buf first, let's add a Noop or Unknown as action with the existing last buffer
        cls.extract_actions(env, cls.get_no_ops(env), buffers[-1])
        # And a zero reward to go with it
        cls.extract_step_info(env, (0, None, None), buffers[-1])
        # Now copy and repeat size times. If you're thinking why there's no "size + 1" business, it's because the
        # extra buffer in the list was never deleted.
        pad_buffer = deepcopy(buffers[-1])
        while len(buffers) <= size:
            buffers.append(pad_buffer)
            pad_buffer = deepcopy(pad_buffer)
        # The last buffer should not have action or reward information
        for key in cls.ACTION_KEYS:
            if key in buffers[-1]:
                del buffers[-1][key]
        for key in cls.REWARD_KEYS:
            if key in buffers[-1]:
                del buffers[-1][key]
        # Done
        return buffers

    TRAJECTORY_SPECIFIC_KEYS = ['unit_types', 'dones', 'battle_wons']
    @classmethod
    def gather_trajectory(cls, env: StarCraft2Env, policy: Callable, size: int = None):
        size = env.episode_limit if size is None else size
        # Initialize containers
        trajectory = Dict()
        # Reset env
        env.reset()
        # Set own unit type offset
        cls.set_own_unit_type_offset(env)
        # Get the unit types (which remain constant over the episode)
        cls.extract_unit_types(env, trajectory)
        # Loop variables
        buffers = []
        terminated_at_timestep = size
        battle_won = False
        for t in range(size):
            # Gather state info
            buffer = cls.extract_state(env)
            # Get action spec for all agents
            action_specs, action_protos = [], []
            for agent_id, agent in enumerate(cls.get_all_agents(env)):
                policy_output = policy(env, agent_id)
                action_spec, action_proto = policy_output.action_spec, policy_output.action_proto
                action_specs.append(action_spec)
                action_protos.append(action_proto)
            # Get action info
            cls.extract_actions(env, action_specs, buffer)
            # Crop action_protos, because the env does not know how to deal with ActionTypes.UNKNOWN
            action_protos = action_protos[:cls.get_num_friendlies(env)]
            # Step the environment
            reward, terminated, info = env.step(action_protos)
            cls.extract_step_info(env, (reward, terminated, info), buffer)
            # Append buffer to list and loop
            buffers.append(buffer)
            if terminated:
                terminated_at_timestep = t
                battle_won = info['battle_won']
                break
        # Append final step to buffer
        buffers.append(cls.extract_state(env))
        # Do the padding
        buffers = cls.pad_buffer_list(env, buffers, size)
        # Consolidate buffers and write to trajectory
        cls.consolidate_buffers(env, buffers, trajectory)
        # Write out trajectory specific details
        trajectory.dones = np.array(terminated_at_timestep, dtype='uint16')
        trajectory.battle_wons = np.array(battle_won, dtype='bool')
        return trajectory

    @staticmethod
    def consolidate_identical_dict(dict_list: List[Dict]):
        consolidated = Dict()
        for key in dict_list[0].keys():
            consolidated[key] = np.array([dikt[key] for dikt in dict_list])
        return consolidated

    @classmethod
    def gather_trajectories(cls, env: StarCraft2Env, policy: Callable, num_trajectories: int, size: int = None):
        # Gather them trajectories
        trajectories = []
        for t in trange(num_trajectories):
            try:
                trajectories.append(cls.gather_trajectory(env, policy, size))
            except Exception as e:
                print(str(e))
                # Another try
                trajectories.append(cls.gather_trajectory(env, policy, size))
        # trajectories = [cls.gather_trajectory(env, policy, size) for _ in range(num_trajectories)]
        # Consolidate to one trajectory
        trajectories = cls.consolidate_identical_dict(trajectories)
        # Add metadata
        trajectories.map_size = np.array([env.map_x, env.map_y])
        trajectories.terrain_height = env.terrain_height
        trajectories.num_friendlies = np.array(cls.get_num_friendlies(env))
        trajectories.num_enemies = np.array(cls.get_num_enemies(env))
        # Done
        return trajectories

    @staticmethod
    def dump_to_file(trajectories, filename):
        with h5.File(filename, 'w') as f:
            for key, value in trajectories.items():
                f.create_dataset(name=key, data=value)
        return filename


if __name__ == '__main__':
    import time
    from mawm.envs.sc2.policies import HeuristicPolicy
    from mawm.envs.sc2.env_wrapper import StarCraft2Env

    SCENARIO = '2c1s6z'
    NUM_TRAJS = 1000
    TRAJ_LEN = 128
    EPS_GREEDY = 0.1

    env = StarCraft2Env(map_name=SCENARIO, heuristic_ai=False, difficulty='8')

    toc = time.time()
    trajs = TrajectoryTools.gather_trajectories(env, HeuristicPolicy(EPS_GREEDY), NUM_TRAJS, TRAJ_LEN)
    TrajectoryTools.dump_to_file(trajs, f'/Users/redacted/Python/mawm/data/'
                                        f'sc2_{SCENARIO}_greed-{EPS_GREEDY}_{NUM_TRAJS}x_trajs.h5')
    tic = time.time()
    print("Elapsed Time : ", (tic - toc)/NUM_TRAJS)
    print("Win Rate     : ", trajs.battle_wons.mean())
