from telnetlib import IP
from dependencies.ravens.ravens.environments.environment import EnvironmentNoRotationsWithHeightmap
from dependencies.ravens.ravens.tasks.align_box_corner import AlignBoxCorner
from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
from ravens.environments.environment import ContinuousEnvironment
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from gym import spaces
from huge.envs.env_utils import Discretized

import pybullet as p


class RavensEnvPickAndPlace():
  def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

    assets_root = "./dependencies/ravens/ravens/environments/assets/"
    task = AlignBoxCorner(continuous = True)

    self._env = ContinuousEnvironment(assets_root,
               task,
               disp,
               shared_memory,
               hz,
               use_egl,
               random_goal=True,
               random_box_position=False)
    

    # TODO: adjust
    # TODO: how do I get the state of suction or not?
    obs_upper = 1.0 * np.ones(8) # TODO should be 14
    self._observation_space = spaces.Box(-obs_upper,obs_upper, dtype=np.float32)
    self.goal_space = spaces.Box(obs_upper,obs_upper, dtype=np.float32)
   
    # Discretize environment action space
    """
    # Modifying rotation
    action_upper = 1.0 * np.ones(7)
    action_lower = - action_upper
    intermediate_action_space = gym.spaces.Box(
        low=np.array(action_upper, dtype=np.float32),
        high=np.array(action_lower, dtype=np.float32),
        shape=(7,),
        dtype=np.float32
    )
    """
    # Modifying just end effector
    self.action_scale = 2./100
    action_upper = np.ones(3)
    action_lower = - action_upper
    intermediate_action_space = gym.spaces.Box(
        low=np.array(action_upper, dtype=np.float32),
        high=np.array(action_lower, dtype=np.float32),
        shape=(3,),
        dtype=np.float32
    )

    granularity = 3

    actions_meshed = np.meshgrid(*[np.linspace(lo, hi, granularity) for lo, hi in zip(intermediate_action_space.low, intermediate_action_space.high)])
    self.base_actions = np.array([a.flat[:] for a in actions_meshed]).T
    n_dims = intermediate_action_space.shape[0]
    assert len(self.base_actions) == granularity ** n_dims

    class Discretized(gym.spaces.Discrete):    
        def __init__(self, n, n_dims, granularity):
            self.n_dims = n_dims
            self.granularity = granularity
            assert n-2 == granularity ** n_dims # TODO: we are checking n-1 because we add the suction status, is this okay?

            super(Discretized, self).__init__(n)

    self.action_space = Discretized(len(self.base_actions)+ 2, n_dims=n_dims, granularity=granularity) # +1 corresponds to activate/deactivate suction

    self.ee_init_pos = [0.4831041007489618, 0.029937637798535994, 0.34, 0, 0, 0, 1]
    
    self.ee_bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.35]])

    self.reset()

  def get_postion(self, obs):
      return obs['observation'][:3]

  def step(self, action=None):
      new_action = None

      if action is not None:
          # TODO: make sure it is going to terminate if I have this
          new_action = {}
          new_action['acts_left'] = 0
          new_action['suction_cmd'] = self._env.task.primitive.s_bit
          if len(self.base_actions) == action: # activate suction
                new_action['suction_cmd'] = 1
                delta = np.zeros(3)
          elif (len(self.base_actions) + 1) == action:
                new_action['suction_cmd'] = 0
                delta = np.zeros(3)
          else:
                new_action['suction_cmd'] = 0
                action_cmd = self.base_actions[action]
                delta = action_cmd
      
      orientation = np.array([0,0,0,1])
      delta = delta*self.action_scale

      ee_pose = self.prev_position + delta
      ee_pose = np.clip(ee_pose, self.ee_bounds[:, 0], self.ee_bounds[:, 1])

      new_position = ee_pose , orientation

      new_action['move_cmd'] = new_position
      state, reward, done, info = self._env.step(new_action)
      obs = self._get_obs()
      reward = self.reward(obs)

      self.prev_position = self.get_postion(obs)
      
      return obs, reward, done, info
  
  def reward(self, obs):
      achieved_state = obs['observation']
      goal_state = obs['desired_goal']
      reward = np.linalg.norm(achieved_state - goal_state)
      #print("reward", reward, achieved_state, goal_state)
      return -reward  

  @property
  def state_space(self):
    #shape = self._size + (p.linalg.norm(state - goal) < self.goal_threshold
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self.goal_space

  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def get_world_obs(self, ):
    obs = self._env._get_obs()
    return obs['state'], obs['goal'], int(obs['object_grabbed']), int(obs['suction_state'])
        
  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs, world_goal, object_grabbed, suction_state = self.get_world_obs()
    ee_obs = np.concatenate(self._env.get_ee_pose())
    #obs = np.concatenate([ee_obs, world_obs])
    obs = np.concatenate([ee_obs[:3], [object_grabbed, suction_state], world_obs[:3]])
    #goal = np.concatenate([ self.ee_init_pos, world_goal]) #self._env.goal
    goal = np.concatenate([world_goal[:3], [1, 0], world_goal[:3]]) # goal: ee_goal, object_goal

    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )


  def reset(self, poses={}):
      if len(poses.keys()) != 0 :
        set_pose = poses['goal'].copy()
        set_pose[0] = set_pose[0][:3]
        poses['goal'] = set_pose
      self._env.reset(poses)
      obs = self._get_obs()
      self.prev_position = self.get_postion(obs)
      return obs

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="rgb_array")

