from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging
import random
import torch as th


class spread_xEnv(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            reward_win=1000,
            episode_limit=50,
            scenario_name="simple_spread_x",
            seed=None
    ):
        # Map arguments
        self._seed = random.randint(0, 9999)
        np.random.seed(self._seed)
        self.n_agents = 2

        # Other
        self._seed = seed

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.episode_limit = episode_limit

        from multiagent.environment import MultiAgentEnv as MultiAgentEnv_x
        import multiagent.scenarios as scenarios

        # load scenario from script
        scenario = scenarios.load(scenario_name + ".py").Scenario()
        # create world
        world = scenario.make_world()
        # create multiagent environment
        self.env = MultiAgentEnv_x(world, scenario.reset_world, scenario.reward, scenario.observation,
                                   done_callback=scenario.done)
        self.obs = self.env.reset()

        # Actions
        self.n_actions = 5

        # Qatten
        self.unit_dim = self.get_obs_size()

    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}

        new_actions_onehot = th.zeros((self.n_agents, self.n_actions)).cuda()
        new_actions_onehot = new_actions_onehot.scatter_(1, actions.unsqueeze(1), 1)
        self.obs, rew_n, done_n, info_n = self.env.step(new_actions_onehot)

        info['battle_won'] = False
        info['battle_lose'] = False

        if all(done_n):
            info['battle_won'] = True
        elif any(done_n):
            info['battle_lose'] = True

        terminated = any(done_n)

        if self._episode_steps >= self.episode_limit:
            terminated = True

        if terminated:
            self._episode_count += 1
            self.battles_game += 1

        return rew_n[0], terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return self.obs

    def get_obs_size(self):
        """Returns the size of the observation."""
        shape = self.env.observation_space[0].shape
        size = 1
        for item in shape:
            size *= item
        return size

    def get_state(self):
        """Returns the global state."""
        return np.concatenate([self.obs[_].flatten() for _ in range(self.n_agents)], axis=0)

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.get_obs_size() * self.env.n

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0
        self.obs = self.env.reset()

        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "unit_dim": self.unit_dim}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats
