from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from envs.base_envs import BenchEnv
from d4rl.kitchen.kitchen_envs import KitchenMicrowaveKettleLightTopLeftBurnerV0
from gym import spaces
import torch

OBJECT_GOAL_VALS = {      'bottom_burner' :  [-0.88, -0.01],
                          'light_switch' :  [ -0.69, -0.05],
                          'slide_cabinet':  [0.37],
                          'left_hinge_cabinet': [0.0],
                          'hinge_cabinet':   [1.45],
                          'microwave'    :   [-0.75],
                        #  'kettle'       :   [-0.23, 0.75, 1.62]
                        }
OBJECT_KEY_POS = {  'bottom_burner' :  [-0.125, 0.68, 2.22],
                    'light_switch' :  [-0.4, 0.68, 2.3],
                    'slide_cabinet':  [-0.12, 0.65, 2.6],
                    'hinge_cabinet':  [-0.53, 0.65, 2.6],
                    'microwave'    :  [-0.63, 0.48, 1.8],
                    #'kettle'       :  [23, 24, 25]
                    }
FINAL_KEY_POS = {  #'bottom_burner' :  [-0.125, 0.68, 2.22],
                    #'light_switch' :  [-0.4, 0.68, 2.3],
                    'slide_cabinet':  [0.2, 0.65, 2.6],
                    'hinge_cabinet':  [-0.45, 0.53, 2.6],
                    'microwave'    :  [-0.7, 0.38, 1.8],
                    #'kettle'       :  [23, 24, 25]
                    }
OBJECT_GOAL_IDXS = {'bottom_burner' :  [2, 3],
                    'light_switch' :  [10,11],
                    'slide_cabinet':  [12],
                    'left_hinge_cabinet': [13],
                    'hinge_cabinet':  [14],
                    'microwave'    :  [15],
                    #'kettle'       :  [16, 17, 18]
                    }

INITIAL_STATE = np.array( [0,0,0, 0,
  0, 0, 0, 0,
  0,  0,  0, 0,
  0, 0, 0, 0,
  -0.269,  0.35,  1.62,  1,
  0, 0, 0])


    

BASE_TASK_NAMES = [   'bottom_burner', 
                        'light_switch', 
                        'slide_cabinet', 
                        'hinge_cabinet', 
                        'microwave', 
                        #'kettle' 
                  ]

"""

    object_goal_vals = {#'bottom_burner' :  [-0.88, -0.01],
                        #  'light_switch' :  [ -0.69, -0.05],
                          'slide_cabinet':  [0.37],
                        #  'hinge_cabinet':   [0., 0.5],
                        #  'microwave'    :   [-0.5],
                        #  'kettle'       :   [-0.23, 0.75, 1.62]
                        }

    object_goal_idxs = {#'bottom_burner' :  [9, 10],
                    #'light_switch' :  [17, 18],
                    'slide_cabinet':  [19],
                    #'hinge_cabinet':  [20, 21],
                    #'microwave'    :  [22],
                    #'kettle'       :  [23, 24, 25]
                    }

    base_task_names = [ #'bottom_burner', 
                        #'light_switch', 
                        'slide_cabinet', 
                        #'hinge_cabinet', 
                        #'microwave', 
                        #'kettle' 
                        ]
"""
 


