import numpy as np
from k_level_policy_gradients.src.core.environment import (
    Environment,
    MDPInfo,
)
from k_level_policy_gradients.src.utils.spaces import *
from k_level_policy_gradients.smacv2.env.starcraft2.starcraft2 import StarCraft2Env


class SMAC(Environment):
    """
    Wraps a smac StarCraft env to be compatible with MushroomRL multi-agent setup.
    """

    def __init__(
        self,
        map_name="",
        difficulty=None,
        state_last_action=True,
        capability_config={},
        horizon=None,
        gamma=0.99,
        seed=0,
        bool_render=False,
    ):
        """Create a new multi-agent StarCraft env compatible with MushoomRL.

        Arguments:
            map_name (str): Name of the map to load.

        """
        self.bool_render = bool_render
        smac_args = {
            "map_name": map_name,
            "difficulty": str(difficulty),
            "state_last_action": state_last_action,
            "capability_config": capability_config,
            "seed": seed,
        }
        self._env = StarCraft2Env(**smac_args)
        self._n_agents = self._env.n_agents

        # Set the observation and action spaces
        action_space = [
            Discrete(self._env.get_total_actions()) for _ in range(self._n_agents)
        ]
        state_space = Box(-1.0, 1.0, shape=(self._env.get_state_size(),))
        observation_space = [
            Box(-1.0, 1.0, shape=(self._env.get_obs_size(),))
            for _ in range(self._n_agents)
        ]
        self.action_mask = [
            Box(0, 1, shape=(self._env.get_total_actions(),))
            for _ in range(self._n_agents)
        ]

        # Set the horizon
        if horizon is not None:
            horizon = int(horizon)
            self._set_horizon(horizon)
        else:
            horizon = self._env.episode_limit
        # self._set_horizon(np.inf)

        mdp_info = MDPInfo(
            state_space=state_space,
            observation_space=observation_space,
            action_space=action_space,
            discrete_actions=True,
            gamma=gamma,
            horizon=horizon,
            has_obs=True,
            has_action_masks=True,
            n_agents=self._n_agents,
        )

        super().__init__(mdp_info)

    def reset(self):
        """
        Resets the env and returns observations from ready agents.

        Returns:
            obs (list): New observations for each ready agent.
        """

        obs, state = self._env.reset()
        action_masks = [
            self._env.get_avail_agent_actions(i) for i in range(self._n_agents)
        ]
        step = {"state": state, "obs": obs, "action_masks": action_masks}
        return step

    def step(self, actions):
        """Returns observations for agents after applying action to the environment.

        Returns
        -------
            obs_list (list): New observations for each agent. (Agents with 0 health get a null observation)
            rewards (list): Reward for each agent.
            absorbing (bool): Whether the episode has entered an absorbing state.
            info (dict): Optional info values (battle_won, dead_allies, dead_enemies, episode_limit).
            state (list): State of the environment.
            action_masks (list): Mask of available actions for each agent for the next step.
        """

        # rew, terminated, and info are single values that apply to all agents
        actions_single = [a[0] for a in actions]  # make actions 0-dimensional
        rew, absorbing, info = self._env.step(actions_single)
        state = self._env.get_state()
        obs = self._env.get_obs()
        action_masks = [
            self._env.get_avail_agent_actions(i) for i in range(self._n_agents)
        ]
        rewards = [rew for _ in range(len(obs))]

        if self.bool_render:
            self._env.render()

        step = {
            "state": state,
            "obs": obs,
            "rewards": rewards,
            "absorbing": absorbing,
            "action_masks": action_masks,
            "info": info,
        }
        return step

    def stop(self):
        """Close the environment"""
        self._env.close()

    def render(self):
        return self._env.render()

    def _set_horizon(self, horizon):
        self._env.episode_limit = horizon
