from telnetlib import IP
from dependencies.ravens.ravens.environments.environment import EnvironmentNoRotationsWithHeightmap
from dependencies.ravens.ravens.tasks.align_box_corner import AlignBoxCorner
from dependencies.ravens.ravens.tasks.stack_blocks import StackBlocks
from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)


from ravens.environments.environment import Environment
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from gym import spaces
from huge.envs.env_utils import Discretized

import pybullet as p


class RavensEnvStackBlock():
  def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

    assets_root = "./dependencies/ravens/ravens/environments/assets/"
    task = StackBlocks(continuous = False, pick_or_place=True)

    self._env = Environment(assets_root,
               task,
               disp,
               shared_memory,
               hz,
               use_egl,
               random_goal=True,
               random_box_position=False)
    

    # TODO: adjust
    # TODO: how do I get the state of suction or not?
    obs_upper = 1.0 * np.ones(9) 
    self._observation_space = spaces.Box(-obs_upper,obs_upper, dtype=np.float32)
    self.goal_space = spaces.Box(obs_upper,obs_upper, dtype=np.float32)
   
    # Discretize environment action space
    """
    # Modifying rotation
    action_upper = 1.0 * np.ones(7)
    action_lower = - action_upper
    intermediate_action_space = gym.spaces.Box(
        low=np.array(action_upper, dtype=np.float32),
        high=np.array(action_lower, dtype=np.float32),
        shape=(7,),
        dtype=np.float32
    )
    """
    # Modifying just end effector
    delta_margin = 0.05
    intermediate_action_space = gym.spaces.Box(
        low=np.array(np.array([0.25 + delta_margin, -0.5+delta_margin]), dtype=np.float32),
        high=np.array(np.array([0.75-delta_margin, 0.5-delta_margin]), dtype=np.float32),
        shape=(2,),
        dtype=np.float32
    )

    granularity = 8

    actions_meshed = np.meshgrid(*[np.linspace(lo, hi, granularity) for lo, hi in zip(intermediate_action_space.low, intermediate_action_space.high)])
    self.base_actions = np.array([a.flat[:] for a in actions_meshed]).T
    n_dims = intermediate_action_space.shape[0]
    assert len(self.base_actions) == granularity ** n_dims

    class Discretized(gym.spaces.Discrete):    
        def __init__(self, n, n_dims, granularity):
            self.n_dims = n_dims
            self.granularity = granularity
            assert n == granularity ** n_dims # TODO: we are checking n-1 because we add the suction status, is this okay?

            super(Discretized, self).__init__(n)

    self.action_space = Discretized(len(self.base_actions), n_dims=n_dims, granularity=granularity) # +1 corresponds to activate/deactivate suction

    self.ee_init_pos = [0.4831041007489618, 0.029937637798535994, 0.34, 0, 0, 0, 1]
    
    self.ee_bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.35]])

    self.reset()

  def get_postion(self, obs):
      return obs['observation'][:3]

  def step(self, action=None):
      new_action = {}

      if action is not None:
          """
          if action < len(self.base_actions):
            pose = self.base_actions[action]
            new_action['pick_action'] = True
          else:
            pose = self.base_actions[action - len(self.base_actions)]
            new_action['pick_action'] = False
          """
          
          pose = self.base_actions[action]
          new_action['pick_action'] = not self._env.ee.check_grasp() # True

      pose = np.concatenate([pose, [0]])
      orientation = np.array([0,0,0,1])

      new_position = pose , orientation

      new_action['pose0'] = new_position
      state, reward, done, info = self._env.step(new_action)
      obs = self._get_obs()
      reward = self.reward(obs)

      self.prev_position = self.get_postion(obs)
      
      return obs, reward, done, info
  
  def reward(self, obs):
      achieved_state = obs['observation']
      goal_state = obs['desired_goal']
      reward = np.linalg.norm(achieved_state - goal_state)
      #print("reward", reward, achieved_state, goal_state)
      return -reward  

  @property
  def state_space(self):
    #shape = self._size + (p.linalg.norm(state - goal) < self.goal_threshold
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self.goal_space

  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def get_world_obs(self, ):
    obs = self._env._get_obs()
    return obs['state'], obs['goal'], int(obs['object_grabbed']), int(obs['suction_state'])
        
  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs, world_goal, object_grabbed, suction_state = self.get_world_obs()
    ee_obs = np.concatenate(self._env.get_ee_pose())
    #obs = np.concatenate([ee_obs, world_obs])
    #assert world_obs.shape == (21,)
    obs = np.concatenate([ee_obs[:2], [object_grabbed], world_obs[:2], world_obs[7:9], world_obs[14:16]])#, world_obs[7:9], world_obs[14:16]]) # TODO: this is wrong
    #goal = np.concatenate([ self.ee_init_pos, world_goal]) #self._env.goal
    goal = np.concatenate([world_goal[:2], [0], world_goal[:2], world_goal[:2], world_goal[:2]])#, world_goal[:2], world_goal[:2]]) # goal: ee_goal, object_goal

    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )


  def reset(self, poses={}):
      if len(poses.keys()) != 0 :
        set_pose = poses['goal'].copy()
        set_pose[0] = np.concatenate([set_pose[0][-2:], [0]])
        poses['goal'] = set_pose
      self._env.reset(poses)
      obs = self._get_obs()
      self.prev_position = self.get_postion(obs)
      return obs

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="rgb_array")

