import numpy as np
from amb.envs.smacv2.core.wrapper import StarCraftCapabilityEnvWrapper

import os.path as osp
from pathlib import Path
import yaml

from gym.spaces import Box, Discrete


class SMACv2Env:
    def __init__(self, args, host=True, ports=None, 
                 multi_map_alignment=False, obs_align_v1=False):
        self.map_config = self.load_map_config(args["map_name"])
        self.host = host
        self.ports = ports
        self.multi_map_alignment = multi_map_alignment
        self.obs_align_v1 = obs_align_v1

    def step(self, actions):
        processed_actions = np.squeeze(actions, axis=1).tolist()
        reward, terminated, info = self.env.step(actions)
        obs = self.env.get_obs()
        state = self.repeat(self.env.get_state())
        rewards = [[reward]] * self.n_agents
        dones = [terminated] * self.n_agents
        if terminated:
            if self.env.env.timeouts > self.timeouts:
                assert (
                    self.env.env.timeouts - self.timeouts == 1
                ), "Change of timeouts unexpected."
                info["bad_transition"] = True
                self.timeouts = self.env.env.timeouts
        infos = [info] * self.n_agents
        avail_actions = self.env.get_avail_actions()
        return obs, state, rewards, dones, infos, avail_actions

    def reset(self):
        self.env.reset()
        obs = self.env.get_obs()
        state = self.repeat(self.env.get_state())
        avail_actions = self.env.get_avail_actions()
        return obs, state, avail_actions

    def seed(self, seed):
        self.env = StarCraftCapabilityEnvWrapper(seed=seed, host=self.host, 
                                                 multi_map_alignment=self.multi_map_alignment,
                                                 obs_align_v1=self.obs_align_v1, 
                                                 ports=self.ports, **self.map_config)
        env_info = self.env.get_env_info()
        n_actions = env_info["n_actions"]
        state_shape = env_info["state_shape"]
        obs_shape = env_info["obs_shape"]
        self.n_agents = env_info["n_agents"]
        self.timeouts = self.env.env.timeouts

        self.share_observation_space = self.repeat(
            Box(low=-np.inf, high=np.inf, shape=(state_shape,))
        )
        self.observation_space = self.repeat(
            Box(low=-np.inf, high=np.inf, shape=(obs_shape,))
        )
        self.action_space = self.repeat(Discrete(n_actions))

    def close(self):
        self.env.close()

    def load_map_config(self, map_name):
        base_path = osp.split(osp.split(osp.dirname(osp.abspath(__file__)))[0])[0]
        map_config_path = (
            Path(base_path)
            / "configs"
            / "envs_cfgs"
            / "smacv2_map_config"
            / f"{map_name}.yaml"
        )
        with open(str(map_config_path), "r", encoding="utf-8") as file:
            map_config = yaml.load(file, Loader=yaml.FullLoader)
        return map_config

    def repeat(self, a):
        return [a for _ in range(self.n_agents)]

    def save_replay(self):
        self.env.save_replay()
        
    def get_env_info(self): 
        return self.observation_space, self.share_observation_space, self.action_space, self.n_agents,  \
            self.env.obs_own_feat, self.env.obs_enemy_feat, self.env.obs_ally_feat

    def get_stats(self):
        return self.env.get_stats()