from envs.customfetch.custom_fetch import PushEnv, SlideEnv, PickPlaceEnv, GoalType, StackEnv, PushLeft, PushRight, SlideNEnv
from gym.envs.robotics import rotations, robot_env, utils

import imageio
import numpy as np

env, external, internal = 'pickplace','all','obj'
if external.lower() == 'all':
  external = GoalType.ALL
elif external.lower() == 'objgrip':
  external = GoalType.OBJ_GRIP
elif external.lower() == 'objspeed':
  external = GoalType.OBJSPEED
elif external.lower() == 'objspeedrot':
  external = GoalType.OBJSPEED2
elif external.lower() == 'obj':
  external = GoalType.OBJ
elif external.lower() == 'grip':
  external = GoalType.GRIP
else:
  raise ValueError

if internal.lower() == 'all':
  raise ValueError
elif internal.lower() == 'objgrip':
  internal = GoalType.OBJ_GRIP
elif internal.lower() == 'obj':
  internal = GoalType.OBJ
elif internal.lower() == 'grip':
  internal = GoalType.GRIP
else:
  raise ValueError

n_blocks = 0
range_min = None # For pickplace
range_max = None # For pickplace

Env = PickPlaceEnv
n_blocks = 1 # THIS IS THE "IN_AIR_PERCENTAGE"
range_min = 0.2 # THIS IS THE MINIMUM_AIR
range_max = 0.45 # THIS IS THE MINIMUM_AIR
env = Env(max_step=51, internal_goal = internal, external_goal = external, mode=0, 
                    per_dim_threshold=0, hard=True, distance_threshold=0, n = n_blocks,
                    range_min=range_min, range_max=range_max)
initial_qpos = {
    'robot0:slide0': 0.405,
    'robot0:slide1': 0.48,
    'robot0:slide2': 0.0,
    'object0:joint': [1.25, 0.53, 0.4, 1., 0., 0., 0.],
}
sim = env.sim
""" FIGURE OUT GRIP / OBJECT POSITION BOUNDS"""
# first try setting box on the table 
INIT_Q_POSES = [
    [1.3, 0.6, 0.41, 1., 0., 0., 0.],
    [1.3, 0.9, 0.41, 1., 0., 0., 0.],
    [1.2, 0.68, 0.41, 1., 0., 0., 0.],
    [1.4, 0.82, 0.41, 1., 0., 0., 0.],
    [1.4, 0.68, 0.41, 1., 0., 0., 0.],
    [1.2, 0.82, 0.41, 1., 0., 0., 0.],
]
ob_dict = env.reset()
import ipdb; ipdb.set_trace()
gif = []
for i, box_pose in enumerate(INIT_Q_POSES):
  # set the objects to the correct position.
  pos = box_pose[:3]
  quat = box_pose[3:]
  sim.data.set_joint_qpos("object0:joint", [*pos,*quat])
  # step the sim
  env.sim.forward()
  mock_img = env.render("rgb_array", 500,500)
  gif.append(mock_img)
imageio.mimwrite("test.mp4", gif)


""" TEST STATE RENDERING SIDE BY SIDE"""
# all_real_img = []
# all_ob_dict = []
# for i in range(100):
#   if i == 0:
#     ob_dict = env.reset()
#   else:
#     ob_dict, *_ = env.step(env.action_space.sample())
#   real_img = env.render("rgb_array", 500,500)
#   all_real_img.append(real_img)
#   all_ob_dict.append(ob_dict)

# all_mock_img = []
# ob_dict = env.reset()
# for i, ob_dict in enumerate(all_ob_dict):
#   obs = ob_dict['observation']
#   obj_pos = obs[:3]
#   grip_pos = obs[3:6]
#   obj_rel_pos = obs[6:9]
#   gripper_state = obs[9:11]
#   object_rot = obs[11:14]
#   # reset the robot.
#   if i == 0:
#     env.reset()
#     env.goal = ob_dict['desired_goal']
#   # move the robot end effector to correct position.
#   gripper_target = grip_pos 
#   gripper_rotation = np.array([1., 0., 1., 0.])
#   sim.data.set_mocap_pos('robot0:mocap', gripper_target)
#   sim.data.set_mocap_quat('robot0:mocap', gripper_rotation)
#   # set the gripper to the correct position.
#   gripper_vel = obs[-2:]
#   sim.data.set_joint_qpos("robot0:r_gripper_finger_joint", gripper_state[0])
#   sim.data.set_joint_qvel("robot0:r_gripper_finger_joint", gripper_vel[0])
#   sim.data.set_joint_qpos("robot0:l_gripper_finger_joint", gripper_state[1])
#   sim.data.set_joint_qvel("robot0:l_gripper_finger_joint", gripper_vel[1])
#   for _ in range(1):
#     env.sim.step()
#   # set the objects to the correct position.
#   obj_quat = rotations.euler2quat(object_rot)
#   sim.data.set_joint_qpos("object0:joint", [*obj_pos,*obj_quat])
#   # step the sim
#   env.sim.forward()

#   mock_img = env.render("rgb_array", 500,500)
#   all_mock_img.append(mock_img)
# gif = []
# for real, mock in zip(all_real_img, all_mock_img):
#   diff= np.clip(real - mock , 0 , 255)
#   img = np.concatenate([real, mock, diff], 1)
#   gif.append(img)
# imageio.mimwrite("test.mp4", gif)