class KitchenIntermediateEnv(BenchEnv):
  def __init__(self, action_repeat=1, use_goal_idx=False, log_per_goal=False,  control_mode='end_effector', width=64):

    super().__init__(action_repeat, width)
    self.use_goal_idx = use_goal_idx
    self.log_per_goal = log_per_goal

    with self.LOCK:
      self._env =  KitchenMicrowaveKettleLightTopLeftBurnerV0(frame_skip=16, control_mode = control_mode, imwidth=width, imheight=width)

      self._env.sim_robot.renderer._camera_settings = dict(
        distance=3, lookat=[-0.3, .5, 2.], azimuth=90, elevation=-60)

      obs_upper = 8.0 * np.ones(self._env.obs_dim//2)
      obs_lower = -obs_upper
      obs_upper_pose = 4 * np.ones(3)
      obs_lower_pose = -obs_upper_pose
      self._observation_space = spaces.Box(np.concatenate([obs_lower[7:], obs_lower_pose]),np.concatenate([obs_upper[7:], obs_upper_pose]), dtype=np.float32)
      self._goal_space = spaces.Box(np.concatenate([obs_lower[7:], obs_lower_pose]),np.concatenate([obs_upper[7:], obs_upper_pose]), dtype=np.float32)
      print("observation space in kitchen", self._observation_space)
    self.rendered_goal = True
   
    initial_obs = self.reset()

    print("initial_obs", initial_obs)



    
  def generate_goal(self,):
    initial_obs = np.array([4.79267505e-02,  3.71350919e-02, -2.66279850e-04, -5.18043486e-05,
        3.12877220e-05, -4.51199853e-05, -3.90842156e-06, -4.22629655e-05,
        6.28065475e-05,  4.04984708e-05,  4.62730939e-04, -2.26906415e-04,
       -4.65501369e-04, -6.44129196e-03, -1.77048263e-03,  1.08009684e-03,
       -2.69397440e-01,  3.50383255e-01,  1.61944683e+00,  9.99970159e-01,
        4.03883905e-03, -6.58004743e-03, -2.66621172e-04])
    self.goal_name =  'microwave' #'slide_cabinet'#'slide_cabinet' #BASE_TASK_NAMES[random.randint(len(BASE_TASK_NAMES))]
    hook_pose = FINAL_KEY_POS[self.goal_name] #np.array([-0.12, 0.65, 2.6]) #np.random.random(size=(3,))-np.array([0.5,0.5,0.5])+np.array([-1, 0, 2]) # todo: find min max in each dimension
    goal_state = initial_obs
    goal_state[OBJECT_GOAL_IDXS[self.goal_name]] = OBJECT_GOAL_VALS[self.goal_name]
    final_goal = np.concatenate([goal_state, hook_pose])
    return final_goal

  def internal_extract_state(self, obs):
      return obs[7:30]

  def set_goal_idx(self, idx):
    self.goal_idx = idx

  def get_goal_idx(self):
    return self.goal_idx

  def get_goals(self):
    return self.goals

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="human")
   
  @property
  def state_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self._observation_space
  @property
  def goal_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self._goal_space
  @property
  def action_space(self):
    return self._env.action_space
  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs = self.internal_extract_state(self._env._get_obs())
    ee_obs = self._env.get_ee_pose()
    obs = np.concatenate([world_obs, ee_obs])
    goal = self.goal #self._env.goal

    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )

  def step(self, action):
    total_reward = 0.0
    for step in range(self._action_repeat):
      state, reward, done, info = self._env.step(action)
      reward = 0 #self.compute_reward()
      total_reward += reward
      if done:
        break
    obs = self._get_obs()
    for k, v in obs.items():
      if 'metric_' in k:
        info[k] = v
    return obs, total_reward, done, info

  def compute_reward(self, goal=None):
    if goal is None:
      goal = self.goal
    qpos = self._env.sim.data.qpos.copy()

    if len(self.obs_element_indices[goal]) > 9 :
        return  -np.linalg.norm(qpos[self.obs_element_indices[goal]][9:] - self.obs_element_goals[goal][9:])
    else:
        return -np.linalg.norm(qpos[self.obs_element_indices[goal]] - self.obs_element_goals[goal])

  def compute_success(self,):

    if goal is None:
      goal = self.goal
    qpos = self._env.sim.data.qpos.copy()

    goal_qpos = self.init_qpos.copy()
    goal_qpos[self.obs_element_indices[goal]] = self.obs_element_goals[goal]

    per_obj_success = {
    'bottom_burner' : ((qpos[9]<-0.38) and (goal_qpos[9]<-0.38)) or ((qpos[9]>-0.38) and (goal_qpos[9]>-0.38)),
    'top_burner':    ((qpos[13]<-0.38) and (goal_qpos[13]<-0.38)) or ((qpos[13]>-0.38) and (goal_qpos[13]>-0.38)),
    'light_switch':  ((qpos[17]<-0.25) and (goal_qpos[17]<-0.25)) or ((qpos[17]>-0.25) and (goal_qpos[17]>-0.25)),
    'slide_cabinet' :  abs(qpos[19] - goal_qpos[19])<0.1,
    'hinge_cabinet' :  abs(qpos[21] - goal_qpos[21])<0.2,
    'microwave' :      abs(qpos[22] - goal_qpos[22])<0.2,
    'kettle' : np.linalg.norm(qpos[23:25] - goal_qpos[23:25]) < 0.2
    }
    task_objects = self.goal_configs[goal]

    task_rel_success = 1
    for _obj in task_objects:
      task_rel_success *= per_obj_success[_obj]

    all_obj_success = 1
    for _obj in per_obj_success:
      all_obj_success *= per_obj_success[_obj]

    return int(task_rel_success), int(all_obj_success)

  def render_goal(self):
    if self.rendered_goal:
      return self.rendered_goal_obj

    # random.sample(list(obs_element_goals), 1)[0]
    backup_qpos = self._env.sim.data.qpos.copy()
    backup_qvel = self._env.sim.data.qvel.copy()

    qpos = self.init_qpos.copy()
    qpos[self.obs_element_indices[self.goal]] = self.obs_element_goals[self.goal]

    self._env.set_state(qpos, np.zeros(len(self._env.init_qvel)))

    goal_obs = self._env.render('rgb_array', width=self._env.imwidth, height=self._env.imheight)

    self._env.set_state(backup_qpos, backup_qvel)

    self.rendered_goal = True
    self.rendered_goal_obj = goal_obs
    return goal_obs

  def reset(self):

    with self.LOCK:
      state = self._env.reset()
    self.goal = self.generate_goal()#self.goals[self.goal_idx]
    self.rendered_goal = False
    return self._get_obs()

