import numpy as np


SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]


def load_env(env_name, seed):
    if env_name in MAMUJOCO_ENV_NAMES:
        from envs.mamujoco.env import MaMujocoWrapper as EnvWrapper
        env = EnvWrapper(env_name, seed)
    elif env_name in SMACV1_ENV_NAMES:
        from envs.smacv1.env import SMACWrapper as EnvWrapper
        env = EnvWrapper(env_name, seed)
    elif env_name in SMACV2_ENV_NAMES:
        from envs.smacv2.env import SMACWrapper as EnvWrapper
        env = EnvWrapper(env_name, seed)
    else:
        raise NotImplementedError
    return env


def worker(env_name, seed, remote, parent_remote):
    from envs.utils import silence_stderr

    parent_remote.close()
    with silence_stderr():
        print(f"Worker started for environment: {env_name} with seed: {seed}")
        game = load_env(env_name, seed)
        while True:
            cmd, data = remote.recv()
            if cmd == 'step':
                remote.send(game.step(data))
            elif cmd == 'reset':
                remote.send(game.reset())
            elif cmd == 'get_curr_state':
                state = game.get_current_states()
                remote.send(state)
            elif cmd == 'get_next_state':
                state = game.get_next_states()
                remote.send(state)
            elif cmd == 'get_env_info':
                env_info = game.get_env_info()
                remote.send(env_info)
            elif cmd == 'close':
                game.env.close()
                remote.close()
                break
            else:
                print("Invalid command sent by remote")
                break


class VectorizedEnv(object):

    def __init__(self, env_name, n_envs, seed=0):
        from multiprocessing import Pipe, Process
        self.waiting = False
        self.closed = False
        self.env_name = env_name
        self.n_envs = n_envs
        self.seed = seed
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(n_envs)])
        data = enumerate(zip(self.work_remotes, self.remotes))
        self.ps = [Process(target=worker, args=(env_name, seed+i, work_remote, remote)) for i, (work_remote, remote) in data]
        for p in self.ps:
            p.daemon = True
            p.start()
        for remote in self.work_remotes:
            remote.close()
        self.ob_dim, self.st_dim, self.ac_dim, self.n_agents, self.n_enemies, self.nf_al, self.nf_en = self.get_env_infos()
    
    def step_async(self, actions):
        for remote, _actions in zip(self.remotes, actions):
            remote.send(('step', _actions))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, states, avails, rewards, dones, infos = zip(*results)
        obs, states, avails, rewards, dones = map(np.stack, (obs, states, avails, rewards, dones))
        # infos = [i for info in infos for i in info]
        return obs, states, avails, rewards, dones, infos

    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()
    
    def reset_async(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        self.waiting = True

    def reset_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        # obs, states, avails = zip(*results)
        # obs, states, avails = map(np.stack, (obs, states, avails))
        # return obs, states, avails
        return results

    def reset(self):
        self.reset_async()
        return self.reset_wait()

    
    
    def curr_infos_async(self):
        for remote in self.remotes:
            remote.send(('get_env_info', None))
        self.waiting = True
    
    def curr_infos_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        return results

    def curr_states_async(self):
        for remote in self.remotes:
            remote.send(('get_curr_state', None))
        self.waiting = True

    def curr_states_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, states, avails = zip(*results)
        obs, states, avails = map(np.stack, (obs, states, avails))
        return obs, states, avails

    def next_states_async(self):
        for remote in self.remotes:
            remote.send(('get_next_state', None))
        self.waiting = True

    def next_states_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, states, avails = zip(*results)
        obs, states, avails = map(np.stack, (obs, states, avails))
        return obs, states, avails

    def get_current_states(self):
        self.curr_states_async()
        return self.curr_states_wait()
    
    def get_next_states(self):
        self.next_states_async()
        return self.next_states_wait()
    
    def get_env_infos(self):
        self.remotes[0].send(('get_env_info', None))
        return self.remotes[0].recv()

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True
