import os
import gym
import numpy as np

from environments.unimal.config import cfg
from environments.unimal.envs.modules.agent import create_agent_xml
from environments.unimal.envs.tasks.escape_bowl import make_env_escape_bowl
from environments.unimal.envs.tasks.locomotion import make_env_locomotion
from environments.unimal.envs.tasks.obstacle import make_env_obstacle
from environments.unimal.envs.wrappers.select_keys import SelectKeysWrapper
from environments.unimal.utils import file as fu


class UnimalEnvWrapper(gym.Wrapper):
    def __init__(self, env):
        super(UnimalEnvWrapper, self).__init__(env)
        self.limb_obs_size = self.env.limb_obs_size
        self.num_limbs = len(self.env.model.body_names[1:])

        self.joint_obs_size = self.limb_obs_size + 4 # joint id
        self.num_joints = self.metadata["num_joints"]

        self.num_agents = self.num_joints + 1 # with torso
        self.observation_space = gym.spaces.Box(-np.inf,np.inf,(self.joint_obs_size * self.num_agents,), np.float32) # TODO: add ID for x,y joints
        self.action_space = gym.spaces.Box(-1.0,1.0,(self.num_agents,), np.float32)

        self.obs_padding_mask = self.env.obs_padding_mask.copy() # [max_n_limb]
        self.act_padding_mask = self.env.act_padding_mask.copy() # [max_n_limb * 2]
        self.act_padding_mask[0] = False
        obs_id = []
        for i in range(len(self.act_padding_mask)//2):
            x = not self.act_padding_mask[2*i]
            y = not self.act_padding_mask[2*i+1]
            if x and not y:
                obs_id.append([1.,0,0,0])
            elif not x and y:
                obs_id.append([0,1.,0,0])
            elif x and y:
                obs_id.append([0,0,1.,0])
                obs_id.append([0,0,0,1.])
        self.obs_id = np.array(obs_id)

    def step(self, action):
        env_action = np.zeros(len(self.act_padding_mask))
        env_action[~self.act_padding_mask] = action
        obs, reward, done, info = self.env.step(env_action)
        obs = self._obs_out(obs)
        return obs, reward, done, info

    def reset(self):
        obs = self.env.reset()
        obs = self._obs_out(obs)
        return obs

    def _obs_out(self, obs):
        obs = obs['proprioceptive'][:self.limb_obs_size * self.num_limbs].reshape(self.num_limbs, self.limb_obs_size)
        obs = obs.repeat(2, axis=0)
        obs = obs[~self.act_padding_mask[:self.num_limbs * 2]]
        obs = np.concatenate([obs, self.obs_id], axis=-1).ravel()
        obs = np.clip(obs, -5, 5)
        return obs

def make_env(xml_path, env_name):
    xml = create_agent_xml(xml_path)
    env_func = "make_env_{}".format(cfg.ENV.TASK)
    env = globals()[env_func](xml, env_name, xml_path)

    # Add common wrappers in the end
    keys_to_keep = cfg.ENV.KEYS_TO_KEEP + cfg.MODEL.OBS_TYPES
    env = SelectKeysWrapper(env, keys_to_keep=keys_to_keep)
    env = UnimalEnvWrapper(env) #TODO: test
    return env

