from selectors import PollSelector
from tokenize import PseudoExtras
import numpy as np

from .peginhole_env import *
from gym import spaces

class SinglePeginHole(PeginHole):
    def __init__(self, observation_mode = 'one', sparse=False, sparse3= False, depth=0.03,friction=1, action_mode = 'free', robots=["Panda"], img_size=100, stop_after_success=False, **kwargs):
        super().__init__(robots, gripper_types=None, **kwargs)
        self._goal = self.peg_class
        self._max_episode_steps = 200
        self.friction[0] = friction
        self.observation_mode = observation_mode
        self.action_mode = action_mode
        self.sparse = sparse
        self.sparse3 = sparse3
        self.depth=depth
        self.stop_after_success = stop_after_success
        self.torques_origin = np.array([-5.71109586e-05, -1.11633171e+01, -1.87585357e-04, 1.21644823e+01, -1.53880417e-04, 1.12848703e+00, -2.10842555e-05])
        self.norm=False
        self.img_size = img_size
        if self.observation_mode == 'pixel': 
            self.camera_id = 5
        self.pre_state = None
        self.ob_dict = None
    
    def get_eef_info(self): #eular
        pose = self.get_robot_pose_6d()
        pose[0:3] = pose[0:3]-self.ob_dict['hole_pos']
        vel = self.robots[0].recent_ee_vel.current
        forcetorque = self.robots[0].recent_ee_forcetorques.current
        return np.concatenate((pose,vel,forcetorque))
    def get_eef_info2(self): #quat
        pose = np.concatenate((self.ob_dict['hole_pos']-self.ob_dict['robot0_eef_pos'], self.ob_dict['robot0_eef_quat']))
        vel = self.robots[0].recent_ee_vel.current
        forcetorque = self.robots[0].recent_ee_forcetorques.current
        return np.concatenate((pose,vel,forcetorque))
    def get_eef_info3(self): #quat
        pose = self.get_robot_pose_6d()
        pose[0:3] = pose[0:3]-self.ob_dict['hole_pos']
        vel = self.robots[0].recent_ee_vel.current
        return np.concatenate((pose,vel))
    def without_torque(self):
        joint = self.ob_dict['robot0_proprio-state'][0:21]
        pose = self.get_robot_pose_6d()
        pose[0:3] = pose[0:3]-self.ob_dict['hole_pos']
        vel = self.robots[0].recent_ee_vel.current
        return np.concatenate((joint,pose,vel))
    def without_torque2(self):
        joint_pos = self.robots[0].recent_qpos.current
        joint_vel = self.ob_dict['robot0_proprio-state'][14:21]
        pose = self.get_robot_pose_6d()
        pose[0:3] = pose[0:3]-self.ob_dict['hole_pos']
        vel = self.robots[0].recent_ee_vel.current
        return np.concatenate((joint_pos,joint_vel,pose,vel))

    def _flatten_obs(self, ob_dict, mode='object-state',blind=True):
        if mode == 'standerd': #28+7
            if self.robots[0].torques is None:
                self.robots[0].torques = self.torques_origin
            joint = np.concatenate((ob_dict['robot0_proprio-state'][0:21],self.robots[0].torques))
            if blind:
                eef = np.concatenate((np.array([-0.1,0,0.8])-ob_dict['robot0_eef_pos'], ob_dict['robot0_eef_quat']))
            else:
                eef = np.concatenate((ob_dict['hole_pos']-ob_dict['robot0_eef_pos'], ob_dict['robot0_eef_quat']))
            if self.norm:
                return (np.concatenate((joint, eef))-self.mu)/self.std
            return np.concatenate((joint, eef))
        if mode == 'eef': # pos+vel+forcetorque
            return self.get_eef_info()
        if mode == 'pixel':
            img = self.render(width=self.img_size, height=self.img_size, camera_id=self.camera_id)
            return np.transpose(img, (2,0,1)).copy()
        if mode == 'without_torque':
            return self.without_torque()
        if mode == 'without_torque2':
            return self.without_torque2()
        if mode == 'eef2': # pos+vel+forcetorque
            return self.get_eef_info2()
        if mode == 'eef3': # pos+vel
            return self.get_eef_info3()
        ob_list = []
        for key in ob_dict:
            ob_list.append(ob_dict[key].flatten())
        return np.concatenate(ob_list)
        
    @property
    def observation_space(self):
        ob_dict = self.observation_spec()
        self.ob_dict = ob_dict
        flat_ob = self._flatten_obs(ob_dict,self.observation_mode)
        
        high = np.inf * np.ones(flat_ob.size)
        low = -high
        return spaces.Box(low=low, high=high)
    
    @property
    def action_space(self):
        low, high = self.action_spec 
        if self.action_mode == 'fixed_ori': #only orientation in z
            low, high = (np.array([-1., -1., -1.]), np.array([1., 1., 1.]))
        return spaces.Box(low=low, high=high)
    
    def reset(self):
        ob_dict = super().reset()
        self.ob_dict = ob_dict
        self.pre_state = None
        return self._flatten_obs(ob_dict, self.observation_mode)
    
    def step(self, action):
        if self.action_mode == 'fixed_ori':
            action = np.concatenate((action[0:3], np.array([0,0,0])))
        ob_dict, reward, done, info = super().step(action)
        self.ob_dict=ob_dict
        if self.check_done():
            done = True
        # total sparse
        if self.sparse:
            reward = self.sparse_reward()
        if self.sparse3:
            reward = self.sparse_reward3(depth=self.depth)
        # if self.sparse3 and self._check_success3(depth=self.depth):
        #     done = True
        #     reward=10
        # if self.intrinsic:
        #     reward+=self.intrinsic_reward() 

        
        return self._flatten_obs(ob_dict, self.observation_mode), reward, done, info


class MultitaskPeginHole(PeginHole):
    def __init__(self, robots, n_tasks=10, randomize_tasks=False, **kwargs):
        super().__init__(robots, gripper_types=None, **kwargs)
        self._goal = self.peg_class
        self.num_tasks = 10
    
    def get_all_task_idx(self):
        return range(self.num_tasks)
    
    def reset_task(self, idx):
        if idx >= 5:
            self.peg_class = idx - 5
            self.large_hole = False
            self.friction[0] = 2
        else:
            self.peg_class = idx
            self.large_hole = True
            self.friction[0] = 1
        self._goal = self.peg_class
        self.reset()
    
    def _flatten_obs(self, ob_dict):
        ob_list = []
        for key in ob_dict:
            ob_list.append(ob_dict[key].flatten())
        return np.concatenate(ob_list)
        
    @property
    def observation_space(self):
        ob_dict = self.observation_spec()
        flat_ob = self._flatten_obs(ob_dict)
        
        high = np.inf * np.ones(flat_ob.size)
        low = -high
        return spaces.Box(low=low, high=high)
    
    @property
    def action_space(self):
        low, high = self.action_spec
        return spaces.Box(low=low, high=high)
    
    def reset(self):
        ob_dict = super().reset()
        return self._flatten_obs(ob_dict)
    
    def step(self, action):
        ob_dict, reward, done, info = super().step(action)
        if self._check_success():
            done = True
        return self._flatten_obs(ob_dict), reward, done, info
    
        