class RavensGoalEnvStackBlock(GymGoalEnvWrapper):
    def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

        env = RavensEnvStackBlock(
                disp,
               shared_memory,
               hz,
               use_egl)
       

        super().__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )


    def compute_success(self, achieved_state, goal):        
      return self.compute_shaped_distance(achieved_state, goal) < 0.05
      #return int(per_obj_success['slide_cabinet'])  + #int(per_obj_success['hinge_cabinet'])+ int(per_obj_success['microwave'])

    def goal_distance(self, state, goal_state):
        # Uses distance in state_goal_key to determine distance (useful for images)
        achieved_state = self.observation(state)
        goal = self.extract_goal(goal_state)

        return self.compute_shaped_distance(achieved_state, goal)
    
    # TODO: write extract functions

    # The task is to open the microwave, then open the slider and then open the cabinet
    def compute_shaped_distance(self, achieved_state, goal):
        assert achieved_state.shape == goal.shape
        # reward is object grasped
        object_grasped = 1 - achieved_state[2]
        distance_obj_goal = np.linalg.norm(achieved_state[-2:] - goal[-2:])
        distance_obj2_goal = np.linalg.norm(achieved_state[-4:-2]-goal[-4:-2])
        distance_obj3_goal = np.linalg.norm(achieved_state[-6:-4]-goal[-6:-4])
        ee_pos = achieved_state[:2]
        #distance_ee_objs = [np.linalg.norm(ee_pos - achieved_state[2:4]), np.linalg.norm(ee_pos - achieved_state[4:6]), np.linalg.norm(ee_pos - achieved_state[6:8])]
        #distance_ee_obj = np.linalg.norm(achieved_state[:2] - achieved_state[-2:])
        distance_success_1 = 1 - int(distance_obj_goal < 0.08) 
        distance_success_2 = 1 - int(distance_obj2_goal < 0.08)
        distance_success_3 = 1 - int(distance_obj3_goal < 0.08)
        return distance_success_1 + distance_success_2 + distance_success_3 #+ min(distance_ee_objs) #+ distance_ee_obj
        if distance_obj_goal < 0.1:
          return - object_grasped*2 + distance_ee_obj
        else:
          return object_grasped + distance_ee_obj

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        """
        Logs things

        Args:
            trajectories: Numpy Array [# Trajectories x Max Path Length x State Dim]
            desired_goal_states: Numpy Array [# Trajectories x State Dim]

        """
        euclidean_distances = np.array([self.goal_distance(trajectories[i][-1],desired_goal_states[i]) for i in range(trajectories.shape[0])])
        shaped_distances = np.array([self.goal_distance(trajectories[i][-1], desired_goal_states[i]) for i in range(trajectories.shape[0])])
        
        
        statistics = OrderedDict()
        for stat_name, stat in [
            ('final l2 distance', euclidean_distances),
            ('final shaped distance', shaped_distances),
        ]:
            statistics.update(create_stats_ordered_dict(
                    stat_name,
                    stat,
                    always_show_all_stats=True,
                ))
            
        return statistics

