from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .multiagentenv import MultiAgentEnv
from matplotlib.colors import Normalize
from matplotlib.cm import viridis
import numpy as np
import random


class HallwayEnv(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            n_agents=5,
            n_groups=2,
            state_numbers=[4,4,4,4,4],
            group_ids=[0,0,1,1,1],
            reward_win=10,
            obs_last_action=False,
            state_last_action=True,
            seed=None
    ):
        # Map arguments
        self._seed = random.randint(0, 9999)
        np.random.seed(self._seed)
        self.n_agents = n_agents
        self.n_states = np.array(state_numbers,
                                 dtype=int)

        # Observations and state
        self.obs_last_action = obs_last_action
        self.state_last_action = state_last_action

        # Rewards args
        self.reward_win = reward_win

        # Other
        self._seed = seed

        # Actions
        self.n_actions = 3
        self.n_groups = n_groups
        self.group_ids = np.array(group_ids)

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.last_action = np.zeros((self.n_agents, self.n_actions))

        self.episode_limit = max(state_numbers) + 10

        # initialize agents
        self.state_n = np.array([np.random.randint(low=1, high=self.n_states[i]+1) for i in range(self.n_agents)],
                                dtype=int)

        self.episode_states = [self.state_n.copy()]

        # self.group_members = [[] for _ in range(self.n_groups)]
        # self.status_by_group = [[] for _ in range(self.n_groups)]
        # for agent_i, group_i in enumerate(self.group_ids):
        #     self.group_members[group_i].append(agent_i)
        #     self.status_by_group[group_i].append(0)

        self.active_group = [True for _ in range(self.n_groups)]
        self.active_agent = np.array([True for _ in range(self.n_agents)])
        self._win_group = 0
        self._fail_group = 0

    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}

        last_state = self.state_n

        for agent_i, action in enumerate(actions):
            if self.active_agent[agent_i]:
                if action == 0:
                    pass
                elif action == 1:
                    self.state_n[agent_i] = max(0, self.state_n[agent_i] - 1)
                elif action == 2:
                    self.state_n[agent_i] = min(self.n_states[agent_i], self.state_n[agent_i] + 1)

        reward = 0
        terminated = False
        info['battle_won'] = False

        win_in_this_round = 0
        win_agents = np.array([False for _ in range(self.n_agents)])

        for group_i in range(self.n_groups):
            if self.active_group[group_i]:
                id = self.state_n[self.group_ids == group_i]

                if (id == 0).all():
                    reward += self.reward_win
                    self._win_group += 1
                    self.active_group[group_i] = False
                    self.active_agent[self.group_ids == group_i] = False
                    win_agents[self.group_ids == group_i] = True
                    win_in_this_round += 1
                elif (id == 0).any():
                    self.active_group[group_i] = False
                    self.active_agent[self.group_ids == group_i] = False
                    self._fail_group += 1

        info['win_group'] = self._win_group

        if win_in_this_round > 1:
            self._win_group -= win_in_this_round
            reward -= self.reward_win * 1.5 * win_in_this_round
            self.active_agent[win_agents] = True
            self.state_n[win_agents] = last_state[win_agents]

        if self._win_group == self.n_groups:
            terminated = True
            self.battles_won += 1
            info['battle_won'] = True
        elif self._fail_group == self.n_groups:
            terminated = True

        if self._episode_steps >= self.episode_limit:
            terminated = True

        if terminated:
            self._episode_count += 1
            self.battles_game += 1
        else:
            self.episode_states.append(self.state_n.copy())

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return np.array([self.state_n[agent_id], float(self.active_agent[agent_id])])

    def get_obs_size(self):
        """Returns the size of the observation."""
        return 2

    def get_state(self):
        """Returns the global state."""
        return np.concatenate(self.get_obs())

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.n_agents * self.get_obs_size()

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0
        self.last_action = np.zeros((self.n_agents, self.n_actions))
        self.state_n = np.array([np.random.randint(low=1, high=self.n_states[i]+1) for i in range(self.n_agents)],
                                dtype=int)
        self.active_group = [True for _ in range(self.n_groups)]
        self.active_agent = np.array([True for _ in range(self.n_agents)])
        self._win_group = 0
        self._fail_group = 0
        self.episode_states = [self.state_n.copy()]

        return self.get_obs(), self.get_state()

    def render(self, mode, state):
        matrix_view = np.stack([np.eye(self.n_states[0])[e-1] * 10 * (id + 1) for e, id in zip(state, self.group_ids)])

        upscaled_data = np.kron(matrix_view, np.ones((40, 40)))
        # Normalize data to be in the range [0, 1]
        norm = Normalize(vmin=np.min(upscaled_data), vmax=np.max(upscaled_data))
        # Use the Viridis colormap to map values to RGB
        rgb_values = viridis(norm(upscaled_data))

        return rgb_values

    def close(self):
        pass

    def seed(self):
        pass

    def get_episode_frames(self, logger):
        frames = []
        for i in range(len(self.episode_states)):
            frame = self.render(mode='image', state=self.episode_states[i])
            frames.append(frame)
        frames = np.array(frames)
        return frames

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats

    def clean(self):
        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False
