import numpy as np

from gym.vector import SyncVectorEnv as SyncVectorEnv_
from gym.vector.utils import concatenate, create_empty_array


class SyncVectorEnv(SyncVectorEnv_):
    def __init__(self,
                 env_fns,
                 observation_space=None,
                 action_space=None,
                 **kwargs):
        super(SyncVectorEnv, self).__init__(env_fns,
                                            observation_space=observation_space,
                                            action_space=action_space,
                                            **kwargs)
        for env in self.envs:
            if not hasattr(env.unwrapped, 'reset_task'):
                raise ValueError('The environment provided is not a '
                                 'meta-learning environment. It does not have '
                                 'the method `reset_task` implemented.')

    @property
    def dones(self):
        return self._dones

    def reset_task(self, task):
        for env in self.envs:
            env.unwrapped.reset_task(task)

    def step_wait(self):
        observations_list, infos = [], []
        batch_ids, j = [], 0
        num_actions = len(self._actions)
        rewards = np.zeros((num_actions,), dtype=np.float_)
        for i, env in enumerate(self.envs):
            if self._dones[i]:
                continue

            action = self._actions[j]
            observation, rewards[j], self._dones[i], info = env.step(action)
            batch_ids.append(i)

            if not self._dones[i]:
                observations_list.append(observation)
                infos.append(info)
            j += 1
        if num_actions != j:
            print('num_actions', num_actions)
            print('j', j)
            for i, env in enumerate(self.envs):
                print('i', i)
        assert num_actions == j

        if observations_list:
            observations = create_empty_array(self.single_observation_space,
                                              n=len(observations_list),
                                              fn=np.zeros)
            concatenate(observations_list,
                        observations,
                        self.single_observation_space)
        else:
            observations = None

        return (observations, rewards, np.copy(self._dones),
                {'batch_ids': batch_ids, 'infos': infos})

    def get_qpos_qvel(self):
        qpos_length = len(self.envs[0].sim.data.qpos)
        qvel_length = len(self.envs[0].sim.data.qvel)
        qpos_all=np.zeros((len(self.envs),qpos_length))
        qvel_all=np.zeros((len(self.envs),qvel_length))
        for i, env in enumerate(self.envs):
            qpos_all[i] = env.sim.data.qpos
            qvel_all[i] = env.sim.data.qvel

        return qpos_all, qvel_all

    def get_qpos_qvel_length(self):
        qpos_length = len(self.envs[0].sim.data.qpos)
        qvel_length = len(self.envs[0].sim.data.qvel)
        return qpos_length, qvel_length

    def get_observations_from_qpos_qel(self,qposs,qvels):
        if self.envs[0].name == 'HalfCheetahVel':
            observation_length = len(qposs[0])+len(qvels[0])-1 #+len(self.envs[0].get_body_com("torso"))
        if self.envs[0].name =='AntVel':
            observation_length = len(qposs[0])+len(qvels[0])#+len(self.envs[0].sim.data.cfrc_ext.flat)+len(self.envs[0].sim.data.get_body_xmat("torso").flat)+len(self.envs[0].get_body_com("torso"))
        observations = np.zeros((len(self.envs),observation_length))
        for i, env in enumerate(self.envs):
            state = env.sim.get_state()
            state.qpos[:] = qposs[i]
            state.qvel[:] = qvels[i]           
            
            env.sim.set_state(state)
            env.sim.forward()

            observations[i] = env.unwrapped._get_obs()

            # print('qpos',qposs[i])
            # print('qvel',qvels[i])
            # print('observations',observations[i])
            
            # observations[i] = np.concatenate([env.sim.data.qpos.flat[1:],
            #                                   env.sim.data.qvel.flat,
            #                                   env.get_body_com("torso").flat,]).astype(np.float32).flatten()
        return observations.astype(np.float32)
        
    def notdone(self):
        for i, env in enumerate(self.envs):
            self._dones[i] = 0