from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import numpy as np
import enum
import os
import json
import copy
import random
import itertools

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.patches import Rectangle


actions = {
    "move": 16,  # target: PointOrUnit
    "attack": 23,  # target: PointOrUnit
    "stop": 4,  # target: None
    "heal": 386,  # Unit
}


class Direction(enum.IntEnum):
    NORTH = 0
    SOUTH = 1
    EAST = 2
    WEST = 3


class Agent:
    def __init__(self, row, col):
        self.row = row
        self.col = col

    @property
    def pos(self):
        return self.row, self.col


class ShapesEnv(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
        self,
        obs_last_action=False,
        state_last_action=True,
        reward_win=200,
        is_print=False,
        debug=False,
        num_agents=4,
        sight_range=3,
        obs_coordinates=False,
        obs_time_step=False,
        obs_goals=False,
        obs_agent_ids=True,
        normalize=True,
        step_size=1,
        fix_spawn=False,
        fix_image=False,
        split='train.large',  # train.med, train.small, train.tiny, val, test
        task='colors.4,0,0',
        episode_limit=30,
        size=30,
        data_dir='./shapes/shapes_3x3_single_red/',
        seed=None
    ):

        # Map arguments
        self.num_agents = num_agents
        # Actions
        self.n_actions = 5

        # Observations and state
        self.obs_last_action = obs_last_action
        self.obs_time_step = obs_time_step
        self.obs_goals = obs_goals
        self.obs_agent_ids = obs_agent_ids
        self.obs_coordinates = obs_coordinates

        self.state_last_action = state_last_action

        # Rewards args
        self.reward_win = reward_win

        # Other
        self.debug = debug
        self.is_print = is_print

        # Map info
        # self._agent_race = map_params["a_race"]
        # self._bot_race = map_params["b_race"]
        # self.map_type = map_params["map_type"]

        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.last_action = np.zeros((self.num_agents, self.n_actions))

        # Try to avoid leaking SC2 processes on shutdown
        # atexit.register(lambda: self.close())

        # #######
        # Shapes modification

        self.step_size = step_size
        self.episode_limit = episode_limit
        self.obs_height = sight_range
        self.obs_width = sight_range
        self.fix_spawn = fix_spawn
        self.fix_image = fix_image
        self.normalize = normalize
        self.task = task
        self.split = split

        # Load data
        self.data_root = os.path.join(data_dir, split)

        self.data = np.load(self.data_root + '.input.npy')[:, :, :, ::-1]
        self.attr = {
            'shapes': json.load(open(self.data_root + '.shapes', 'r')),
            'colors': json.load(open(self.data_root + '.colors', 'r')),
            'sizes': json.load(open(self.data_root + '.sizes', 'r'))
        }

        self.n_imgs = self.data.shape[0]
        self.img_shape = self.data[0].shape
        self.img_obs_shape = (self.obs_height, self.obs_width,
                              self.img_shape[2])

        self.task = []
        self._parse_task(task)

        self.max_row = self.img_shape[0] - self.obs_height
        self.max_col = self.img_shape[1] - self.obs_width

        # [NOTE] these constants assume 30 x 30 image has 3x3 grid
        self.max_row_cells = int(self.img_shape[0] / 10)
        self.max_col_cells = int(self.img_shape[1] / 10)

        self.n_steps = 0
        self.visited_states = []
        self.max_states = (self.max_row // self.step_size) * (
            self.max_col // self.step_size)

        self.image_idx = idx = self._get_image_id()
        attr = self.attr[self.task][idx]

        # make sure sampled image has all goals
        while len([
                x for x in list(set(itertools.chain.from_iterable(attr)))
                if isinstance(x, int)
        ]) != 3:
            idx = np.random.randint(self.n_imgs)
            attr = self.attr[self.task][idx]

        self.image = copy.deepcopy(self.data[idx])
        self.success_maps = [np.array(attr) == i for i in range(3)]

        agent_positions = self._get_agent_pos()
        self.agents = [Agent(*pos) for pos in agent_positions]
        self.size = size

        self.pos_eye = np.eye(size)

    def step(self, actions):
        """Returns reward, terminated, info."""
        info = {}
        self.n_steps += 1

        actions = [int(a) for a in actions]
        self.last_action = np.eye(self.n_actions)[np.array(actions)]

        for agent, action in zip(self.agents, actions):
            self._move(agent, action)

        # team reward
        reward = sum([
            self._on_goal(self.agents[i], self.success_maps[self.task_id[i]])
            for i in range(self.num_agents)
        ]) / self.num_agents
        # individual reward
        # reward = np.array([
        #     self._on_goal(self.agents[i], self.success_maps[self.task_id[i]])
        #     for i in range(self.num_agents)
        # ], dtype=np.float32)
        done = (self.n_steps == self.episode_limit)

        info['success'] = self.is_success()

        if done:
            self.battles_game += 1

        if info['success']:
            done = True
            self.battles_won += 1
            self.battles_game += 1

        info['coverage'] = self.get_coverage()
        info['image_idx'] = self.image_idx
        info['n_steps'] = self.n_steps

        return reward, done, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        agents_obs = [self.get_obs_agent(i) for i in range(self.num_agents)]
        return agents_obs

    def get_obs_agent(self, agent_id):
        img_obs = self._get_img_obs(self.agents[agent_id])

        joint_obs = img_obs.flatten()

        # Coordinate observations
        if self.obs_coordinates:
            coordinates = np.array(self.agents[agent_id].pos)

            if self.normalize:
                coordinates[0] = coordinates[0] / self.max_row
                coordinates[1] = coordinates[1] / self.max_col

            joint_obs = np.concatenate([joint_obs, coordinates])

        # Time step observation
        if self.obs_time_step:
            time_step = np.array([self.n_steps])

            if self.normalize:
                time_step = time_step / self.episode_limit

            joint_obs = np.concatenate([joint_obs, time_step])

        # Agent ids
        #
        # Sequence doesn't mean much
        # should be used to index a lookup table
        if self.obs_agent_ids:
            agents = np.array([agent_id / len(self.agents)])
            joint_obs = np.concatenate([joint_obs, agents])

        # Goal ids
        #
        # Sequence doesn't mean much
        # should be used to index a lookup table
        if self.obs_goals:
            goals = np.array([self.task_id[agent_id]])
            joint_obs = np.concatenate([joint_obs, goals])

        if self.obs_last_action:
            last_action = self.last_action.reshape(self.num_agents, -1)

            joint_obs = np.concatenate([joint_obs, last_action[agent_id]])

        return joint_obs

    def get_obs_size(self):
        """Returns the size of the observation."""
        len_obs = self.img_obs_shape[0] * self.img_obs_shape[
            1] * self.img_obs_shape[2]
        if self.obs_coordinates:
            len_obs += 2
        if self.obs_time_step:
            len_obs += 1
        if self.obs_agent_ids:
            len_obs += 1
        if self.obs_goals:
            len_obs += 1
        if self.obs_last_action:
            len_obs += self.n_actions

        return len_obs

    def get_state(self):
        """Returns the global state."""

        s = [self.image.flatten() / 255., np.concatenate([
            np.concatenate([self.pos_eye[int(agent.pos[0])], self.pos_eye[int(agent.pos[1])]])
            for agent in self.agents])]

        return np.concatenate(s)

    def get_state_size(self):
        """Returns the size of the global state."""
        state_len = 0
        size = self.image.shape
        state_len += np.prod(size)

        state_len += self.num_agents * 2 * self.size

        return int(state_len)

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.num_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self.n_steps = 0
        self.visited_states = []
        self.max_states = (self.max_row // self.step_size) * (
            self.max_col // self.step_size)

        self.image_idx = idx = self._get_image_id()
        attr = self.attr[self.task][idx]

        # make sure sampled image has all goals
        while len([
                x for x in list(set(itertools.chain.from_iterable(attr)))
                if isinstance(x, int)
        ]) != 3:
            idx = np.random.randint(self.n_imgs)
            attr = self.attr[self.task][idx]

        self.image = copy.deepcopy(self.data[idx])
        self.success_maps = [np.array(attr) == i for i in range(3)]

        agent_positions = self._get_agent_pos()
        self.agents = [Agent(*pos) for pos in agent_positions]

        return self.get_obs(), self.get_state()

    def render(self):
        aux = {}

        visions = np.zeros(self.img_shape, dtype=int)
        positions = np.ones(self.img_shape, dtype=int) * 255

        for agent in self.agents:
            r, c = agent.pos
            h, w = self.obs_height, self.obs_width
            visions[r:r + h, c:c + w, :] = self.image[r:r + h, c:c + w, :]

        hdim = 3
        if 'p_attn' in aux:
            hdim += 1

        fig = plt.figure(figsize=(4 * hdim, 4))

        ax = plt.subplot(1, hdim, 1)
        plt.imshow(self.image)
        for i in range(len(self.agents)):
            r, c = self.agents[i].pos
            p = Rectangle(
                (c - 0.5, r - 0.5),
                self.obs_width,
                self.obs_height,
                fill=False,
                edgecolor='white')
            ax.add_patch(p)
            ax.text(c + 1.5, r + 3, str(i + 1), color='white')
        plt.title('Complete state')

        ax = plt.subplot(1, hdim, 2)
        plt.imshow(visions)
        for i in range(len(self.agents)):
            r, c = self.agents[i].pos
            p = Rectangle(
                (c - 0.5, r - 0.5),
                self.obs_width,
                self.obs_height,
                fill=False,
                edgecolor='white')
            ax.add_patch(p)
            ax.text(c + 1.5, r + 3, str(i + 1), color='white')
        plt.title('Observed state (jointly by all)')

        ax = plt.subplot(1, hdim, 3)
        for i in range(len(self.agents)):
            r, c = self.agents[i].pos
            plt_r = 14 * int(i // (self.max_col / (self.obs_width + 2)))
            plt_c = 2 + 7 * int(i % (self.max_col / (self.obs_width + 2)))
            positions[plt_r:plt_r + self.obs_height, plt_c:
                      plt_c + self.obs_width] = self.image[r:r + h, c:c + w, :]
            ax.text(
                plt_c,
                plt_r + 7,
                "(%02d,%02d)" % (r, c),
                color='black',
                fontsize=6)
            ax.text(plt_c + 1.5, plt_r + 10, str(i + 1), color='black')
        ax.axis('off')
        plt.imshow(positions)
        plt.title('Agent views & positions')
        plt.tick_params(
            which='both',
            left=False,
            bottom=False,
            labelbottom=False,
            labelleft=False)

        if 'p_attn' in aux:
            ax = plt.subplot(1, hdim, 4)
            aux['p_attn'] = aux['p_attn'].data.numpy()
            plt.imshow(aux['p_attn'], cmap='Blues', vmin=0.0, vmax=1.0)

            plt.xlabel('Sender')
            plt.ylabel('Receiver')

            ticks = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x + 1))
            loc = ticker.MultipleLocator(base=1.0)
            ax.xaxis.set_major_formatter(ticks)
            ax.xaxis.set_major_locator(loc)
            ax.yaxis.set_major_formatter(ticks)
            ax.yaxis.set_major_locator(loc)

            plt.title('Communication attention')

        save_dir = './shapes/viz'

        plt.savefig("%s/%05d" % (save_dir, self.n_steps))
        plt.show()
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def _get_img_obs(self, agent):
        obs = self.image[agent.row:agent.row + self.obs_height, agent.col:
                         agent.col + self.obs_width, :]

        if self.normalize:
            obs = obs / 255.0

        return obs

    def _get_agent_pos(self, agent_positions=None):
        if self.fix_spawn:
            # corners
            pos = []
            for i in range(self.num_agents):
                if i % 4 == 0:
                    pos.append((0, 0))
                elif i % 4 == 1:
                    pos.append((0, self.max_col))
                elif i % 4 == 2:
                    pos.append((self.max_row, self.max_col))
                elif i % 4 == 3:
                    pos.append((self.max_row, 0))
            return pos

        if agent_positions is not None:
            return agent_positions
        else:
            return [(random.randint(0, self.max_row),
                     random.randint(0, self.max_col))
                    for _ in range(self.num_agents)]

    def _on_goal(self, agent, success_map):
        cell_row = (agent.row + int(self.obs_height / 2)) // 10
        cell_col = (agent.col + int(self.obs_width / 2)) // 10

        assert cell_row < self.max_row_cells
        assert cell_col < self.max_col_cells

        return success_map[cell_row, cell_col]

    def is_success(self):
        return sum([
            self._on_goal(self.agents[i], self.success_maps[self.task_id[i]])
            for i in range(self.num_agents)
        ]) == self.num_agents

    # [TODO] this is potentially buggy
    def get_coverage(self):
        for agent in self.agents:
            r, c = agent.pos
            p = r * self.max_row + c
            if p not in self.visited_states:
                self.visited_states.append(p)
        return (1.0 * len(self.visited_states)) / self.max_states

    def _get_image_id(self, img_idx=None):
        if self.fix_image:
            return 25

        if img_idx is not None:
            return img_idx
        else:
            return np.random.randint(self.n_imgs)

    def _parse_task(self, task):
        self.task = task.split('.')[0]
        agent_goal_counts = [int(x) for x in task.split('.')[1].split(',')]
        assert sum(
            agent_goal_counts
        ) == self.num_agents, "No. of agents does not match goal definition"

        self.task_id = []
        for i in range(len(agent_goal_counts)):
            for j in range(agent_goal_counts[i]):
                self.task_id.append(i)

    def _move(self, agent, action):
        assert action in range(5)

        if action == 0:  # Move get_obsNorth
            agent.row = max(agent.row - self.step_size, 0)
        elif action == 1:  # Move East
            agent.col = min(agent.col + self.step_size, self.max_col)
        elif action == 2:  # Move South
            agent.row = min(agent.row + self.step_size, self.max_row)
        elif action == 3:  # Move West
            agent.col = max(agent.col - self.step_size, 0)
        else:
            pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.num_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats


if __name__ == '__main__':
    env = ShapesEnv()
    env.reset()

    for _ in range(100):
        x = env.step([np.random.randint(0, 4) for _ in range(4)])
        state = env.get_state()
