import numpy as np

from s2clientprotocol import sc2api_pb2 as sc_pb
from absl import logging
from pysc2.lib import protocol

from smac.env import StarCraft2Env as _StarCraft2Env


class StarCraft2Env(_StarCraft2Env):
    def get_agent_action(self, a_id, action):
        if isinstance(action, int):
            # Action is an int, so we let the super-class do the dirty work of converting to proto
            return super(StarCraft2Env, self).get_agent_action(a_id, action)
        # For the following, we should  have direct access to the action proto
        if isinstance(action, sc_pb.Action):
            return action
        elif isinstance(action, dict) and 'action_proto' in action:
            return action['action_proto']
        elif action is None:
            return action
        else:
            raise ValueError

    def step(self, actions):
        """A single environment step. Returns reward, terminated, info."""

        # Allow actions to be a list of protos
        if all([isinstance(a, int) for a in actions]):
            self.last_action = np.eye(self.n_actions)[np.array(actions)]
        else:
            assert all([isinstance(a, (dict, sc_pb.Action, type(None))) for a in actions])

        # Collect individual actions
        sc_actions = []
        if self.debug:
            logging.debug("Actions".center(60, "-"))

        for a_id, action in enumerate(actions):
            if not self.heuristic_ai:
                agent_action = self.get_agent_action(a_id, action)
            else:
                agent_action = self.get_agent_action_heuristic(a_id, action)
            if agent_action:
                sc_actions.append(agent_action)

        # Send action request
        req_actions = sc_pb.RequestAction(actions=sc_actions)
        try:
            self._controller.actions(req_actions)
            # Make step in SC2, i.e. apply actions
            self._controller.step(self._step_mul)
            # Observe here so that we know if the episode is over.
            self._obs = self._controller.observe()
        except (protocol.ProtocolError, protocol.ConnectionError):
            self.full_restart()
            return 0, True, {}

        self._total_steps += 1
        self._episode_steps += 1

        # Update units
        game_end_code = self.update_units()

        terminated = False
        reward = self.reward_battle()
        info = {"battle_won": False}

        if game_end_code is not None:
            # Battle is over
            terminated = True
            self.battles_game += 1
            if game_end_code == 1 and not self.win_counted:
                self.battles_won += 1
                self.win_counted = True
                info["battle_won"] = True
                if not self.reward_sparse:
                    reward += self.reward_win
                else:
                    reward = 1
            elif game_end_code == -1 and not self.defeat_counted:
                self.defeat_counted = True
                if not self.reward_sparse:
                    reward += self.reward_defeat
                else:
                    reward = -1

        elif self._episode_steps >= self.episode_limit:
            # Episode limit reached
            terminated = True
            if self.continuing_episode:
                info["episode_limit"] = True
            self.battles_game += 1
            self.timeouts += 1

        if self.debug:
            logging.debug("Reward = {}".format(reward).center(60, '-'))

        if terminated:
            self._episode_count += 1

        if self.reward_scale:
            reward /= self.max_reward / self.reward_scale_rate

        return reward, terminated, info