class RavensGoalEnvPickAndPlace(GymGoalEnvWrapper):
    def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

        env = RavensEnvPickAndPlace(
                disp,
               shared_memory,
               hz,
               use_egl)
       

        super().__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )

    def compute_success(self, achieved_state, goal):        
      return self.compute_shaped_distance(achieved_state, goal) < 0.05
      #return int(per_obj_success['slide_cabinet'])  + #int(per_obj_success['hinge_cabinet'])+ int(per_obj_success['microwave'])

    def goal_distance(self, state, goal_state):
        # Uses distance in state_goal_key to determine distance (useful for images)
        achieved_state = self.observation(state)
        goal = self.extract_goal(goal_state)

        return self.compute_shaped_distance(achieved_state, goal)
    
    # TODO: write extract functions

    # The task is to open the microwave, then open the slider and then open the cabinet
    def compute_shaped_distance(self, achieved_state, goal):
        assert achieved_state.shape == goal.shape
        distance_ee_obj = np.linalg.norm(achieved_state[:3] - achieved_state[-3:])
        distance_obj_goal = np.linalg.norm(achieved_state[-3:]-goal[-3:])
        object_elevated = np.linalg.norm(achieved_state[-3:] - np.array([0.4831041007489618, 0.029937637798535994, 0.34]))
        print("distance", distance_obj_goal, distance_ee_obj, object_elevated)
        dist_object_grabbed = int(achieved_state[3] == 0)
        return distance_ee_obj + dist_object_grabbed + distance_obj_goal#+ object_elevated #

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        """
        Logs things

        Args:
            trajectories: Numpy Array [# Trajectories x Max Path Length x State Dim]
            desired_goal_states: Numpy Array [# Trajectories x State Dim]

        """
        euclidean_distances = np.array([self.goal_distance(trajectories[i][-1],desired_goal_states[i]) for i in range(trajectories.shape[0])])
        shaped_distances = np.array([self.goal_distance(trajectories[i][-1], desired_goal_states[i]) for i in range(trajectories.shape[0])])
        
        
        statistics = OrderedDict()
        for stat_name, stat in [
            ('final l2 distance', euclidean_distances),
            ('final shaped distance', shaped_distances),
        ]:
            statistics.update(create_stats_ordered_dict(
                    stat_name,
                    stat,
                    always_show_all_stats=True,
                ))
            
        return statistics

