# Fixed Horizon wrapper of mujoco environments
import gym
import numpy as np
import robosuite as suite
from robosuite import load_controller_config
from gym import spaces

class RobosuiteFH(gym.Env):
    def __init__(self, env_name, T=500, r=None, obs_mean=None, obs_std=None, seed=1):
        controller = load_controller_config(default_controller="OSC_POSE")
        # these arguments are the same for all envs
        config = {
            "controller_configs": controller,
            "horizon": 500,
            "control_freq": 20,
            "reward_shaping": True,
            "reward_scale": 1.0,
            "use_camera_obs": False,
            "ignore_done": True,
            "hard_reset": False,
        }

        # this should be used during training to speed up training
        # A renderer should be used if you're visualizing rollouts!
        config["has_offscreen_renderer"] = False
        self.env = suite.make(env_name=env_name, robots="Panda", **config)
        # self.env = gym.make(env_name)
        self.T = T
        self.r = r
        assert (obs_mean is None and obs_std is None) or (obs_mean is not None and obs_std is not None)
        self.obs_mean, self.obs_std = obs_mean, obs_std
        

        # 
        obs = self.reset()
        # import ipdb;ipdb.set_trace()
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf,shape=(obs.shape[0],), dtype=np.float16)
        self.action_space = spaces.Box(low=-1, high=1, shape=(self.env.action_dim,), dtype=np.float16)


        # self.env.observation_space
        # self.action_space = self.env.action_space

        # self.seed(seed)
        
    def seed(self, seed):
        pass
        # self.env.seed(seed)

    def reset(self):
        self.t = 0
        self.terminated = False
        self.terminal_state = None

        self.obs = self.env.reset()
        # import ipdb;ipdb.set_trace()
        complete_obs = self.obs.copy()
        # self.obs = np.concatenate((complete_obs['robot0_eef_pos'],complete_obs['robot0_eef_quat']),axis=0)
        # self.obs = np.concatenate((self.obs,complete_obs['robot0_gripper_qpos']),axis=0)
        # self.obs = np.concatenate((self.obs,complete_obs['object-state']),axis=0)
        
        
        self.obs = np.concatenate((self.obs['robot0_proprio-state'],self.obs['object-state']),axis=0)
        
        
        self.obs = self.normalize_obs(self.obs)
        return self.obs.copy()
    
    def step(self, action):
        self.t += 1

        if self.terminated:
            return self.terminal_state, 0, self.t == self.T, True
        else:
            prev_obs = self.obs.copy()
            self.obs, r, done, info = self.env.step(action)
            # complete_obs = self.obs.copy()
            # self.obs = np.concatenate((complete_obs['robot0_eef_pos'],complete_obs['robot0_eef_quat']),axis=0)
            # self.obs = np.concatenate((self.obs,complete_obs['robot0_gripper_qpos']),axis=0)
            # self.obs = np.concatenate((self.obs,complete_obs['object-state']),axis=0)
        
            self.obs = np.concatenate((self.obs['robot0_proprio-state'],self.obs['object-state']),axis=0)
            self.obs = self.normalize_obs(self.obs)
            
            if self.r is not None:  # from irl model
                r = self.r(prev_obs)

            if done:
                self.terminated = True
                self.terminal_state = self.obs
            
            return self.obs.copy(), r, done, done
    
    def normalize_obs(self, obs):
        if self.obs_mean is not None:
            obs = (obs - self.obs_mean) / self.obs_std
        return obs
