from envs.multiagentenv import MultiAgentEnv

import numpy as np

class matrixgameEnv(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            n_actions=3,
            rewards=[[2,2,0],[2,2,0],[0,0,3]],
            episode_limit=None,
            is_print=False,
            print_rew=True,
            print_steps=1000,
            seed=None
    ):
        # Map arguments
        self.print_rew = print_rew
        self.is_print = is_print
        self.print_steps = print_steps
        
        self.n_agents = 2
        self.n_states = 1
        self.n_actions = n_actions
        self.rewards = rewards
        self.max_reward = int(np.max(rewards))
        # print(self.rewards)
        # print(self.max_reward)

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False

        if episode_limit is not None:
            self.episode_limit = episode_limit
        else:
            self.episode_limit = 1


    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}
        opt = True
        
        if self.is_print:
            print('t_steps: %d' % self._episode_steps)
            print(actions)

        reward = self.rewards[actions[0]][actions[1]]
        if reward != self.max_reward:
            opt = False
        terminated = False

        if self._episode_steps >= self.episode_limit:
            terminated = True
            info["battle_won"] = opt
            self._episode_count += 1
            self.battles_game += 1
            if opt:
                self.battles_won += opt

        if self.print_rew:
            self.p_step += 1
            if terminated:
                self.rew_gather.append(reward)
            if self.p_step % self.print_steps == 0:
                print('steps: %d, average rew: %.3lf' % (self.p_step,
                                                         float(np.mean(self.rew_gather))))
                self.is_print_once = True

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return np.array([0])

    def get_obs_size(self):
        """Returns the size of the observation."""
        return 1

    def get_state(self):
        """Returns the global state."""
        return np.array([0]*self.n_agents)

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.n_agents

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0
        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats

    def clean(self):
        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False