class KitchenGoalEnv(GymGoalEnvWrapper):
    def __init__(self, fixed_start=True, fixed_goal=False, images=False, image_kwargs=None):
        

        env = KitchenIntermediateEnv()
       

        super(KitchenGoalEnv, self).__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )

    def compute_success(self, achieved_state, goal):        
      per_obj_success = {
          'bottom_burner' : ((achieved_state[2]<-0.38) and (goal[2]<-0.38)) or ((achieved_state[2]>-0.38) and (goal[2]>-0.38)),
          'top_burner':    ((achieved_state[15]<-0.38) and (goal[6]<-0.38)) or ((achieved_state[6]>-0.38) and (goal[6]>-0.38)),
          'light_switch':  ((achieved_state[10]<-0.25) and (goal[10]<-0.25)) or ((achieved_state[10]>-0.25) and (goal[10]>-0.25)),
          'slide_cabinet' :  abs(achieved_state[12] - goal[12])<0.1,
          'hinge_cabinet' :  abs(achieved_state[14] - goal[14])<0.2,
          'microwave' :      abs(achieved_state[15] - goal[15])<0.2,
          'kettle' : np.linalg.norm(achieved_state[16:18] - goal[16:18]) < 0.2
      }

      return per_obj_success[self.base_env.goal_name]
  
    def compute_shaped_distance(self, achieved_state, goal):
        if torch.is_tensor(achieved_state):
          achieved_state = achieved_state.detach().cpu().numpy()
        if torch.is_tensor(goal):
          goal = goal.detach().cpu().numpy()
        goal_name = self.base_env.goal_name 

        goal_idxs = OBJECT_GOAL_IDXS[goal_name]
        achieved_joint = achieved_state[goal_idxs]
        goal_joint = goal[goal_idxs]
        original_joint = INITIAL_STATE[goal_idxs]

        distance_from_original = abs(original_joint -  achieved_joint)

        dist_slide = abs(achieved_joint-goal_joint)
        
        key_position = OBJECT_KEY_POS[goal_name]
        if abs(achieved_joint - 0.75) < 0.1:
          import IPython
          IPython.embed()
        distance_to_key_pos = np.linalg.norm(achieved_state[-3:]-key_position)
        print("goal_name", goal_name)
        print("distance joint", distance_from_original, original_joint, achieved_joint)
        if distance_from_original < 0.03 and distance_to_key_pos > 0.05:

          gripper_open = np.linalg.norm(achieved_state[:2]-np.array([0,0]))
          return distance_to_key_pos + gripper_open + dist_slide + 2
        else:
          gripper_closed = np.linalg.norm(achieved_state[:2]-np.array([1,1]))
          return dist_slide + gripper_closed


    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        """self._env.observation_space
        Logs things

        Args:
            trajectories: Numpy Array [# Trajectories x Max Path Length x State Dim]
            desired_goal_states: Numpy Array [# Trajectories x State Dim]

        
        endeff_distances = np.array([self.endeff_distance(trajectories[i], np.tile(desired_goal_states[i], (trajectories.shape[1],1))) for i in range(trajectories.shape[0])])
        puck_distances = np.array([self.puck_distance(trajectories[i], np.tile(desired_goal_states[i], (trajectories.shape[1],1))) for i in range(trajectories.shape[0])])

        endeff_movement = self.endeff_distance(trajectories[:,0], trajectories[:, -1])
        puck_movement = self.puck_distance(trajectories[:,0], trajectories[:, -1])
        
        statistics = OrderedDict()self._env.observation_space
            ('final endeff distance', endeff_distances[:,-1]),
            ('puck movement', puck_movement),
            ('endeff movement', endeff_movement),
        ]:
            statistics.update(create_stats_ordered_dict(
                    stat_name,
                    stat,
                    always_show_all_stats=True,
                ))
        
        return statistics
        """
        return OrderedDict()