import random
import numpy as np
import stopit
from envs.utils import silence_stderr


class Config:
    
    n_eval_rollout_threads = 1
    env_name = "StarCraft2"

    def __init__(self, map_name, seed):
        self.map_name = map_name
        self.seed = seed
        self.config = self.read_smac_config(map_name)
    
    def read_smac_config(self, map_name):
        map_type, params = map_name.lower().split("_", 1)
        if map_type not in ["protoss", "terran", "zerg"]:
            raise
        n_agents, _, n_enemy = params.split("_")
        import yaml
        with open(f"envs/smacv2/configs/sc2_gen_{map_type}.yaml", "r") as f:
            config = yaml.safe_load(f)["env_args"]
            config["capability_config"]["n_units"] = int(n_agents)
            config["capability_config"]["n_enemies"] = int(n_enemy)
        return config


class SMACWrapper:

    def __init__(self, env_name, seed=0):
        np.bool = bool
        self.init(env_name)
        self.set_seed(seed)

        # self.list_trajectories = []
    
    def init(self, env_name):
        self.close()
        from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper as StarCraft2Env
        with silence_stderr():
            self.env = StarCraft2Env(**self.read_smac_config(env_name))
            self.env_info = self.env.get_env_info()
        self.st_dim = self.env_info["state_shape"]
        self.ob_dim = self.env_info["obs_shape"]
        self.ac_dim = self.env_info["n_actions"]
        self.n_agents = self.env.env.n_agents
        self.n_enemies = self.env.env.n_enemies
        self.max_len = self.env_info["episode_limit"]
        self.env_name = env_name

        self.nf_al = self.env.env.get_ally_num_attributes()
        self.nf_en = self.env.env.get_enemy_num_attributes()
    
    def get_env_info(self):
        return self.ob_dim, self.st_dim, self.ac_dim, self.n_agents, self.n_enemies, self.nf_al, self.nf_en

    def read_smac_config(self, map_name):
        map_type, params = map_name.lower().split("_", 1)
        if map_type not in ["protoss", "terran", "zerg"]:
            raise
        n_agents, _, n_enemy = params.split("_")
        import yaml
        with open(f"envs/smacv2/configs/sc2_gen_{map_type}.yaml", "r") as f:
            config = yaml.safe_load(f)["env_args"]
            config["capability_config"]["n_units"] = int(n_agents)
            config["capability_config"]["n_enemies"] = int(n_enemy)
        return config
    
    def set_seed(self, seed):
        self.env.env._seed = seed
        random.seed(seed)
        np.random.seed(seed)
        self.seed = seed
    
    @stopit.threading_timeoutable()
    def _reset(self):
        try:
            with silence_stderr():
                self.env.reset()
            self.set_seed(self.seed + 1)
            obs = self.env.get_obs()
            state = self.env.get_state()
            avails = self.env.get_avail_actions()
            return obs, state, avails
        except Exception as e:
            print(f"Error resetting environment {self.env_name}: {e}")

    def reset(self):
        self.total_reward = 0.0
        self.trajectory = []
        try:
            self.obs, self.state, self.avails = self._reset(timeout=100)
        except:
            self.init(self.env_name)
            self.reset()

    def close(self):
        try:
            with silence_stderr():
                self.env.close()
        except:
            pass
    
    def get_current_states(self):
        return self.obs, self.state, self.avails

    def get_next_states(self):
        return self.next_obs, self.next_state, self.next_avails
    
    def step(self, actions):
        with silence_stderr():
            reward, done, info = self.env.step(actions)
        self.next_obs = self.env.get_obs()
        self.next_state = self.env.get_state()
        self.next_avails = self.env.get_avail_actions()

        self.total_reward += reward
        reward = 0.0 if not done else self.total_reward
        self.trajectory.append((self.obs, self.state, self.avails, actions, reward, done))

        myinfo = {}
        if done:
            myinfo["dead_allies"] = info.get("dead_allies", 0) / self.n_agents
            myinfo["dead_enemies"] = info.get("dead_enemies", 0) / self.n_enemies
            myinfo["go_count"] = self.env._episode_steps
            myinfo["won"] = info.get("battle_won", False)

            self.trajectory.append((self.next_obs, self.next_state, self.next_avails, None, None, None))
            myinfo["trajectory"] = self.trajectory
            # self.list_trajectories.append(self.trajectory)

            # nf_al = self.env.env.get_ally_num_attributes()
            # nf_en = self.env.env.get_enemy_num_attributes()

            # print("state:", self.state.shape)

            # ally_state_pos = self.n_agents * nf_al
            # enemy_state_pos = ally_state_pos + self.n_enemies * nf_en

            # ally_state = self.state[:ally_state_pos].reshape(self.n_agents, nf_al)
            # enemy_state = self.state[ally_state_pos:enemy_state_pos].reshape(self.n_enemies, nf_en)

            # ally_healths = ally_state[..., 0]
            # enemy_healths = enemy_state[..., 0]

            self.reset()
        else:
            self.obs, self.state, self.avails = self.next_obs, self.next_state, self.next_avails

        return self.next_obs, self.next_state, self.next_avails, reward, done, myinfo