from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging


class Spread1Env(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
        self,
        map_name="8m",
        step_mul=None,
        move_amount=2,
        difficulty="7",
        game_version=None,
        seed=None,
        continuing_episode=False,
        obs_all_health=True,
        obs_own_health=True,
        obs_last_action=False,
        obs_pathing_grid=False,
        obs_terrain_height=False,
        obs_instead_of_state=False,
        obs_timestep_number=False,
        state_last_action=True,
        state_timestep_number=False,
        reward_sparse=False,
        reward_only_positive=True,
        reward_death_value=10,
        reward_win=300,
        reward_defeat=0,
        reward_negative_scale=0.5,
        reward_scale=True,
        reward_scale_rate=20,
        replay_dir="",
        replay_prefix="",
        window_size_x=1920,
        window_size_y=1200,
        debug=False,
        n_agents=10,
        grid_size=20,
        sight_range=1,
        random_start=True,
        noise1=False,
        noise2=False,
        number_fake_landmarks=None
    ):
        # Map arguments
        self.n_agents = n_agents

        # Observations and state
        self.continuing_episode = continuing_episode
        self.obs_instead_of_state = obs_instead_of_state
        self.obs_last_action = obs_last_action
        self.obs_pathing_grid = obs_pathing_grid
        self.obs_terrain_height = obs_terrain_height
        self.state_last_action = state_last_action

        # Rewards args
        self.reward_sparse = reward_sparse
        self.reward_only_positive = reward_only_positive
        self.reward_negative_scale = reward_negative_scale
        self.reward_death_value = reward_death_value
        self.reward_win = reward_win
        self.reward_defeat = reward_defeat
        self.reward_scale = reward_scale
        self.reward_scale_rate = reward_scale_rate

        # Other
        self._seed = seed
        self.debug = debug
        self.window_size = (window_size_x, window_size_y)
        self.replay_dir = replay_dir
        self.replay_prefix = replay_prefix

        # Actions
        self.n_actions = 5

        # Map info
        # self._agent_race = map_params["a_race"]
        # self._bot_race = map_params["b_race"]
        # self.map_type = map_params["map_type"]

        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.last_stats = None
        self.previous_ally_units = None
        self.previous_enemy_units = None
        self.last_action = np.zeros((self.n_agents, self.n_actions))
        self._min_unit_type = 0
        self.max_distance_x = 0
        self.max_distance_y = 0
        
        self.battles_game = 0
        self.battles_won = 0
        self.episode_limit = 50
        self.sight_range = sight_range
        self.random_start = random_start
        self.timeouts = 0

        self.grid_size = grid_size
        self.noise1 = noise1
        self.noise2 = noise2
        if number_fake_landmarks:
            self.number_fake_landmarks = number_fake_landmarks
        else:
            self.number_fake_landmarks = self.grid_size * self.grid_size // 10

        # Try to avoid leaking SC2 processes on shutdown
        # atexit.register(lambda: self.close())
        
        # Create Grid
        self.grid = np.zeros((self.grid_size + self.sight_range*2, self.grid_size + self.sight_range*2))
        self.l_grid = np.zeros((self.grid_size + self.sight_range*2, self.grid_size + self.sight_range*2))
        self.grid[:, np.arange(self.sight_range)] = -1
        self.grid[np.arange(self.sight_range), :] = -1
        self.grid[:, np.arange(self.sight_range) - self.sight_range] = -1
        self.grid[np.arange(self.sight_range) - self.sight_range, :] = -1

        # Initial positions
        if self.random_start:
            # Landmarks
            self.landmarks = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
                              for _ in range(self.n_agents)]
            for l in self.landmarks:
                self.l_grid[tuple(l)] = 2

            # Agents
            self.agents = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
                           for _ in range(self.n_agents)]
            for a in self.agents:
                self.grid[tuple(a)] = 1
        else:
            if self.n_agents == 5:
                # Landmarks
                self.landmarks = [np.array([3, self.grid_size-3])+self.sight_range,
                                  np.array([3, 3])+self.sight_range,
                                  np.array([self.grid_size-3, 3])+self.sight_range,
                                  np.array([self.grid_size-3, self.grid_size-3])+self.sight_range,
                                  np.array([self.grid_size//2, self.grid_size//2])+self.sight_range]
                for l in self.landmarks:
                    self.l_grid[tuple(l)] = 2

                # Agents
                a = [np.array([1, self.grid_size-1])+self.sight_range,
                     np.array([1, 1])+self.sight_range,
                     np.array([self.grid_size-1, 1])+self.sight_range,
                     np.array([self.grid_size-1, self.grid_size-1])+self.sight_range,
                     np.array([self.grid_size//2, self.grid_size//2])+self.sight_range]

                index = np.random.choice(4, 3, replace=False)
                self.agents = [None for _ in range(self.n_agents)]
                self.agents[0] = a[index[0]]
                self.agents[1] = a[index[1]]
                self.agents[2] = a[index[2]]
                self.agents[3] = a[4]
                self.agents[4] = a[4]

                for a in self.agents:
                    self.grid[tuple(a)] = 1

        # Fake Landmarks
        if self.noise1:
            # Add some "Fake Landmarks"
            self.f_landmarks = []
            while len(self.f_landmarks) < self.number_fake_landmarks:
                nl = np.random.randint(self.grid_size, size=(2)) + self.sight_range
                if np.array([(nl == m_ll).all() for m_ll in (self.landmarks + self.f_landmarks)]).any():
                    continue
                else:
                    self.f_landmarks.append(nl)

            for l in self.f_landmarks:
                self.l_grid[tuple(l)] = 5

        # Others

        self.t_step = 0
        self.last_action = np.zeros(shape=(self.n_agents, self.n_actions))

    def step(self, actions):
        """Returns reward, terminated, info."""
        self.t_step += 1
        info = {'battle_won': False}
        
        actions = [int(a) for a in actions]
        self.last_action = np.eye(self.n_actions)[np.array(actions)]

        # Tack Action:
        agents = self.agents.copy()
        for i, action in enumerate(actions):
            new_x = self.agents[i][0]
            new_y = self.agents[i][1]
            
            if action == 0:
                new_x = max(new_x - 1, self.sight_range)
            elif action == 1:
                new_y = min(new_y + 1, self.sight_range + self.grid_size - 1)
            elif action == 2:
                new_x = min(new_x + 1, self.sight_range + self.grid_size - 1)
            elif action == 3:
                new_y = max(new_y - 1, self.sight_range)

            self.agents[i] = np.array([new_x, new_y])

        for m_agent in agents:
            self.grid[tuple(m_agent)] = 0.
        for m_agent in self.agents:
            self.grid[tuple(m_agent)] = 1.

        # print(str((self.grid + self.l_grid).astype(dtype=np.int8)).replace('1', u"\u2588").replace(u"-\u2588", ' 0').replace('3', u"\u2588"))
        # input()

        # Rewards and Dones
        reward = 0
        count = 0
        
        for l in self.landmarks:
            around = self.grid[l[0] - 1: l[0] + 2,
                               l[1] - 1: l[1] + 2]
            
            if (around == 1).any():
                reward += 1
                count += 1
        
        terminated = False

        if self.t_step >= self.episode_limit:
            terminated = True
            # print('Lose')
            if self.continuing_episode:
                info['episode_limit'] = True
            self.timeouts += 1

        if count == self.n_agents:
            terminated = True
            reward += self.reward_win + (self.episode_limit - self.t_step) * 10
            # print('Win')
            info['battle_won'] = True
            self.battles_won += 1

        if terminated:
            self.battles_game += 1

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        x = self.agents[agent_id][0]
        y = self.agents[agent_id][1]
        
        around = self.grid[x-self.sight_range: x+self.sight_range+1, y-self.sight_range: y+self.sight_range+1].reshape(-1) + \
                 self.l_grid[x - self.sight_range: x + self.sight_range + 1,
                 y - self.sight_range: y + self.sight_range + 1].reshape(-1)

        around = np.concatenate([around, self.agents[agent_id]])

        obs = None

        if self.obs_last_action:
            action = self.last_action.copy()

            for ii, agent in enumerate(self.agents):
                if ii != agent_id:
                    if not (abs(agent[0] - self.agents[agent_id][0]) <= self.sight_range and abs(agent[1] - self.agents[agent_id][1]) <= self.sight_range):
                            action[ii] *= 0.
                        
            action = action.reshape(-1)
            obs =  np.concatenate([around, action])
        else:
            obs =  around

        if self.noise2:
            obs += np.random.normal(loc=0.0, scale=0.2, size=obs.shape)

        return obs

    def get_obs_size(self):
        if self.obs_last_action:
            return (2*self.sight_range + 1) * (2*self.sight_range + 1) + self.n_actions * self.n_agents + 2
        else:
            return (2*self.sight_range + 1) * (2*self.sight_range + 1) + 2

    def get_state(self):
        """Returns the global state."""
        grid = self.grid[self.sight_range: self.sight_range + self.grid_size,
               self.sight_range: self.sight_range + self.grid_size].reshape(-1) + \
               self.l_grid[self.sight_range: self.sight_range + self.grid_size,
               self.sight_range: self.sight_range + self.grid_size].reshape(-1)

        if self.state_last_action:
            action = self.last_action.copy().reshape(-1)
            return np.concatenate([grid, action])
        else:
            return grid

    def get_state_size(self):
        """Returns the size of the global state."""
        if self.state_last_action:
            return self.grid_size * self.grid_size + self.n_agents * self.n_actions
        else:
            return self.grid_size * self.grid_size

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        # Create Grid
        self.grid_size = self.grid_size
        self.grid = np.zeros((self.grid_size + self.sight_range*2, self.grid_size + self.sight_range*2))
        self.l_grid = np.zeros((self.grid_size + self.sight_range*2, self.grid_size + self.sight_range*2))
        self.grid[:, np.arange(self.sight_range)] = -1
        self.grid[np.arange(self.sight_range), :] = -1
        self.grid[:, np.arange(self.sight_range) - self.sight_range] = -1
        self.grid[np.arange(self.sight_range) - self.sight_range, :] = -1

        # Initial positions
        if self.random_start:
            # Landmarks
            self.landmarks = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
                              for _ in range(self.n_agents)]
            for l in self.landmarks:
                self.l_grid[tuple(l)] = 2

            # Agents
            self.agents = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
                           for _ in range(self.n_agents)]
            for a in self.agents:
                self.grid[tuple(a)] = 1
        else:
            if self.n_agents == 5:
                # Landmarks
                self.landmarks = [np.array([3, self.grid_size-3])+self.sight_range,
                                  np.array([3, 3])+self.sight_range,
                                  np.array([self.grid_size-3, 3])+self.sight_range,
                                  np.array([self.grid_size-3, self.grid_size-3])+self.sight_range,
                                  np.array([self.grid_size//2, self.grid_size//2])+self.sight_range]
                for l in self.landmarks:
                    self.l_grid[tuple(l)] = 2

                # Agents
                a = [np.array([1, self.grid_size-1])+self.sight_range,
                     np.array([1, 1])+self.sight_range,
                     np.array([self.grid_size-1, 1])+self.sight_range,
                     np.array([self.grid_size-1, self.grid_size-1])+self.sight_range,
                     np.array([self.grid_size//2, self.grid_size//2])+self.sight_range]

                index = np.random.choice(4, 3, replace=False)
                self.agents[0] = a[index[0]]
                self.agents[1] = a[index[1]]
                self.agents[2] = a[index[2]]
                self.agents[3] = a[4]
                self.agents[4] = a[4]

                for a in self.agents:
                    self.grid[tuple(a)] = 1


        # Fake Landmarks
        if self.noise1:
            # Add some "Fake Landmark"
            self.f_landmarks = []
            while len(self.f_landmarks) < self.number_fake_landmarks:
                nl = np.random.randint(self.grid_size, size=(2)) + self.sight_range
                if np.array([(nl == m_ll).all() for m_ll in (self.landmarks + self.f_landmarks)]).any():
                    continue
                else:
                    self.f_landmarks.append(nl)

            for l in self.f_landmarks:
                self.l_grid[tuple(l)] = 5

        # Others
        self.t_step = 0
        self.last_action = np.zeros(shape=(self.n_agents, self.n_actions))

        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "battles_draw": self.timeouts,
            "win_rate": float(self.battles_won) / self.battles_game,
            "timeouts": self.timeouts,
            "restarts": 0
        }
        return stats
