import torch

from .habitat import construct_envs, construct_envs21


def make_vec_envs(args):
    envs = construct_envs21(args)
    envs = VecPyTorch(envs, args.device)
    return envs


# Adapted from
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/envs.py#L159
class VecPyTorch():

    def __init__(self, venv, device):
        self.venv = venv
        self.num_envs = venv.num_envs
        self.observation_space = venv.observation_space
        self.action_space = venv.action_space
        self.device = device

    def reset(self):  # 本来这个函数是info，改为infos，因为确实此时的info包括所有进程的info
        # 多接受一个original_RGBs 的dtype="float32"  obs.dtype也是float32.二者都是ndarray，infos是tuple
        obs, infos, original_RGBs = self.venv.reset()
        obs = torch.from_numpy(obs).float().to(self.device)
        # original_RGBs=torch.from_numpy(original_RGBs).float().to(self.device)
        return obs, infos, original_RGBs  # 多返回一个original_RGBs ndaray类型

    def step_async(self, actions):
        actions = actions.cpu().numpy()
        self.venv.step_async(actions)

    def step_wait(self):
        obs, reward, done, info = self.venv.step_wait()
        obs = torch.from_numpy(obs).float().to(self.device)
        reward = torch.from_numpy(reward).float()
        return obs, done, info

    def step(self, actions):
        actions = actions.cpu().numpy()
        obs, reward, done, info = self.venv.step(actions)
        obs = torch.from_numpy(obs).float().to(self.device)
        # reward = torch.from_numpy(reward).float()
        return obs, done, info

    # def get_rewards(self, inputs):
    #     reward = self.venv.get_rewards(inputs)
    #     reward = torch.from_numpy(reward).float()
    #     return reward

    def plan_act_and_preprocess(self, inputs):
        obs, done, infos, rgbd = self.venv.plan_act_and_preprocess(
            inputs)
        obs = torch.from_numpy(obs).float().to(self.device)
        # reward = torch.from_numpy(reward).float()
        return obs, done, infos, rgbd

    def close(self):
        return self.venv.close()

    def get_scenes_count(self):
        return self.venv.get_scenes_count()
