import numpy as np
import random

from .gather_trajectories import TrajectoryTools, ActionSpec
from .conventions import ActionTypes
from addict import Dict

# Imports for type-checking
from smac.env import StarCraft2Env


class HeuristicPolicy(object):

    VALID_RANDOM_ACTIONS = [
        ActionTypes.NORTH,
        ActionTypes.SOUTH,
        ActionTypes.EAST,
        ActionTypes.WEST,
        ActionTypes.NOOP
    ]

    def __init__(self, eps_greedy=0.):
        self.eps_greedy = eps_greedy

    def get_action_spec(self, env: StarCraft2Env, agent_id: int):
        if np.random.uniform(0, 1) < self.eps_greedy:
            return self.get_random_action_spec(env, agent_id)
        else:
            return self.get_heuristic_action_spec(env, agent_id)

    def get_random_action_spec(self, env: StarCraft2Env, agent_id: int):
        agent = TrajectoryTools.get_agent_at_idx(env, agent_id)
        # Check if friendly; if not return
        if TrajectoryTools.is_foe(env, agent):
            return ActionSpec(action_type=ActionTypes.UNKNOWN, target=None, target_id=None)
        # Check if alive
        if not TrajectoryTools.is_alive(env, agent):
            return ActionSpec(action_type=ActionTypes.NOOP, target=None, target_id=None)
        # If alive, sample a random action and return
        random_movement_direction = random.sample(self.VALID_RANDOM_ACTIONS, 1)[0]
        return ActionSpec(action_type=random_movement_direction)

    def get_heuristic_action_spec(self, env: StarCraft2Env, agent_id: int):
        assert env.map_type != 'MMM', "No Medivacs for now."
        # Strategy:
        #   0. Check if unit is alive. If not, return no-op.
        #   1. Find the closest unit (inside or outside FOV). This might look like cheating, but it's fair if the
        #      agent is matched with a bot with full visibility.
        #   2. If out-of-range, move towards it (find out what action!), else engage.
        agent = TrajectoryTools.get_agent_at_idx(env, agent_id)
        # Check if friendly; if not return
        if TrajectoryTools.is_foe(env, agent):
            return ActionSpec(action_type=ActionTypes.UNKNOWN, target=None, target_id=None)
        # Step 0: Check if alive
        if not TrajectoryTools.is_alive(env, agent):
            return ActionSpec(action_type=ActionTypes.NOOP, target=None, target_id=None)
        # Step 1: Okay, we're alive. Find the closest unit
        unit_distances = [Dict(unit=unit, distance=TrajectoryTools.distance_between(agent, unit), unit_id=uid)
                          for uid, unit in enumerate(TrajectoryTools.get_all_agents(env))
                          if TrajectoryTools.is_foe(env, unit) and TrajectoryTools.is_alive(env, unit)]
        closest_unit_idx = int(np.argmin([ud.distance for ud in unit_distances]))
        # Word of warning: closest_unit_id != closest_unit_idx.
        closest_unit, closest_distance, closest_unit_id = (unit_distances[closest_unit_idx].unit,
                                                           unit_distances[closest_unit_idx].distance,
                                                           unit_distances[closest_unit_idx].unit_id)
        # Step 2: Check if in or out-of range.
        # Note that agent_id is understood by the SMAC SC2 env, because the agent is friendly.
        unit_in_range = closest_distance < env.unit_shoot_range(agent_id)
        # if the unit is in range, we engage, else we move
        if unit_in_range:
            # Engage!
            return ActionSpec(action_type=ActionTypes.ATTACK, target=closest_unit, target_id=closest_unit_id)
        else:
            movement_direction = self.get_movement_direction(from_unit=agent, to_unit=closest_unit)
            return ActionSpec(action_type=movement_direction, target=None, target_id=None)

    @staticmethod
    def get_movement_direction(from_unit, to_unit):
        vector = TrajectoryTools.get_vector(from_unit, to_unit)
        argmax = int(np.argmax(np.abs(vector)))
        sign = int(np.sign(vector[argmax]))
        direction = {
            (0, +1): ActionTypes.EAST,
            (0, -1): ActionTypes.WEST,
            (1, +1): ActionTypes.NORTH,
            (1, -1): ActionTypes.SOUTH
        }[argmax, sign]
        return direction

    def __call__(self, env: StarCraft2Env, agent_id: int):
        # Get action spec
        action_spec = self.get_action_spec(env, agent_id)
        # Get proto
        action_proto = action_spec.to_proto(env, agent_id)
        # Done
        return Dict(action_spec=action_spec, action_proto=action_proto)


