from smac.env import StarCraft2Env
import numpy as np

from typing import List

from .multiagentenv import MultiAgentEnv

class SMACWrapper(MultiAgentEnv):
    def __init__(
        self,
        map_name,
        use_absorbing_state,
        trailing_absorbing_state_length,
        seed,
        **kwargs
    ):
        self.env = StarCraft2Env(map_name=map_name, seed=seed, **kwargs)
        env_info = self.env.get_env_info()
        self.n_actions = env_info["n_actions"]
        self.n_agents = env_info["n_agents"]
        self.obs_shape = (env_info["obs_shape"],)
        self.episode_limit = self.env.episode_limit

        # absorbing state
        self.use_absorbing_state = use_absorbing_state
        self.trailing_absorbing_state_length = trailing_absorbing_state_length
        assert not self.use_absorbing_state or self.trailing_absorbing_state_length > 0
        self.is_trailing_absorbing_state = False
        self.cur_absorbing_state_length = -1
        self.absorbing_obs = np.zeros(self.obs_shape, dtype=np.float32)
        self.absorbing_avail_actions = [0] * self.n_actions
        self.absorbing_avail_actions[0] = 1 # no-op
        self.absorbing_agent_mask = np.zeros((self.n_agents, 1))
        self.absorbing_rewards = np.zeros((self.n_agents, 1))

    def step(self, actions):
        """Returns obss, reward, terminated, truncated, info"""
        result = {}
        if not self.is_trailing_absorbing_state:
            rews, terminated, self.info = self.env.step(actions)

            if self.use_absorbing_state and terminated:
                self.is_trailing_absorbing_state = True
                self.cur_absorbing_state_length = 0
                terminated = False

            obss = self.get_obs()
            avail_actions = self.get_avail_actions()
            agent_mask = self.get_agent_mask()
            rewards = np.ones((self.n_agents, 1), dtype=np.float32) * rews
        else:
            self.cur_absorbing_state_length += 1
            terminated = (self.cur_absorbing_state_length >= self.trailing_absorbing_state_length)
            obss = [self.absorbing_obs] * self.n_agents
            avail_actions = [self.absorbing_avail_actions] * self.n_agents
            rewards = self.absorbing_rewards
            agent_mask = self.absorbing_agent_mask

        result = {
            "obs": np.stack(obss, axis=0),
            "avail_actions": np.stack(avail_actions, axis=0),
            "agent_mask": np.stack(agent_mask, axis=0),
            "rewards": rewards,
            "terminated": np.array([terminated], dtype=np.bool),
            "truncated": np.array([False], dtype=np.bool),
            "is_first": np.array([False], dtype=np.bool),
            "is_trailing_absorbing_state": np.array([self.is_trailing_absorbing_state], dtype=np.bool),
        }
        if terminated:
            result["log_battle_won"] = self.info.get("battle_won", False)

        # result["obs"].shape: (n_agents, obs_dim)
        # result["avail_actions"].shape: (n_agents, n_actions)
        # result["agent_mask"].shape: (n_agents, 1)
        # result["rewards"].shape: (n_agents, 1)
        # result["terminated"].shape: (1,)
        # result["truncated"].shape: (1,)
        # result["is_first"].shape: (1,)
        # result["is_trailing_absorbing_state"].shape: (1,)
        return result

    def get_obs(self) -> List[np.ndarray]:
        """Returns all agent observations in a list"""
        agents_obs = []
        for i in range(self.n_agents):
            if self.env.death_tracker_ally[i] == 1:
                agents_obs.append(self.absorbing_obs)
            else:
                agents_obs.append(self.env.get_obs_agent(i))
        return agents_obs

    def get_obs_size(self):
        """Returns the shape of the observation"""
        return self.env.get_obs_size()

    def get_state(self):
        return self.env.get_state()

    def get_state_size(self):
        """Returns the shape of the state"""
        return self.env.get_state_size()

    def get_avail_actions(self) -> List[np.ndarray]:
        avail_actions = []
        for agent_id in range(self.env.n_agents):
            if self.env.death_tracker_ally[agent_id] == 1:
                avail_agent = self.absorbing_avail_actions
            else:
                avail_agent = self.env.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_agent)
        return avail_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take"""
        return self.env.get_total_actions()
    
    def get_agent_mask(self):
        agent_mask = 1 - self.env.death_tracker_ally[..., None]
        return agent_mask

    def reset(self, seed=None, options=None):
        """Returns initial observations and info"""
        if seed is not None:
            self.env.seed(seed)
        self.env.reset()
        obss = self.get_obs()
        avail_actions = self.get_avail_actions()

        if self.use_absorbing_state:
            self.is_trailing_absorbing_state = False
            self.cur_absorbing_state_length = -1

        result = {
            "obs": np.stack(obss, axis=0),
            "avail_actions": np.stack(avail_actions, axis=0),
            "agent_mask": self.get_agent_mask(),
            "rewards": np.zeros((self.n_agents, 1), dtype=np.float32),
            "terminated": np.array([False], dtype=np.bool),
            "truncated": np.array([False], dtype=np.bool),
            "is_first": np.array([True], dtype=np.bool),
            "is_trailing_absorbing_state": np.array([False], dtype=np.bool),
        }
        # result["obs"].shape: (n_agents, obs_dim)
        # result["avail_actions"].shape: (n_agents, n_actions)
        # result["agent_mask"].shape: (n_agents, 1)
        return result

    def render(self):
        self.env.render()

    def close(self):
        self.env.close()

    def seed(self, seed=None):
        self.env._seed = seed

    def save_replay(self):
        self.env.save_replay()

    def get_env_info(self):
        return self.env.get_env_info()

    def get_stats(self):
        return self.env.get_stats()