from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .multiagentenv import MultiAgentEnv
from .smac_maps import get_map_params

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging

from pysc2 import maps
from pysc2 import run_configs
from pysc2.lib import protocol

from s2clientprotocol import common_pb2 as sc_common
from s2clientprotocol import sc2api_pb2 as sc_pb
from s2clientprotocol import raw_pb2 as r_pb
from s2clientprotocol import debug_pb2 as d_pb

import os.path as osp
from pathlib import Path
import yaml

import random
from gym.spaces import Discrete

races = {
    "R": sc_common.Random,
    "P": sc_common.Protoss,
    "T": sc_common.Terran,
    "Z": sc_common.Zerg,
}

difficulties = {
    "1": sc_pb.VeryEasy,
    "2": sc_pb.Easy,
    "3": sc_pb.Medium,
    "4": sc_pb.MediumHard,
    "5": sc_pb.Hard,
    "6": sc_pb.Harder,
    "7": sc_pb.VeryHard,
    "8": sc_pb.CheatVision,
    "9": sc_pb.CheatMoney,
    "A": sc_pb.CheatInsane,
}

actions = {
    "move": 16,  # target: PointOrUnit
    "attack": 23,  # target: PointOrUnit
    "stop": 4,  # target: None
    "heal": 386,  # Unit
}


class Direction(enum.IntEnum):
    NORTH = 0
    SOUTH = 1
    EAST = 2
    WEST = 3


class StarCraft2Env(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """

    def __init__(
        self,
        args,
        step_mul=8,
        move_amount=2,
        difficulty="7",
        game_version=None,
        seed=None,
        continuing_episode=False,
        obs_all_health=True,
        obs_own_health=True,
        obs_last_action=True,
        obs_pathing_grid=False,
        obs_terrain_height=False,
        obs_instead_of_state=False,
        obs_timestep_number=False,
        obs_agent_id=True,
        state_pathing_grid=False,
        state_terrain_height=False,
        state_last_action=True,
        state_timestep_number=False,
        state_agent_id=True,
        reward_sparse=False,
        reward_only_positive=True,
        reward_death_value=10,
        reward_win=200,
        reward_defeat=0,
        reward_negative_scale=0.5,
        reward_scale=True,
        reward_scale_rate=20,
        replay_dir="",
        replay_prefix="",
        window_size_x=1920,
        window_size_y=1200,
        heuristic_ai=False,
        heuristic_rest=False,
        debug=False,
    ):
        """
        Create a StarCraftC2Env environment.

        Parameters
        ----------
        map_name : str, optional
            The name of the SC2 map to play (default is "8m"). The full list
            can be found by running bin/map_list.
        step_mul : int, optional
            How many game steps per agent step (default is 8). None
            indicates to use the default map step_mul.
        move_amount : float, optional
            How far away units are ordered to move per step (default is 2).
        difficulty : str, optional
            The difficulty of built-in computer AI bot (default is "7").
        game_version : str, optional
            StarCraft II game version (default is None). None indicates the
            latest version.
        seed : int, optional
            Random seed used during game initialisation. This allows to
        continuing_episode : bool, optional
            Whether to consider episodes continuing or finished after time
            limit is reached (default is False).
        obs_all_health : bool, optional
            Agents receive the health of all units (in the sight range) as part
            of observations (default is True).
        obs_own_health : bool, optional
            Agents receive their own health as a part of observations (default
            is False). This flag is ignored when obs_all_health == True.
        obs_last_action : bool, optional
            Agents receive the last actions of all units (in the sight range)
            as part of observations (default is False).
        obs_pathing_grid : bool, optional
            Whether observations include pathing values surrounding the agent
            (default is False).
        obs_terrain_height : bool, optional
            Whether observations include terrain height values surrounding the
            agent (default is False).
        obs_instead_of_state : bool, optional
            Use combination of all agents' observations as the global state
            (default is False).
        obs_timestep_number : bool, optional
            Whether observations include the current timestep of the episode
            (default is False).
        state_last_action : bool, optional
            Include the last actions of all agents as part of the global state
            (default is True).
        state_timestep_number : bool, optional
            Whether the state include the current timestep of the episode
            (default is False).
        reward_sparse : bool, optional
            Receive 1/-1 reward for winning/loosing an episode (default is
            False). Whe rest of reward parameters are ignored if True.
        reward_only_positive : bool, optional
            Reward is always positive (default is True).
        reward_death_value : float, optional
            The amount of reward received for killing an enemy unit (default
            is 10). This is also the negative penalty for having an allied unit
            killed if reward_only_positive == False.
        reward_win : float, optional
            The reward for winning in an episode (default is 200).
        reward_defeat : float, optional
            The reward for loosing in an episode (default is 0). This value
            should be nonpositive.
        reward_negative_scale : float, optional
            Scaling factor for negative rewards (default is 0.5). This
            parameter is ignored when reward_only_positive == True.
        reward_scale : bool, optional
            Whether or not to scale the reward (default is True).
        reward_scale_rate : float, optional
            Reward scale rate (default is 20). When reward_scale == True, the
            reward received by the agents is divided by (max_reward /
            reward_scale_rate), where max_reward is the maximum possible
            reward per episode without considering the shield regeneration
            of Protoss units.
        replay_dir : str, optional
            The directory to save replays (default is None). If None, the
            replay will be saved in Replays directory where StarCraft II is
            installed.
        replay_prefix : str, optional
            The prefix of the replay to be saved (default is None). If None,
            the name of the map will be used.
        window_size_x : int, optional
            The length of StarCraft II window size (default is 1920).
        window_size_y: int, optional
            The height of StarCraft II window size (default is 1200).
        heuristic_ai: bool, optional
            Whether or not to use a non-learning heuristic AI (default False).
        heuristic_rest: bool, optional
            At any moment, restrict the actions of the heuristic AI to be
            chosen from actions available to RL agents (default is False).
            Ignored if heuristic_ai == False.
        debug: bool, optional
            Log messages about observations, state, actions and rewards for
            debugging purposes (default is False).
        """
        # Map arguments
        state_config = self.load_state_config(args["state_type"])
        self.map_name = args["map_name"]
        self.add_local_obs = state_config["add_local_obs"]
        self.add_move_state = state_config["add_move_state"]
        self.add_visible_state = state_config["add_visible_state"]
        self.add_distance_state = state_config["add_distance_state"]
        self.add_xy_state = state_config["add_xy_state"]
        self.add_enemy_action_state = state_config["add_enemy_action_state"]
        self.add_agent_id = state_config["add_agent_id"]
        self.use_state_agent = state_config["use_state_agent"]
        self.use_mustalive = state_config["use_mustalive"]
        self.add_center_xy = state_config["add_center_xy"]
        self.use_stacked_frames = state_config["use_stacked_frames"]
        self.stacked_frames = state_config["stacked_frames"]

        map_params = get_map_params(self.map_name)
        self.n_agents = map_params["n_agents"]
        self.n_enemies = map_params["n_enemies"]
        self.episode_limit = map_params["limit"]
        self._move_amount = move_amount
        self._step_mul = step_mul
        self.difficulty = difficulty

        # Observations and state
        self.obs_own_health = obs_own_health
        self.obs_all_health = obs_all_health
        self.obs_instead_of_state = state_config["use_obs_instead_of_state"]
        self.obs_last_action = obs_last_action
        self.use_global_state = state_config["use_global_state"]
        self.global_state_include_info = state_config["global_state_include_info"]

        self.obs_pathing_grid = obs_pathing_grid
        self.obs_terrain_height = obs_terrain_height
        self.obs_timestep_number = obs_timestep_number
        self.obs_agent_id = obs_agent_id
        self.state_pathing_grid = state_config["state_pathing_grid"]
        self.state_terrain_height = state_config["state_terrain_height"]
        self.state_last_action = state_config["state_last_action"]
        self.state_timestep_number = state_config["state_timestep_number"]
        self.state_agent_id = state_agent_id
        if self.obs_all_health:
            self.obs_own_health = True
        self.n_obs_pathing = 8
        self.n_obs_height = 9

        # Rewards args
        self.reward_sparse = reward_sparse
        self.reward_only_positive = reward_only_positive
        self.reward_negative_scale = reward_negative_scale
        self.reward_death_value = reward_death_value
        self.reward_win = reward_win
        self.reward_defeat = reward_defeat

        self.reward_scale = reward_scale
        self.reward_scale_rate = reward_scale_rate

        # Other
        self.game_version = game_version
        self.continuing_episode = continuing_episode
        self._seed = seed
        self.heuristic_ai = heuristic_ai
        self.heuristic_rest = heuristic_rest
        self.debug = debug
        self.window_size = (window_size_x, window_size_y)
        self.replay_dir = replay_dir
        self.replay_prefix = replay_prefix

        # Actions
        self.n_actions_no_attack = 6
        self.n_actions_move = 4
        self.n_actions = self.n_actions_no_attack + self.n_enemies

        # Map info
        self._agent_race = map_params["a_race"]
        self._bot_race = map_params["b_race"]
        self.shield_bits_ally = 1 if self._agent_race == "P" else 0
        self.shield_bits_enemy = 1 if self._bot_race == "P" else 0
        self.unit_type_bits = map_params["unit_type_bits"]
        self.map_type = map_params["map_type"]

        self.max_reward = self.n_enemies * self.reward_death_value + self.reward_win

        self.agents = {}
        self.enemies = {}
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self._obs = None
        self.battles_won = 0
        self.battles_game = 0
        self.timeouts = 0
        self.force_restarts = 0
        self.last_stats = None
        self.death_tracker_ally = np.zeros(self.n_agents, dtype=np.float32)
        self.death_tracker_enemy = np.zeros(self.n_enemies, dtype=np.float32)
        self.previous_ally_units = None
        self.previous_enemy_units = None
        self.last_action = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)
        self._min_unit_type = 0
        self.marine_id = self.marauder_id = self.medivac_id = 0
        self.hydralisk_id = self.zergling_id = self.baneling_id = 0
        self.stalker_id = self.colossus_id = self.zealot_id = 0
        self.max_distance_x = 0
        self.max_distance_y = 0
        self.map_x = 0
        self.map_y = 0
        self.terrain_height = None
        self.pathing_grid = None
        self._run_config = None
        self._sc2_proc = None
        self._controller = None

        # Try to avoid leaking SC2 processes on shutdown
        atexit.register(lambda: self.close())

        self.action_space = []
        self.observation_space = []
        self.share_observation_space = []
        for i in range(self.n_agents):
            self.action_space.append(Discrete(self.n_actions))
            self.observation_space.append(self.get_obs_size())
            self.share_observation_space.append(self.get_state_size())

        if self.use_stacked_frames:
            self.stacked_local_obs = np.zeros(
                (
                    self.n_agents,
                    self.stacked_frames,
                    int(self.get_obs_size()[0] / self.stacked_frames),
                ),
                dtype=np.float32,
            )
            self.stacked_global_state = np.zeros(
                (
                    self.n_agents,
                    self.stacked_frames,
                    int(self.get_state_size()[0] / self.stacked_frames),
                ),
                dtype=np.float32,
            )

    def _launch(self):
        """Launch the StarCraft II game."""
        self._run_config = run_configs.get(version=self.game_version)
        _map = maps.get(self.map_name)
        self._seed += 1

        # Setting up the interface
        interface_options = sc_pb.InterfaceOptions(raw=True, score=False)
        self._sc2_proc = self._run_config.start(
            window_size=self.window_size, want_rgb=False
        )
        self._controller = self._sc2_proc.controller

        # Request to create the game
        create = sc_pb.RequestCreateGame(
            local_map=sc_pb.LocalMap(
                map_path=_map.path, map_data=self._run_config.map_data(_map.path)
            ),
            realtime=False,
            random_seed=self._seed,
        )
        create.player_setup.add(type=sc_pb.Participant)
        create.player_setup.add(
            type=sc_pb.Computer,
            race=races[self._bot_race],
            difficulty=difficulties[self.difficulty],
        )
        self._controller.create_game(create)

        join = sc_pb.RequestJoinGame(
            race=races[self._agent_race], options=interface_options
        )
        self._controller.join_game(join)

        game_info = self._controller.game_info()
        map_info = game_info.start_raw
        map_play_area_min = map_info.playable_area.p0
        map_play_area_max = map_info.playable_area.p1
        self.max_distance_x = map_play_area_max.x - map_play_area_min.x
        self.max_distance_y = map_play_area_max.y - map_play_area_min.y
        self.map_x = map_info.map_size.x
        self.map_y = map_info.map_size.y

        if map_info.pathing_grid.bits_per_pixel == 1:
            vals = np.array(list(map_info.pathing_grid.data)).reshape(
                self.map_x, int(self.map_y / 8)
            )
            self.pathing_grid = np.transpose(
                np.array(
                    [
                        [(b >> i) & 1 for b in row for i in range(7, -1, -1)]
                        for row in vals
                    ],
                    dtype=bool,
                )
            )
        else:
            self.pathing_grid = np.invert(
                np.flip(
                    np.transpose(
                        np.array(
                            list(map_info.pathing_grid.data), dtype=bool
                        ).reshape(self.map_x, self.map_y)
                    ),
                    axis=1,
                )
            )

        self.terrain_height = (
            np.flip(
                np.transpose(
                    np.array(list(map_info.terrain_height.data)).reshape(
                        self.map_x, self.map_y
                    )
                ),
                1,
            )
            / 255
        )

    def reset(self):
        """Reset the environment. Required after each full episode.
        Returns initial observations and states.
        """
        self._episode_steps = 0
        if self._episode_count == 0:
            # Launch StarCraft II
            self._launch()
        else:
            self._restart()

        # Information kept for counting the reward
        self.death_tracker_ally = np.zeros(self.n_agents, dtype=np.float32)
        self.death_tracker_enemy = np.zeros(self.n_enemies, dtype=np.float32)
        self.previous_ally_units = None
        self.previous_enemy_units = None
        self.win_counted = False
        self.defeat_counted = False

        self.last_action = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)

        if self.heuristic_ai:
            self.heuristic_targets = [None] * self.n_agents

        try:
            self._obs = self._controller.observe()
            self.init_units()
        except (protocol.ProtocolError, protocol.ConnectionError):
            self.full_restart()

        available_actions = []
        for i in range(self.n_agents):
            available_actions.append(self.get_avail_agent_actions(i))

        if self.debug:
            logging.debug(
                "Started Episode {}".format(self._episode_count).center(60, "*")
            )

        if self.use_state_agent:
            global_state = [
                self.get_state_agent(agent_id) for agent_id in range(self.n_agents)
            ]
        elif self.use_global_state:
            global_state = [
                self.get_global_state() for agent_id in range(self.n_agents)
            ]
        else:
            global_state = [
                self.get_state(agent_id) for agent_id in range(self.n_agents)
            ]

        local_obs = self.get_obs()

        if self.use_stacked_frames:
            self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)
            self.stacked_global_state = np.roll(self.stacked_global_state, 1, axis=1)

            self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()
            self.stacked_global_state[:, -1, :] = np.array(global_state).copy()

            local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)
            global_state = self.stacked_global_state.reshape(self.n_agents, -1)

        return local_obs, global_state, available_actions

    def load_state_config(self, state_type):
        base_path = osp.split(osp.split(osp.dirname(osp.abspath(__file__)))[0])[0]
        state_config_path = (
            Path(base_path)
            / "configs"
            / "envs_cfgs"
            / "smac_state_config"
            / f"{state_type}.yaml"
        )
        with open(str(state_config_path), "r", encoding="utf-8") as file:
            state_config = yaml.load(file, Loader=yaml.FullLoader)
        return state_config

    def _restart(self):
        """Restart the environment by killing all units on the map.
        There is a trigger in the SC2Map file, which restarts the
        episode when there are no units left.
        """
        try:
            self._kill_all_units()
            self._controller.step(2)
        except (protocol.ProtocolError, protocol.ConnectionError):
            self.full_restart()

    def full_restart(self):
        """Full restart. Closes the SC2 process and launches a new one."""
        self._sc2_proc.close()
        self._launch()
        self.force_restarts += 1

    def step(self, actions):
        """A single environment step. Returns reward, terminated, info."""
        terminated = False
        bad_transition = False
        infos = [{} for i in range(self.n_agents)]
        dones = np.zeros((self.n_agents), dtype=bool)
    
        # actions_int = [int(a) for a in actions]
        # print(actions)
        actions_int = [int(a.item()) for a in actions]

        self.last_action = np.eye(self.n_actions)[np.array(actions_int)]

        # Collect individual actions
        sc_actions = []
        if self.debug:
            logging.debug("Actions".center(60, "-"))

        for a_id, action in enumerate(actions_int):
            if not self.heuristic_ai:
                sc_action = self.get_agent_action(a_id, action)
            else:
                sc_action, action_num = self.get_agent_action_heuristic(a_id, action)
                actions[a_id] = action_num
            if sc_action:
                sc_actions.append(sc_action)

        # Send action request
        req_actions = sc_pb.RequestAction(actions=sc_actions)
        try:
            self._controller.actions(req_actions)
            # Make step in SC2, i.e. apply actions
            self._controller.step(self._step_mul)
            # Observe here so that we know if the episode is over.
            self._obs = self._controller.observe()
        except (protocol.ProtocolError, protocol.ConnectionError):
            self.full_restart()
            terminated = True
            available_actions = []
            for i in range(self.n_agents):
                available_actions.append(self.get_avail_agent_actions(i))
                infos[i] = {
                    "battles_won": self.battles_won,
                    "battles_game": self.battles_game,
                    "battles_draw": self.timeouts,
                    "restarts": self.force_restarts,
                    "bad_transition": bad_transition,
                    "won": self.win_counted,
                }
                if terminated:
                    dones[i] = True
                else:
                    if self.death_tracker_ally[i]:
                        dones[i] = True
                    else:
                        dones[i] = False

            if self.use_state_agent:
                global_state = [
                    self.get_state_agent(agent_id) for agent_id in range(self.n_agents)
                ]
            elif self.use_global_state:
                global_state = [
                    self.get_global_state() for agent_id in range(self.n_agents)
                ]
            else:
                global_state = [
                    self.get_state(agent_id) for agent_id in range(self.n_agents)
                ]

            local_obs = self.get_obs()

            if self.use_stacked_frames:
                self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)
                self.stacked_global_state = np.roll(
                    self.stacked_global_state, 1, axis=1
                )

                self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()
                self.stacked_global_state[:, -1, :] = np.array(global_state).copy()

                local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)
                global_state = self.stacked_global_state.reshape(self.n_agents, -1)

            return (
                local_obs,
                global_state,
                [[0]] * self.n_agents,
                dones,
                infos,
                available_actions,
            )

        self._total_steps += 1
        self._episode_steps += 1

        # Update units
        game_end_code = self.update_units()

        reward = self.reward_battle()

        available_actions = []
        for i in range(self.n_agents):
            available_actions.append(self.get_avail_agent_actions(i))

        if game_end_code is not None:
            # Battle is over
            terminated = True
            self.battles_game += 1
            if game_end_code == 1 and not self.win_counted:
                self.battles_won += 1
                self.win_counted = True
                if not self.reward_sparse:
                    reward += self.reward_win
                else:
                    reward = 1
            elif game_end_code == -1 and not self.defeat_counted:
                self.defeat_counted = True
                if not self.reward_sparse:
                    reward += self.reward_defeat
                else:
                    reward = -1

        elif self._episode_steps >= self.episode_limit:
            # Episode limit reached
            terminated = True
            bad_transition = True
            if self.continuing_episode:
                info["episode_limit"] = True
            self.battles_game += 1
            self.timeouts += 1

        for i in range(self.n_agents):
            infos[i] = {
                "battles_won": self.battles_won,
                "battles_game": self.battles_game,
                "battles_draw": self.timeouts,
                "restarts": self.force_restarts,
                "bad_transition": bad_transition,
                "won": self.win_counted,
            }

            if terminated:
                dones[i] = True
            else:
                if self.death_tracker_ally[i]:
                    dones[i] = True
                else:
                    dones[i] = False

        if self.debug:
            logging.debug("Reward = {}".format(reward).center(60, "-"))

        if terminated:
            self._episode_count += 1

        if self.reward_scale:
            reward /= self.max_reward / self.reward_scale_rate

        rewards = [[reward]] * self.n_agents

        if self.use_state_agent:
            global_state = [
                self.get_state_agent(agent_id) for agent_id in range(self.n_agents)
            ]
        elif self.use_global_state:
            global_state = [
                self.get_global_state() for agent_id in range(self.n_agents)
            ]
        else:
            global_state = [
                self.get_state(agent_id) for agent_id in range(self.n_agents)
            ]

        local_obs = self.get_obs()

        if self.use_stacked_frames:
            self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)
            self.stacked_global_state = np.roll(self.stacked_global_state, 1, axis=1)

            self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()
            self.stacked_global_state[:, -1, :] = np.array(global_state).copy()

            local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)
            global_state = self.stacked_global_state.reshape(self.n_agents, -1)

        return local_obs, global_state, rewards, dones, infos, available_actions

    def get_agent_action(self, a_id, action):
        """Construct the action for agent a_id."""
        avail_actions = self.get_avail_agent_actions(a_id)
        assert avail_actions[action] == 1, "Agent {} cannot perform action {}".format(
            a_id, action
        )

        unit = self.get_unit_by_id(a_id)
        tag = unit.tag
        x = unit.pos.x
        y = unit.pos.y

        if action == 0:
            # no-op (valid only when dead)
            assert unit.health == 0, "No-op only available for dead agents."
            if self.debug:
                logging.debug("Agent {}: Dead".format(a_id))
            return None
        elif action == 1:
            # stop
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["stop"], unit_tags=[tag], queue_command=False
            )
            if self.debug:
                logging.debug("Agent {}: Stop".format(a_id))

        elif action == 2:
            # move north
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(x=x, y=y + self._move_amount),
                unit_tags=[tag],
                queue_command=False,
            )
            if self.debug:
                logging.debug("Agent {}: Move North".format(a_id))

        elif action == 3:
            # move south
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(x=x, y=y - self._move_amount),
                unit_tags=[tag],
                queue_command=False,
            )
            if self.debug:
                logging.debug("Agent {}: Move South".format(a_id))

        elif action == 4:
            # move east
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(x=x + self._move_amount, y=y),
                unit_tags=[tag],
                queue_command=False,
            )
            if self.debug:
                logging.debug("Agent {}: Move East".format(a_id))

        elif action == 5:
            # move west
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=sc_common.Point2D(x=x - self._move_amount, y=y),
                unit_tags=[tag],
                queue_command=False,
            )
            if self.debug:
                logging.debug("Agent {}: Move West".format(a_id))
        else:
            # attack/heal units that are in range
            target_id = action - self.n_actions_no_attack
            if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
                target_unit = self.agents[target_id]
                action_name = "heal"
            else:
                target_unit = self.enemies[target_id]
                action_name = "attack"

            action_id = actions[action_name]
            target_tag = target_unit.tag

            cmd = r_pb.ActionRawUnitCommand(
                ability_id=action_id,
                target_unit_tag=target_tag,
                unit_tags=[tag],
                queue_command=False,
            )

            if self.debug:
                logging.debug(
                    "Agent {} {}s unit # {}".format(a_id, action_name, target_id)
                )

        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
        return sc_action

    def get_agent_action_heuristic(self, a_id, action):
        unit = self.get_unit_by_id(a_id)
        tag = unit.tag

        target = self.heuristic_targets[a_id]
        if unit.unit_type == self.medivac_id:
            if (
                target is None
                or self.agents[target].health == 0
                or self.agents[target].health == self.agents[target].health_max
            ):
                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
                min_id = -1
                for al_id, al_unit in self.agents.items():
                    if al_unit.unit_type == self.medivac_id:
                        continue
                    if al_unit.health != 0 and al_unit.health != al_unit.health_max:
                        dist = self.distance(
                            unit.pos.x, unit.pos.y, al_unit.pos.x, al_unit.pos.y
                        )
                        if dist < min_dist:
                            min_dist = dist
                            min_id = al_id
                self.heuristic_targets[a_id] = min_id
                if min_id == -1:
                    self.heuristic_targets[a_id] = None
                    return None, 0
            action_id = actions["heal"]
            target_tag = self.agents[self.heuristic_targets[a_id]].tag
        else:
            if target is None or self.enemies[target].health == 0:
                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
                min_id = -1
                for e_id, e_unit in self.enemies.items():
                    if (
                        unit.unit_type == self.marauder_id
                        and e_unit.unit_type == self.medivac_id
                    ):
                        continue
                    if e_unit.health > 0:
                        dist = self.distance(
                            unit.pos.x, unit.pos.y, e_unit.pos.x, e_unit.pos.y
                        )
                        if dist < min_dist:
                            min_dist = dist
                            min_id = e_id
                self.heuristic_targets[a_id] = min_id
                if min_id == -1:
                    self.heuristic_targets[a_id] = None
                    return None, 0
            action_id = actions["attack"]
            target_tag = self.enemies[self.heuristic_targets[a_id]].tag

        action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack

        # Check if the action is available
        if self.heuristic_rest and self.get_avail_agent_actions(a_id)[action_num] == 0:
            # Move towards the target rather than attacking/healing
            if unit.unit_type == self.medivac_id:
                target_unit = self.agents[self.heuristic_targets[a_id]]
            else:
                target_unit = self.enemies[self.heuristic_targets[a_id]]

            delta_x = target_unit.pos.x - unit.pos.x
            delta_y = target_unit.pos.y - unit.pos.y

            if abs(delta_x) > abs(delta_y):  # east or west
                if delta_x > 0:  # east
                    target_pos = sc_common.Point2D(
                        x=unit.pos.x + self._move_amount, y=unit.pos.y
                    )
                    action_num = 4
                else:  # west
                    target_pos = sc_common.Point2D(
                        x=unit.pos.x - self._move_amount, y=unit.pos.y
                    )
                    action_num = 5
            else:  # north or south
                if delta_y > 0:  # north
                    target_pos = sc_common.Point2D(
                        x=unit.pos.x, y=unit.pos.y + self._move_amount
                    )
                    action_num = 2
                else:  # south
                    target_pos = sc_common.Point2D(
                        x=unit.pos.x, y=unit.pos.y - self._move_amount
                    )
                    action_num = 3

            cmd = r_pb.ActionRawUnitCommand(
                ability_id=actions["move"],
                target_world_space_pos=target_pos,
                unit_tags=[tag],
                queue_command=False,
            )
        else:
            # Attack/heal the target
            cmd = r_pb.ActionRawUnitCommand(
                ability_id=action_id,
                target_unit_tag=target_tag,
                unit_tags=[tag],
                queue_command=False,
            )

        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
        return sc_action, action_num

    def reward_battle(self):
        """Reward function when self.reward_spare==False.
        Returns accumulative hit/shield point damage dealt to the enemy
        + reward_death_value per enemy unit killed, and, in case
        self.reward_only_positive == False, - (damage dealt to ally units
        + reward_death_value per ally unit killed) * self.reward_negative_scale
        """
        if self.reward_sparse:
            return 0

        reward = 0
        delta_deaths = 0
        delta_ally = 0
        delta_enemy = 0

        neg_scale = self.reward_negative_scale

        # update deaths
        for al_id, al_unit in self.agents.items():
            if not self.death_tracker_ally[al_id]:
                # did not die so far
                prev_health = (
                    self.previous_ally_units[al_id].health
                    + self.previous_ally_units[al_id].shield
                )
                if al_unit.health == 0:
                    # just died
                    self.death_tracker_ally[al_id] = 1
                    if not self.reward_only_positive:
                        delta_deaths -= self.reward_death_value * neg_scale
                    delta_ally += prev_health * neg_scale
                else:
                    # still alive
                    delta_ally += neg_scale * (
                        prev_health - al_unit.health - al_unit.shield
                    )

        for e_id, e_unit in self.enemies.items():
            if not self.death_tracker_enemy[e_id]:
                prev_health = (
                    self.previous_enemy_units[e_id].health
                    + self.previous_enemy_units[e_id].shield
                )
                if e_unit.health == 0:
                    self.death_tracker_enemy[e_id] = 1
                    delta_deaths += self.reward_death_value
                    delta_enemy += prev_health
                else:
                    delta_enemy += prev_health - e_unit.health - e_unit.shield

        if self.reward_only_positive:
            reward = abs(delta_enemy + delta_deaths)  # shield regeneration
        else:
            reward = delta_enemy + delta_deaths - delta_ally

        return reward

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    @staticmethod
    def distance(x1, y1, x2, y2):
        """Distance between two points."""
        return math.hypot(x2 - x1, y2 - y1)

    def unit_shoot_range(self, agent_id):
        """Returns the shooting range for an agent."""
        return 6

    def unit_sight_range(self, agent_id):
        """Returns the sight range for an agent."""
        return 9

    def unit_max_cooldown(self, unit):
        """Returns the maximal cooldown for a unit."""
        switcher = {
            self.marine_id: 15,
            self.marauder_id: 25,
            self.medivac_id: 200,  # max energy
            self.stalker_id: 35,
            self.zealot_id: 22,
            self.colossus_id: 24,
            self.hydralisk_id: 10,
            self.zergling_id: 11,
            self.baneling_id: 1,
        }
        return switcher.get(unit.unit_type, 15)

    def save_replay(self):
        """Save a replay."""
        prefix = self.replay_prefix or self.map_name
        replay_dir = self.replay_dir or ""
        replay_path = self._run_config.save_replay(
            self._controller.save_replay(), replay_dir=replay_dir, prefix=prefix
        )
        logging.info("Replay saved at: %s" % replay_path)

    def unit_max_shield(self, unit):
        """Returns maximal shield for a given unit."""
        if unit.unit_type == 74 or unit.unit_type == self.stalker_id:
            return 80  # Protoss's Stalker
        if unit.unit_type == 73 or unit.unit_type == self.zealot_id:
            return 50  # Protoss's Zaelot
        if unit.unit_type == 4 or unit.unit_type == self.colossus_id:
            return 150  # Protoss's Colossus

    def can_move(self, unit, direction):
        """Whether a unit can move in a given direction."""
        m = self._move_amount / 2

        if direction == Direction.NORTH:
            x, y = int(unit.pos.x), int(unit.pos.y + m)
        elif direction == Direction.SOUTH:
            x, y = int(unit.pos.x), int(unit.pos.y - m)
        elif direction == Direction.EAST:
            x, y = int(unit.pos.x + m), int(unit.pos.y)
        else:
            x, y = int(unit.pos.x - m), int(unit.pos.y)

        if self.check_bounds(x, y) and self.pathing_grid[x, y]:
            return True

        return False

    def get_surrounding_points(self, unit, include_self=False):
        """Returns the surrounding points of the unit in 8 directions."""
        x = int(unit.pos.x)
        y = int(unit.pos.y)

        ma = self._move_amount

        points = [
            (x, y + 2 * ma),
            (x, y - 2 * ma),
            (x + 2 * ma, y),
            (x - 2 * ma, y),
            (x + ma, y + ma),
            (x - ma, y - ma),
            (x + ma, y - ma),
            (x - ma, y + ma),
        ]

        if include_self:
            points.append((x, y))

        return points

    def check_bounds(self, x, y):
        """Whether a point is within the map bounds."""
        return 0 <= x < self.map_x and 0 <= y < self.map_y

    def get_surrounding_pathing(self, unit):
        """Returns pathing values of the grid surrounding the given unit."""
        points = self.get_surrounding_points(unit, include_self=False)
        vals = [
            self.pathing_grid[x, y] if self.check_bounds(x, y) else 1 for x, y in points
        ]
        return vals

    def get_surrounding_height(self, unit):
        """Returns height values of the grid surrounding the given unit."""
        points = self.get_surrounding_points(unit, include_self=True)
        vals = [
            self.terrain_height[x, y] if self.check_bounds(x, y) else 1
            for x, y in points
        ]
        return vals

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id. The observation is composed of:

        - agent movement features (where it can move to, height information and pathing grid)
        - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)
        - ally features (visible, distance, relative_x, relative_y, shield, unit_type)
        - agent unit features (health, shield, unit_type)

        All of this information is flattened and concatenated into a list,
        in the aforementioned order. To know the sizes of each of the
        features inside the final list of features, take a look at the
        functions ``get_obs_move_feats_size()``,
        ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
        ``get_obs_own_feats_size()``.

        The size of the observation vector may vary, depending on the
        environment configuration and type of units present in the map.
        For instance, non-Protoss units will not have shields, movement
        features may or may not include terrain height and pathing grid,
        unit_type is not included if there is only one type of unit in the
        map etc.).

        NOTE: Agents should have access only to their local observations
        during decentralised execution.
        """
        unit = self.get_unit_by_id(agent_id)

        move_feats_dim = self.get_obs_move_feats_size()
        enemy_feats_dim = self.get_obs_enemy_feats_size()
        ally_feats_dim = self.get_obs_ally_feats_size()
        own_feats_dim = self.get_obs_own_feats_size()

        move_feats = np.zeros(move_feats_dim, dtype=np.float32)
        enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
        ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
        own_feats = np.zeros(own_feats_dim, dtype=np.float32)
        agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)

        if unit.health > 0:  # otherwise dead, return all zeros
            x = unit.pos.x
            y = unit.pos.y
            sight_range = self.unit_sight_range(agent_id)

            # Movement features
            avail_actions = self.get_avail_agent_actions(agent_id)
            for m in range(self.n_actions_move):
                move_feats[m] = avail_actions[m + 2]

            ind = self.n_actions_move

            if self.obs_pathing_grid:
                move_feats[
                    ind : ind + self.n_obs_pathing
                ] = self.get_surrounding_pathing(unit)
                ind += self.n_obs_pathing

            if self.obs_terrain_height:
                move_feats[ind:] = self.get_surrounding_height(unit)

            # Enemy features
            for e_id, e_unit in self.enemies.items():
                e_x = e_unit.pos.x
                e_y = e_unit.pos.y
                dist = self.distance(x, y, e_x, e_y)

                if dist < sight_range and e_unit.health > 0:  # visible and alive
                    # Sight range > shoot range
                    # available
                    enemy_feats[e_id, 0] = avail_actions[
                        self.n_actions_no_attack + e_id
                    ]
                    enemy_feats[e_id, 1] = dist / sight_range  # distance
                    enemy_feats[e_id, 2] = (e_x - x) / sight_range  # relative X
                    enemy_feats[e_id, 3] = (e_y - y) / sight_range  # relative Y

                    ind = 4
                    if self.obs_all_health:
                        enemy_feats[e_id, ind] = (
                            e_unit.health / e_unit.health_max
                        )  # health
                        ind += 1
                        if self.shield_bits_enemy > 0:
                            max_shield = self.unit_max_shield(e_unit)
                            enemy_feats[e_id, ind] = (
                                e_unit.shield / max_shield
                            )  # shield
                            ind += 1

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(e_unit, False)
                        enemy_feats[e_id, ind + type_id] = 1  # unit type

            # Ally features
            al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id]
            for i, al_id in enumerate(al_ids):
                al_unit = self.get_unit_by_id(al_id)
                al_x = al_unit.pos.x
                al_y = al_unit.pos.y
                dist = self.distance(x, y, al_x, al_y)

                if dist < sight_range and al_unit.health > 0:  # visible and alive
                    ally_feats[i, 0] = 1  # visible
                    ally_feats[i, 1] = dist / sight_range  # distance
                    ally_feats[i, 2] = (al_x - x) / sight_range  # relative X
                    ally_feats[i, 3] = (al_y - y) / sight_range  # relative Y

                    ind = 4
                    if self.obs_all_health:
                        ally_feats[i, ind] = (
                            al_unit.health / al_unit.health_max
                        )  # health
                        ind += 1
                        if self.shield_bits_ally > 0:
                            max_shield = self.unit_max_shield(al_unit)
                            ally_feats[i, ind] = al_unit.shield / max_shield  # shield
                            ind += 1

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(al_unit, True)
                        ally_feats[i, ind + type_id] = 1
                        ind += self.unit_type_bits

                    if self.obs_last_action:
                        ally_feats[i, ind:] = self.last_action[al_id]

            # Own features
            ind = 0
            own_feats[0] = 1  # visible
            own_feats[1] = 0  # distance
            own_feats[2] = 0  # X
            own_feats[3] = 0  # Y
            ind = 4
            if self.obs_own_health:
                own_feats[ind] = unit.health / unit.health_max
                ind += 1
                if self.shield_bits_ally > 0:
                    max_shield = self.unit_max_shield(unit)
                    own_feats[ind] = unit.shield / max_shield
                    ind += 1

            if self.unit_type_bits > 0:
                type_id = self.get_unit_type_id(unit, True)
                own_feats[ind + type_id] = 1
                # print(ind, self.unit_type_bits)
                ind += self.unit_type_bits

            if self.obs_last_action:
                # print(self.last_action[agent_id].shape)
                own_feats[ind:] = self.last_action[agent_id]

        agent_obs = np.concatenate(
            (
                ally_feats.flatten(),
                enemy_feats.flatten(),
                move_feats.flatten(),
                own_feats.flatten(),
            )
        )
        # print(agent_obs.shape)
        # print(ally_feats.shape, enemy_feats.shape, move_feats.shape, own_feats.shape)

        # Agent id features
        if self.obs_agent_id:
            agent_id_feats[agent_id] = 1.0
            agent_obs = np.concatenate(
                (
                    ally_feats.flatten(),
                    enemy_feats.flatten(),
                    move_feats.flatten(),
                    own_feats.flatten(),
                    agent_id_feats.flatten(),
                )
            )

        if self.obs_timestep_number:
            agent_obs = np.append(agent_obs, self._episode_steps / self.episode_limit)

        if self.debug:
            logging.debug("Obs Agent: {}".format(agent_id).center(60, "-"))
            logging.debug(
                "Avail. actions {}".format(self.get_avail_agent_actions(agent_id))
            )
            logging.debug("Move feats {}".format(move_feats))
            logging.debug("Enemy feats {}".format(enemy_feats))
            logging.debug("Ally feats {}".format(ally_feats))
            logging.debug("Own feats {}".format(own_feats))

        return agent_obs

    def get_obs(self):
        """Returns all agent observations in a list.
        NOTE: Agents should have access only to their local observations
        during decentralised execution.
        """
        agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
        return agents_obs

    def get_state(self, agent_id=-1):
        """Returns the global state.
        NOTE: This functon should not be used during decentralised execution.
        """
        if self.obs_instead_of_state:
            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)
            return obs_concat

        nf_al = 2 + self.shield_bits_ally + self.unit_type_bits
        nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits

        if self.add_center_xy:
            nf_al += 2
            nf_en += 2

        if self.add_distance_state:
            nf_al += 1
            nf_en += 1

        if self.add_xy_state:
            nf_al += 2
            nf_en += 2

        if self.add_visible_state:
            nf_al += 1
            nf_en += 1

        if self.state_last_action:
            nf_al += self.n_actions
            nf_en += self.n_actions

        if self.add_enemy_action_state:
            nf_en += 1

        nf_mv = self.get_state_move_feats_size()

        ally_state = np.zeros((self.n_agents, nf_al), dtype=np.float32)
        enemy_state = np.zeros((self.n_enemies, nf_en), dtype=np.float32)
        move_state = np.zeros((1, nf_mv), dtype=np.float32)
        agent_id_feats = np.zeros((self.n_agents, 1), dtype=np.float32)

        center_x = self.map_x / 2
        center_y = self.map_y / 2

        unit = self.get_unit_by_id(agent_id)  # get the unit of some agent
        x = unit.pos.x
        y = unit.pos.y
        sight_range = self.unit_sight_range(agent_id)
        avail_actions = self.get_avail_agent_actions(agent_id)

        if (self.use_mustalive and unit.health > 0) or (
            not self.use_mustalive
        ):  # or else all zeros
            # Movement features
            for m in range(self.n_actions_move):
                move_state[0, m] = avail_actions[m + 2]

            ind = self.n_actions_move

            if self.state_pathing_grid:
                move_state[
                    0, ind : ind + self.n_obs_pathing
                ] = self.get_surrounding_pathing(unit)
                ind += self.n_obs_pathing

            if self.state_terrain_height:
                move_state[0, ind:] = self.get_surrounding_height(unit)

            for al_id, al_unit in self.agents.items():
                if al_unit.health > 0:
                    al_x = al_unit.pos.x
                    al_y = al_unit.pos.y
                    max_cd = self.unit_max_cooldown(al_unit)
                    dist = self.distance(x, y, al_x, al_y)

                    ally_state[al_id, 0] = al_unit.health / al_unit.health_max  # health
                    if self.map_type == "MMM" and al_unit.unit_type == self.medivac_id:
                        ally_state[al_id, 1] = al_unit.energy / max_cd  # energy
                    else:
                        ally_state[al_id, 1] = (
                            al_unit.weapon_cooldown / max_cd
                        )  # cooldown

                    ind = 2

                    if self.add_center_xy:
                        ally_state[al_id, ind] = (
                            al_x - center_x
                        ) / self.max_distance_x  # center X
                        # center Y
                        ally_state[al_id, ind + 1] = (
                            al_y - center_y
                        ) / self.max_distance_y
                        ind += 2

                    if self.shield_bits_ally > 0:
                        max_shield = self.unit_max_shield(al_unit)
                        ally_state[al_id, ind] = al_unit.shield / max_shield  # shield
                        ind += 1

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(al_unit, True)
                        ally_state[al_id, ind + type_id] = 1

                    if unit.health > 0:
                        ind += self.unit_type_bits
                        if self.add_distance_state:
                            ally_state[al_id, ind] = dist / sight_range  # distance
                            ind += 1
                        if self.add_xy_state:
                            ally_state[al_id, ind] = (
                                al_x - x
                            ) / sight_range  # relative X
                            # relative Y
                            ally_state[al_id, ind + 1] = (al_y - y) / sight_range
                            ind += 2
                        if self.add_visible_state:
                            if dist < sight_range:
                                ally_state[al_id, ind] = 1  # visible
                            ind += 1
                        if self.state_last_action:
                            ally_state[al_id, ind:] = self.last_action[al_id]

            for e_id, e_unit in self.enemies.items():
                if e_unit.health > 0:
                    e_x = e_unit.pos.x
                    e_y = e_unit.pos.y
                    dist = self.distance(x, y, e_x, e_y)

                    enemy_state[e_id, 0] = e_unit.health / e_unit.health_max  # health

                    ind = 1
                    if self.add_center_xy:
                        enemy_state[e_id, ind] = (
                            e_x - center_x
                        ) / self.max_distance_x  # center X
                        # center Y
                        enemy_state[e_id, ind + 1] = (
                            e_y - center_y
                        ) / self.max_distance_y
                        ind += 2

                    if self.shield_bits_enemy > 0:
                        max_shield = self.unit_max_shield(e_unit)
                        enemy_state[e_id, ind] = e_unit.shield / max_shield  # shield
                        ind += 1

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(e_unit, False)
                        enemy_state[e_id, ind + type_id] = 1

                    if unit.health > 0:
                        ind += self.unit_type_bits
                        if self.add_distance_state:
                            enemy_state[e_id, ind] = dist / sight_range  # distance
                            ind += 1
                        if self.add_xy_state:
                            enemy_state[e_id, ind] = (
                                e_x - x
                            ) / sight_range  # relative X
                            # relative Y
                            enemy_state[e_id, ind + 1] = (e_y - y) / sight_range
                            ind += 2
                        if self.add_visible_state:
                            if dist < sight_range:
                                enemy_state[e_id, ind] = 1  # visible
                            ind += 1
                        if self.add_enemy_action_state:
                            # available
                            enemy_state[e_id, ind] = avail_actions[
                                self.n_actions_no_attack + e_id
                            ]

        state = np.append(ally_state.flatten(), enemy_state.flatten())

        if self.add_move_state:
            state = np.append(state, move_state.flatten())

        if self.add_local_obs:
            state = np.append(state, self.get_obs_agent(agent_id).flatten())

        if self.state_timestep_number:
            state = np.append(state, self._episode_steps / self.episode_limit)

        if self.add_agent_id:
            agent_id_feats[agent_id] = 1.0
            state = np.append(state, agent_id_feats.flatten())

        state = state.astype(dtype=np.float32)

        if self.debug:
            logging.debug("STATE".center(60, "-"))
            logging.debug("Ally state {}".format(ally_state))
            logging.debug("Enemy state {}".format(enemy_state))
            logging.debug("Move state {}".format(move_state))
            if self.state_last_action:
                logging.debug("Last actions {}".format(self.last_action))

        return state

    def get_global_state(self):
        """Returns the agent-agnostic global state.
        NOTE: This functon should not be used during decentralised execution.
        """
        if self.obs_instead_of_state:
            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)
            return obs_concat

        nf_al = 2 + self.shield_bits_ally + self.unit_type_bits
        nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits

        if self.add_center_xy:
            nf_al += 2
            nf_en += 2

        if self.state_last_action:
            nf_al += self.n_actions

        nf_mv_glb = self.get_state_move_feats_size_global()

        ally_state = np.zeros((self.n_agents, nf_al), dtype=np.float32)
        enemy_state = np.zeros((self.n_enemies, nf_en), dtype=np.float32)
        move_state = np.zeros((self.n_agents, nf_mv_glb), dtype=np.float32)
        info_state = np.zeros((1, 5), dtype=np.float32)

        center_x = self.map_x / 2
        center_y = self.map_y / 2

        # move_state
        for agent_id in range(self.n_agents):
            unit = self.get_unit_by_id(agent_id)
            avail_actions = self.get_avail_agent_actions(agent_id)
            for m in range(self.n_actions):
                move_state[agent_id, m] = avail_actions[m]
            ind = self.n_actions
            if self.state_pathing_grid:
                move_state[
                    agent_id, ind : ind + self.n_obs_pathing
                ] = self.get_surrounding_pathing(unit)
                ind += self.n_obs_pathing
            if self.state_terrain_height:
                move_state[agent_id, ind:] = self.get_surrounding_height(unit)

        # ally_state
        for al_id, al_unit in self.agents.items():
            if al_unit.health > 0:
                al_x = al_unit.pos.x
                al_y = al_unit.pos.y
                max_cd = self.unit_max_cooldown(al_unit)

                ally_state[al_id, 0] = al_unit.health / al_unit.health_max  # health
                if self.map_type == "MMM" and al_unit.unit_type == self.medivac_id:
                    ally_state[al_id, 1] = al_unit.energy / max_cd  # energy
                else:
                    ally_state[al_id, 1] = al_unit.weapon_cooldown / max_cd  # cooldown

                ind = 2

                if self.add_center_xy:
                    ally_state[al_id, ind] = (
                        al_x - center_x
                    ) / self.max_distance_x  # center X
                    # center Y
                    ally_state[al_id, ind + 1] = (al_y - center_y) / self.max_distance_y
                    ind += 2

                if self.shield_bits_ally > 0:
                    max_shield = self.unit_max_shield(al_unit)
                    ally_state[al_id, ind] = al_unit.shield / max_shield  # shield
                    ind += 1

                if self.unit_type_bits > 0:
                    type_id = self.get_unit_type_id(al_unit, True)
                    ally_state[al_id, ind + type_id] = 1
                    ind += self.unit_type_bits

                if self.state_last_action:
                    ally_state[al_id, ind:] = self.last_action[al_id]

        # enemy_state
        for e_id, e_unit in self.enemies.items():
            if e_unit.health > 0:
                e_x = e_unit.pos.x
                e_y = e_unit.pos.y

                enemy_state[e_id, 0] = e_unit.health / e_unit.health_max  # health

                ind = 1
                if self.add_center_xy:
                    enemy_state[e_id, ind] = (
                        e_x - center_x
                    ) / self.max_distance_x  # center X
                    # center Y
                    enemy_state[e_id, ind + 1] = (e_y - center_y) / self.max_distance_y
                    ind += 2

                if self.shield_bits_enemy > 0:
                    max_shield = self.unit_max_shield(e_unit)
                    enemy_state[e_id, ind] = e_unit.shield / max_shield  # shield
                    ind += 1

                if self.unit_type_bits > 0:
                    type_id = self.get_unit_type_id(e_unit, False)
                    enemy_state[e_id, ind + type_id] = 1
                    ind += self.unit_type_bits

        # info_state
        info_state[0, 0] = self.map_x
        info_state[0, 1] = self.map_y
        info_state[0, 2] = self.max_distance_x
        info_state[0, 3] = self.max_distance_y
        info_state[0, 4] = self.unit_sight_range(0)

        state = np.append(ally_state.flatten(), enemy_state.flatten())

        if self.add_move_state:
            state = np.append(state, move_state.flatten())

        if self.state_timestep_number:
            state = np.append(state, self._episode_steps / self.episode_limit)

        if self.global_state_include_info:
            state = np.append(state, info_state.flatten())

        state = state.astype(dtype=np.float32)

        if self.debug:
            logging.debug("STATE".center(60, "-"))
            logging.debug("Ally state {}".format(ally_state))
            logging.debug("Enemy state {}".format(enemy_state))
            logging.debug("Move state {}".format(move_state))
            logging.debug("Info state {}".format(info_state))
            if self.state_last_action:
                logging.debug("Last actions {}".format(self.last_action))

        return state

    def get_state_agent(self, agent_id):
        """Returns observation for agent_id. The observation is composed of:

        - agent movement features (where it can move to, height information and pathing grid)
        - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)
        - ally features (visible, distance, relative_x, relative_y, shield, unit_type)
        - agent unit features (health, shield, unit_type)

        All of this information is flattened and concatenated into a list,
        in the aforementioned order. To know the sizes of each of the
        features inside the final list of features, take a look at the
        functions ``get_obs_move_feats_size()``,
        ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
        ``get_obs_own_feats_size()``.

        The size of the observation vector may vary, depending on the
        environment configuration and type of units present in the map.
        For instance, non-Protoss units will not have shields, movement
        features may or may not include terrain height and pathing grid,
        unit_type is not included if there is only one type of unit in the
        map etc.).

        NOTE: Agents should have access only to their local observations
        during decentralised execution.
        """
        if self.obs_instead_of_state:
            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)
            return obs_concat

        unit = self.get_unit_by_id(agent_id)

        move_feats_dim = self.get_obs_move_feats_size()
        enemy_feats_dim = self.get_state_enemy_feats_size()
        ally_feats_dim = self.get_state_ally_feats_size()
        own_feats_dim = self.get_state_own_feats_size()

        move_feats = np.zeros(move_feats_dim, dtype=np.float32)
        enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
        ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
        own_feats = np.zeros(own_feats_dim, dtype=np.float32)
        agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)

        center_x = self.map_x / 2
        center_y = self.map_y / 2

        # otherwise dead, return all zeros
        if (self.use_mustalive and unit.health > 0) or (not self.use_mustalive):
            x = unit.pos.x
            y = unit.pos.y
            sight_range = self.unit_sight_range(agent_id)

            # Movement features
            avail_actions = self.get_avail_agent_actions(agent_id)
            for m in range(self.n_actions_move):
                move_feats[m] = avail_actions[m + 2]

            ind = self.n_actions_move

            if self.state_pathing_grid:
                move_feats[
                    ind : ind + self.n_obs_pathing
                ] = self.get_surrounding_pathing(unit)
                ind += self.n_obs_pathing

            if self.state_terrain_height:
                move_feats[ind:] = self.get_surrounding_height(unit)

            # Enemy features
            for e_id, e_unit in self.enemies.items():
                e_x = e_unit.pos.x
                e_y = e_unit.pos.y
                dist = self.distance(x, y, e_x, e_y)

                if e_unit.health > 0:  # visible and alive
                    # Sight range > shoot range
                    if unit.health > 0:
                        # available
                        enemy_feats[e_id, 0] = avail_actions[
                            self.n_actions_no_attack + e_id
                        ]
                        enemy_feats[e_id, 1] = dist / sight_range  # distance
                        enemy_feats[e_id, 2] = (e_x - x) / sight_range  # relative X
                        enemy_feats[e_id, 3] = (e_y - y) / sight_range  # relative Y
                        if dist < sight_range:
                            enemy_feats[e_id, 4] = 1  # visible

                    ind = 5
                    if self.obs_all_health:
                        enemy_feats[e_id, ind] = (
                            e_unit.health / e_unit.health_max
                        )  # health
                        ind += 1
                        if self.shield_bits_enemy > 0:
                            max_shield = self.unit_max_shield(e_unit)
                            enemy_feats[e_id, ind] = (
                                e_unit.shield / max_shield
                            )  # shield
                            ind += 1

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(e_unit, False)
                        enemy_feats[e_id, ind + type_id] = 1  # unit type
                        ind += self.unit_type_bits

                    if self.add_center_xy:
                        enemy_feats[e_id, ind] = (
                            e_x - center_x
                        ) / self.max_distance_x  # center X
                        # center Y
                        enemy_feats[e_id, ind + 1] = (
                            e_y - center_y
                        ) / self.max_distance_y

            # Ally features
            al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id]
            for i, al_id in enumerate(al_ids):
                al_unit = self.get_unit_by_id(al_id)
                al_x = al_unit.pos.x
                al_y = al_unit.pos.y
                dist = self.distance(x, y, al_x, al_y)
                max_cd = self.unit_max_cooldown(al_unit)

                if al_unit.health > 0:  # visible and alive
                    if unit.health > 0:
                        if dist < sight_range:
                            ally_feats[i, 0] = 1  # visible
                        ally_feats[i, 1] = dist / sight_range  # distance
                        ally_feats[i, 2] = (al_x - x) / sight_range  # relative X
                        ally_feats[i, 3] = (al_y - y) / sight_range  # relative Y

                    if self.map_type == "MMM" and al_unit.unit_type == self.medivac_id:
                        ally_feats[i, 4] = al_unit.energy / max_cd  # energy
                    else:
                        ally_feats[i, 4] = al_unit.weapon_cooldown / max_cd  # cooldown

                    ind = 5
                    if self.obs_all_health:
                        ally_feats[i, ind] = (
                            al_unit.health / al_unit.health_max
                        )  # health
                        ind += 1
                        if self.shield_bits_ally > 0:
                            max_shield = self.unit_max_shield(al_unit)
                            ally_feats[i, ind] = al_unit.shield / max_shield  # shield
                            ind += 1

                    if self.add_center_xy:
                        ally_feats[i, ind] = (
                            al_x - center_x
                        ) / self.max_distance_x  # center X
                        # center Y
                        ally_feats[i, ind + 1] = (al_y - center_y) / self.max_distance_y
                        ind += 2

                    if self.unit_type_bits > 0:
                        type_id = self.get_unit_type_id(al_unit, True)
                        ally_feats[i, ind + type_id] = 1
                        ind += self.unit_type_bits

                    if self.state_last_action:
                        ally_feats[i, ind:] = self.last_action[al_id]

            # Own features
            ind = 0
            own_feats[0] = 1  # visible
            own_feats[1] = 0  # distance
            own_feats[2] = 0  # X
            own_feats[3] = 0  # Y
            ind = 4
            if self.obs_own_health:
                own_feats[ind] = unit.health / unit.health_max
                ind += 1
                if self.shield_bits_ally > 0:
                    max_shield = self.unit_max_shield(unit)
                    own_feats[ind] = unit.shield / max_shield
                    ind += 1

            if self.add_center_xy:
                own_feats[ind] = (x - center_x) / self.max_distance_x  # center X
                own_feats[ind + 1] = (y - center_y) / self.max_distance_y  # center Y
                ind += 2

            if self.unit_type_bits > 0:
                type_id = self.get_unit_type_id(unit, True)
                own_feats[ind + type_id] = 1
                ind += self.unit_type_bits

            if self.state_last_action:
                own_feats[ind:] = self.last_action[agent_id]

        state = np.concatenate(
            (
                ally_feats.flatten(),
                enemy_feats.flatten(),
                move_feats.flatten(),
                own_feats.flatten(),
            )
        )

        # Agent id features
        if self.state_agent_id:
            agent_id_feats[agent_id] = 1.0
            state = np.append(state, agent_id_feats.flatten())

        if self.state_timestep_number:
            state = np.append(state, self._episode_steps / self.episode_limit)

        if self.debug:
            logging.debug("Obs Agent: {}".format(agent_id).center(60, "-"))
            logging.debug(
                "Avail. actions {}".format(self.get_avail_agent_actions(agent_id))
            )
            logging.debug("Move feats {}".format(move_feats))
            logging.debug("Enemy feats {}".format(enemy_feats))
            logging.debug("Ally feats {}".format(ally_feats))
            logging.debug("Own feats {}".format(own_feats))

        return state

    def get_obs_enemy_feats_size(self):
        """Returns the dimensions of the matrix containing enemy features.
        Size is n_enemies x n_features.
        """
        nf_en = 4 + self.unit_type_bits

        if self.obs_all_health:
            nf_en += 1 + self.shield_bits_enemy

        return self.n_enemies, nf_en

    def get_state_enemy_feats_size(self):
        """Returns the dimensions of the matrix containing enemy features.
        Size is n_enemies x n_features.
        """
        nf_en = 5 + self.unit_type_bits

        if self.obs_all_health:
            nf_en += 1 + self.shield_bits_enemy

        if self.add_center_xy:
            nf_en += 2

        return self.n_enemies, nf_en

    def get_obs_ally_feats_size(self):
        """Returns the dimensions of the matrix containing ally features.
        Size is n_allies x n_features.
        """
        nf_al = 4 + self.unit_type_bits

        if self.obs_all_health:
            nf_al += 1 + self.shield_bits_ally

        if self.obs_last_action:
            nf_al += self.n_actions

        return self.n_agents - 1, nf_al

    def get_state_ally_feats_size(self):
        """Returns the dimensions of the matrix containing ally features.
        Size is n_allies x n_features.
        """
        nf_al = 5 + self.unit_type_bits

        if self.obs_all_health:
            nf_al += 1 + self.shield_bits_ally

        if self.obs_last_action:
            nf_al += self.n_actions

        if self.add_center_xy:
            nf_al += 2

        return self.n_agents - 1, nf_al

    def get_obs_own_feats_size(self):
        """Returns the size of the vector containing the agents' own features."""
        own_feats = 4 + self.unit_type_bits
        if self.obs_own_health:
            own_feats += 1 + self.shield_bits_ally

        if self.obs_last_action:
            own_feats += self.n_actions

        return own_feats

    def get_state_own_feats_size(self):
        """Returns the size of the vector containing the agents' own features."""
        own_feats = 4 + self.unit_type_bits
        if self.obs_own_health:
            own_feats += 1 + self.shield_bits_ally

        if self.obs_last_action:
            own_feats += self.n_actions

        if self.add_center_xy:
            own_feats += 2

        return own_feats

    def get_obs_move_feats_size(self):
        """Returns the size of the vector containing the agents's movement-related features."""
        move_feats = self.n_actions_move
        if self.obs_pathing_grid:
            move_feats += self.n_obs_pathing
        if self.obs_terrain_height:
            move_feats += self.n_obs_height

        return move_feats

    def get_state_move_feats_size(self):
        """Returns the size of the vector containing the agents's movement-related features."""
        move_feats = self.n_actions_move
        if self.state_pathing_grid:
            move_feats += self.n_obs_pathing
        if self.state_terrain_height:
            move_feats += self.n_obs_height

        return move_feats

    def get_state_move_feats_size_global(self):
        """Returns the size of the vector containing the agents's movement-related features. global"""
        move_feats = self.n_actions
        if self.state_pathing_grid:
            move_feats += self.n_obs_pathing
        if self.state_terrain_height:
            move_feats += self.n_obs_height

        return move_feats

    def get_obs_size(self):
        """Returns the size of the observation."""
        own_feats = self.get_obs_own_feats_size()
        move_feats = self.get_obs_move_feats_size()

        n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()
        n_allies, n_ally_feats = self.get_obs_ally_feats_size()

        enemy_feats = n_enemies * n_enemy_feats
        ally_feats = n_allies * n_ally_feats

        all_feats = move_feats + enemy_feats + ally_feats + own_feats

        agent_id_feats = 0
        timestep_feats = 0

        if self.obs_agent_id:
            agent_id_feats = self.n_agents
            all_feats += agent_id_feats

        if self.obs_timestep_number:
            timestep_feats = 1
            all_feats += timestep_feats

        return [
            all_feats * self.stacked_frames if self.use_stacked_frames else all_feats,
            [n_allies, n_ally_feats],
            [n_enemies, n_enemy_feats],
            [1, move_feats],
            [1, own_feats + agent_id_feats + timestep_feats],
        ]

    def get_state_size(self):
        """Returns the size of the global state."""
        if self.obs_instead_of_state:
            return [
                self.get_obs_size()[0] * self.n_agents,
                [self.n_agents, self.get_obs_size()[0]],
            ]

        if self.use_state_agent:
            own_feats = self.get_state_own_feats_size()
            move_feats = self.get_obs_move_feats_size()

            n_enemies, n_enemy_feats = self.get_state_enemy_feats_size()
            n_allies, n_ally_feats = self.get_state_ally_feats_size()

            enemy_feats = n_enemies * n_enemy_feats
            ally_feats = n_allies * n_ally_feats

            all_feats = move_feats + enemy_feats + ally_feats + own_feats

            agent_id_feats = 0
            timestep_feats = 0

            if self.state_agent_id:
                agent_id_feats = self.n_agents
                all_feats += agent_id_feats

            if self.state_timestep_number:
                timestep_feats = 1
                all_feats += timestep_feats

            return [
                all_feats * self.stacked_frames
                if self.use_stacked_frames
                else all_feats,
                [n_allies, n_ally_feats],
                [n_enemies, n_enemy_feats],
                [1, move_feats],
                [1, own_feats + agent_id_feats + timestep_feats],
            ]

        if self.use_global_state:
            nf_al = 2 + self.shield_bits_ally + self.unit_type_bits
            nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits

            if self.add_center_xy:
                nf_al += 2
                nf_en += 2

            if self.state_last_action:
                nf_al += self.n_actions

            nf_mv_glb = self.get_state_move_feats_size_global()

            enemy_state = self.n_enemies * nf_en
            ally_state = self.n_agents * nf_al

            size = enemy_state + ally_state

            move_state = 0
            timestep_state = 0
            info_state = 0

            if self.add_move_state:
                move_state = self.n_agents * nf_mv_glb
                size += move_state

            if self.state_timestep_number:
                timestep_state = 1
                size += timestep_state

            if self.global_state_include_info:
                info_state = 5
                size += info_state

            return [
                size * self.stacked_frames if self.use_stacked_frames else size,
                [self.n_agents, nf_al],
                [self.n_enemies, nf_en],
                [self.n_agents, nf_mv_glb if self.add_move_state else 0],
                [1, timestep_state],
                [1, info_state],
            ]

        nf_al = 2 + self.shield_bits_ally + self.unit_type_bits
        nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits
        nf_mv = self.get_state_move_feats_size()

        if self.add_center_xy:
            nf_al += 2
            nf_en += 2

        if self.state_last_action:
            nf_al += self.n_actions
            nf_en += self.n_actions

        if self.add_visible_state:
            nf_al += 1
            nf_en += 1

        if self.add_distance_state:
            nf_al += 1
            nf_en += 1

        if self.add_xy_state:
            nf_al += 2
            nf_en += 2

        if self.add_enemy_action_state:
            nf_en += 1

        enemy_state = self.n_enemies * nf_en
        ally_state = self.n_agents * nf_al

        size = enemy_state + ally_state

        move_state = 0
        obs_agent_size = 0
        timestep_state = 0
        agent_id_feats = 0

        if self.add_move_state:
            move_state = nf_mv
            size += move_state

        if self.add_local_obs:
            obs_agent_size = self.get_obs_size()[0]
            size += obs_agent_size

        if self.state_timestep_number:
            timestep_state = 1
            size += timestep_state

        if self.add_agent_id:
            agent_id_feats = self.n_agents
            size += agent_id_feats

        return [
            size * self.stacked_frames if self.use_stacked_frames else size,
            [self.n_agents, nf_al],
            [self.n_enemies, nf_en],
            [1, move_state + obs_agent_size + timestep_state + agent_id_feats],
        ]

    def get_visibility_matrix(self):
        """Returns a boolean numpy array of dimensions
        (n_agents, n_agents + n_enemies) indicating which units
        are visible to each agent.
        """
        arr = np.zeros((self.n_agents, self.n_agents + self.n_enemies), dtype=bool)

        for agent_id in range(self.n_agents):
            current_agent = self.get_unit_by_id(agent_id)
            if current_agent.health > 0:  # it agent not dead
                x = current_agent.pos.x
                y = current_agent.pos.y
                sight_range = self.unit_sight_range(agent_id)

                # Enemies
                for e_id, e_unit in self.enemies.items():
                    e_x = e_unit.pos.x
                    e_y = e_unit.pos.y
                    dist = self.distance(x, y, e_x, e_y)

                    if dist < sight_range and e_unit.health > 0:
                        # visible and alive
                        arr[agent_id, self.n_agents + e_id] = 1

                # The matrix for allies is filled symmetrically
                al_ids = [al_id for al_id in range(self.n_agents) if al_id > agent_id]
                for i, al_id in enumerate(al_ids):
                    al_unit = self.get_unit_by_id(al_id)
                    al_x = al_unit.pos.x
                    al_y = al_unit.pos.y
                    dist = self.distance(x, y, al_x, al_y)

                    if dist < sight_range and al_unit.health > 0:
                        # visible and alive
                        arr[agent_id, al_id] = arr[al_id, agent_id] = 1

        return arr

    def get_unit_type_id(self, unit, ally):
        """Returns the ID of unit type in the given scenario."""
        if ally:  # use new SC2 unit types
            type_id = unit.unit_type - self._min_unit_type
        else:  # use default SC2 unit types
            if self.map_type == "stalkers_and_zealots":
                # id(Stalker) = 74, id(Zealot) = 73
                type_id = unit.unit_type - 73
            elif self.map_type == "colossi_stalkers_zealots":
                # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4
                if unit.unit_type == 4:
                    type_id = 0
                elif unit.unit_type == 74:
                    type_id = 1
                else:
                    type_id = 2
            elif self.map_type == "bane":
                if unit.unit_type == 9:
                    type_id = 0
                else:
                    type_id = 1
            elif self.map_type == "MMM":
                if unit.unit_type == 51:
                    type_id = 0
                elif unit.unit_type == 48:
                    type_id = 1
                else:
                    type_id = 2

        return type_id

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        unit = self.get_unit_by_id(agent_id)
        if unit.health > 0:
            # cannot choose no-op when alive
            avail_actions = [0] * self.n_actions

            # stop should be allowed
            avail_actions[1] = 1

            # see if we can move
            if self.can_move(unit, Direction.NORTH):
                avail_actions[2] = 1
            if self.can_move(unit, Direction.SOUTH):
                avail_actions[3] = 1
            if self.can_move(unit, Direction.EAST):
                avail_actions[4] = 1
            if self.can_move(unit, Direction.WEST):
                avail_actions[5] = 1

            # Can attack only alive units that are alive in the shooting range
            shoot_range = self.unit_shoot_range(agent_id)

            target_items = self.enemies.items()
            if self.map_type == "MMM" and unit.unit_type == self.medivac_id:
                # Medivacs cannot heal themselves or other flying units
                target_items = [
                    (t_id, t_unit)
                    for (t_id, t_unit) in self.agents.items()
                    if t_unit.unit_type != self.medivac_id
                ]

            for t_id, t_unit in target_items:
                if t_unit.health > 0:
                    dist = self.distance(
                        unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y
                    )
                    if dist <= shoot_range:
                        avail_actions[t_id + self.n_actions_no_attack] = 1

            return avail_actions

        else:
            # only no-op allowed
            return [1] + [0] * (self.n_actions - 1)

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_agent = self.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def close(self):
        """Close StarCraft II."""
        if self._sc2_proc:
            self._sc2_proc.close()

    def seed(self, seed):
        """Returns the random seed used by the environment."""
        self._seed = seed

    def render(self):
        """Use save_replay instead"""
        pass

    def _kill_all_units(self):
        """Kill all units on the map."""
        units_alive = [unit.tag for unit in self.agents.values() if unit.health > 0] + [
            unit.tag for unit in self.enemies.values() if unit.health > 0
        ]
        debug_command = [
            d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive))
        ]
        self._controller.debug(debug_command)

    def init_units(self):
        """Initialise the units."""
        while True:
            # Sometimes not all units have yet been created by SC2
            self.agents = {}
            self.enemies = {}

            ally_units = [
                unit for unit in self._obs.observation.raw_data.units if unit.owner == 1
            ]
            ally_units_sorted = sorted(
                ally_units,
                key=attrgetter("unit_type", "pos.x", "pos.y"),
                reverse=False,
            )

            for i in range(len(ally_units_sorted)):
                self.agents[i] = ally_units_sorted[i]
                if self.debug:
                    logging.debug(
                        "Unit {} is {}, x = {}, y = {}".format(
                            len(self.agents),
                            self.agents[i].unit_type,
                            self.agents[i].pos.x,
                            self.agents[i].pos.y,
                        )
                    )

            for unit in self._obs.observation.raw_data.units:
                if unit.owner == 2:
                    self.enemies[len(self.enemies)] = unit
                    if self._episode_count == 0:
                        self.max_reward += unit.health_max + unit.shield_max

            if self._episode_count == 0:
                min_unit_type = min(unit.unit_type for unit in self.agents.values())
                self._init_ally_unit_types(min_unit_type)

            all_agents_created = len(self.agents) == self.n_agents
            all_enemies_created = len(self.enemies) == self.n_enemies

            if all_agents_created and all_enemies_created:  # all good
                return

            try:
                self._controller.step(1)
                self._obs = self._controller.observe()
            except (protocol.ProtocolError, protocol.ConnectionError):
                self.full_restart()
                self.reset()

    def update_units(self):
        """Update units after an environment step.
        This function assumes that self._obs is up-to-date.
        """
        n_ally_alive = 0
        n_enemy_alive = 0

        # Store previous state
        self.previous_ally_units = deepcopy(self.agents)
        self.previous_enemy_units = deepcopy(self.enemies)

        for al_id, al_unit in self.agents.items():
            updated = False
            for unit in self._obs.observation.raw_data.units:
                if al_unit.tag == unit.tag:
                    self.agents[al_id] = unit
                    updated = True
                    n_ally_alive += 1
                    break

            if not updated:  # dead
                al_unit.health = 0

        for e_id, e_unit in self.enemies.items():
            updated = False
            for unit in self._obs.observation.raw_data.units:
                if e_unit.tag == unit.tag:
                    self.enemies[e_id] = unit
                    updated = True
                    n_enemy_alive += 1
                    break

            if not updated:  # dead
                e_unit.health = 0

        if n_ally_alive == 0 and n_enemy_alive > 0 or self.only_medivac_left(ally=True):
            return -1  # lost
        if (
            n_ally_alive > 0
            and n_enemy_alive == 0
            or self.only_medivac_left(ally=False)
        ):
            return 1  # won
        if n_ally_alive == 0 and n_enemy_alive == 0:
            return 0

        return None

    def _init_ally_unit_types(self, min_unit_type):
        """Initialise ally unit types. Should be called once from the
        init_units function.
        """
        self._min_unit_type = min_unit_type
        if self.map_type == "marines":
            self.marine_id = min_unit_type
        elif self.map_type == "stalkers_and_zealots":
            self.stalker_id = min_unit_type
            self.zealot_id = min_unit_type + 1
        elif self.map_type == "colossi_stalkers_zealots":
            self.colossus_id = min_unit_type
            self.stalker_id = min_unit_type + 1
            self.zealot_id = min_unit_type + 2
        elif self.map_type == "MMM":
            self.marauder_id = min_unit_type
            self.marine_id = min_unit_type + 1
            self.medivac_id = min_unit_type + 2
        elif self.map_type == "zealots":
            self.zealot_id = min_unit_type
        elif self.map_type == "hydralisks":
            self.hydralisk_id = min_unit_type
        elif self.map_type == "stalkers":
            self.stalker_id = min_unit_type
        elif self.map_type == "colossus":
            self.colossus_id = min_unit_type
        elif self.map_type == "bane":
            self.baneling_id = min_unit_type
            self.zergling_id = min_unit_type + 1

    def only_medivac_left(self, ally):
        """Check if only Medivac units are left."""
        if self.map_type != "MMM":
            return False

        if ally:
            units_alive = [
                a
                for a in self.agents.values()
                if (a.health > 0 and a.unit_type != self.medivac_id)
            ]
            if len(units_alive) == 0:
                return True
            return False
        else:
            units_alive = [
                a
                for a in self.enemies.values()
                if (a.health > 0 and a.unit_type != self.medivac_id)
            ]
            if len(units_alive) == 1 and units_alive[0].unit_type == 54:
                return True
            return False

    def get_unit_by_id(self, a_id):
        """Get unit by ID."""
        return self.agents[a_id]

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "battles_draw": self.timeouts,
            "win_rate": self.battles_won / self.battles_game,
            "timeouts": self.timeouts,
            "restarts": self.force_restarts,
        }
        return stats
