from telnetlib import IP
from dependencies.ravens.ravens.environments.environment import EnvironmentNoRotationsWithHeightmap
from dependencies.ravens.ravens.tasks.align_box_corner import AlignBoxCorner
from dependencies.ravens.ravens.tasks.stack_blocks import StackBlocks
from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
from ravens.environments.environment import ContinuousEnvironment
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from gym import spaces
from huge.envs.env_utils import Discretized

import pybullet as p

from ravens.environments.environment import Environment



class RavensEnvStackBlockContinuous():
  def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False,
               num_blocks=1,
               random_goal=True,
               pickorplace=False):

    self.num_blocks = num_blocks
    self.pickorplace = pickorplace

    assets_root = "./dependencies/ravens/ravens/environments/assets/"
    task = StackBlocks(num_blocks=num_blocks, continuous = True, pick_or_place= pickorplace)

    if self.pickorplace:
        self._env = Environment(assets_root,
               task,
               disp,
               shared_memory,
               hz,
               use_egl,
               random_goal=random_goal,
               random_box_position=False)
    else:
        self._env = ContinuousEnvironment(assets_root,
                task,
                disp,
                shared_memory,
                hz,
                use_egl,
                random_goal=random_goal,
                random_box_position=False)
    

    # TODO: adjust
    # TODO: how do I get the state of suction or not?
    obs_upper = 0.5 * np.ones(3 + 1 + self.num_blocks + self.num_blocks * 3)
    obs_lower = -0.5 * np.ones(3 + 1 + self.num_blocks + self.num_blocks * 3)
    self._observation_space = spaces.Box(obs_lower,obs_upper, dtype=np.float32)
    self.goal_space = spaces.Box(-obs_upper,obs_upper, dtype=np.float32)
   
    # Discretize environment action space
    """
    # Modifying rotation
    action_upper = 1.0 * np.ones(7)
    action_lower = - action_upper
    intermediate_action_space = gym.spaces.Box(
        low=np.array(action_upper, dtype=np.float32),
        high=np.array(action_lower, dtype=np.float32),
        shape=(7,),
        dtype=np.float32
    )
    """
    if self.pickorplace:
        delta_margin = 0.05
        self.action_space = gym.spaces.Box(
            low=np.array(np.array([0.25 + delta_margin, -0.5+delta_margin]), dtype=np.float32),
            high=np.array(np.array([0.75-delta_margin, 0.5-delta_margin]), dtype=np.float32),
            shape=(2,),
            dtype=np.float32
        )
        self.ee_init_pos = [0.4831041007489618, 0.029937637798535994, 0.34, 0, 0, 0, 1]

    else:
        # Modifying just end effector
        self.action_scale = 2./100
        action_upper = np.ones(3)
        action_lower = - action_upper
        intermediate_action_space = gym.spaces.Box(
            low=np.array(action_upper, dtype=np.float32),
            high=np.array(action_lower, dtype=np.float32),
            shape=(3,),
            dtype=np.float32
        )

        granularity = 3

        actions_meshed = np.meshgrid(*[np.linspace(lo, hi, granularity) for lo, hi in zip(intermediate_action_space.low, intermediate_action_space.high)])
        self.base_actions = np.array([a.flat[:] for a in actions_meshed]).T
        n_dims = intermediate_action_space.shape[0]
        assert len(self.base_actions) == granularity ** n_dims

    class Discretized(gym.spaces.Discrete):    
        def __init__(self, n, n_dims, granularity):
            self.n_dims = n_dims
            self.granularity = granularity
            assert n-1 == granularity ** n_dims # TODO: we are checking n-1 because we add the suction status, is this okay?

            super(Discretized, self).__init__(n)

    self.action_space = Discretized(len(self.base_actions)+1, n_dims=n_dims, granularity=granularity) # +1 corresponds to activate/deactivate suction

    self.ee_bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.35]])

    self.reset()

  def get_postion(self, obs):
      return obs['observation'][:3]

  def step(self, action=None):
    if self.pickorplace:
        new_action = {}

        if action is not None:
            """
            if action < len(self.base_actions):
                pose = self.base_actions[action]
                new_action['pick_action'] = True
            else:
                pose = self.base_actions[action - len(self.base_actions)]
                new_action['pick_action'] = False
            """
            
            pose = self.base_actions[action]
            new_action['pick_action'] = not self._env.ee.check_grasp() # True

        pose = np.concatenate([pose, [0]])
        orientation = np.array([0,0,0,1])

        new_position = pose , orientation

        new_action['pose0'] = new_position
        state, reward, done, info = self._env.step(new_action)
        obs = self._get_obs()
        reward = self.reward(obs)

        self.prev_position = self.get_postion(obs)
        
        return obs, reward, done, info
    else:
        new_action = None

        if action is not None:
            # TODO: make sure it is going to terminate if I have this
            new_action = {}
            new_action['acts_left'] = 0
            
            if len(self.base_actions) == action:
                delta = np.zeros(3)
                new_action['suction_cmd'] = not self._env.ee.check_grasp()

            else:
                action_cmd = self.base_actions[action]
                delta = action_cmd
                new_action['suction_cmd'] = self._env.ee.check_grasp()

        
        orientation = np.array([0,0,0,1])
        delta = delta*self.action_scale

        ee_pose = self.prev_position + delta
        ee_pose = np.clip(ee_pose, self.ee_bounds[:, 0], self.ee_bounds[:, 1])

        new_position = ee_pose , orientation

        new_action['move_cmd'] = new_position
        state, reward, done, info = self._env.step(new_action)
        obs = self._get_obs()
        reward = self.reward(obs)

        self.prev_position = self.get_postion(obs)
        
        return obs, reward, done, info
  
  def reward(self, obs):
      achieved_state = obs['observation']
      goal_state = obs['desired_goal']
      reward = np.linalg.norm(achieved_state - goal_state)
      #print("reward", reward, achieved_state, goal_state)
      return -reward  

  @property
  def state_space(self):
    #shape = self._size + (p.linalg.norm(state - goal) < self.goal_threshold
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self.goal_space

  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def get_world_obs(self, ):
    obs = self._env._get_obs()
    return obs['state'], obs['goal'], int(obs['object_grabbed']), int(obs['suction_state'])
        
  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs, world_goal, object_grabbed, suction_state = self.get_world_obs()
    ee_obs = np.concatenate(self._env.get_ee_pose())

    obj_pos = []

    for i in range(self.num_blocks):
        obj_pos.append(world_obs[i*7:i*7+3])

    obj_pos = np.array(obj_pos)

    goal = world_goal[:3]

    ee_pos = ee_obs[:3]

    closest_object_to_ee = np.argmin(np.linalg.norm(ee_pos - obj_pos, axis=-1))

    grabbed_emb = np.zeros(self.num_blocks)
    if object_grabbed == 1:
        grabbed_emb[closest_object_to_ee] = 1

    obs = np.concatenate([ee_obs[:3], [suction_state], grabbed_emb, np.concatenate(obj_pos)])
    goal = np.concatenate([world_goal[:3], [0], np.zeros(self.num_blocks), np.concatenate([world_goal[:3] for _ in range(self.num_blocks)])]) # TODO: fix np ones
    assert obs.shape == goal.shape
    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )


  def reset(self, poses={}):
      if len(poses.keys()) != 0 :
        set_pose = poses['goal'].copy()
        set_pose[0] = set_pose[0][:3]
        poses['goal'] = set_pose
      self._env.reset()
      obs = self._get_obs()
      self.prev_position = self.get_postion(obs)
      return obs

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="rgb_array")

class RavensGoalEnvStackBlockContinuous(GymGoalEnvWrapper):
    def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False,
               num_blocks=1,
               random_goal=False,
               goal_threshold=0.05,
               pick_or_place=False,
               ):

        env = RavensEnvStackBlockContinuous(
                disp,
               shared_memory,
               hz,
               use_egl,
               num_blocks=num_blocks,
               random_goal=random_goal,
               pickorplace=pick_or_place
        )
       
        self.goal_threshold = goal_threshold
        self.num_blocks = num_blocks

        super().__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )

    def compute_success(self, achieved_state, goal):   
        success = 0
        for i in range(self.num_blocks):
            if i == 0:
                obj_pos = achieved_state[-3*i-3:]
                obj_goal = goal[-3*i-3:]
            else:
                obj_pos = achieved_state[-3*i-3:-3*i]
                obj_goal = goal[-3*i-3:-3*i]     

            success+= np.linalg.norm(obj_pos - obj_goal) < self.goal_threshold
        return success #self.compute_shaped_distance(achieved_state, goal) < 0.05
        #return int(per_obj_success['slide_cabinet'])  + #int(per_obj_success['hinge_cabinet'])+ int(per_obj_success['microwave'])

    def goal_distance(self, state, goal_state):
        # Uses distance in state_goal_key to determine distance (useful for images)
        achieved_state = self.observation(state)
        goal = self.extract_goal(goal_state)

        return self.compute_shaped_distance(achieved_state, goal)
    
    # TODO: write extract functions
    #def set_marker(self, goal_pos):
    #    self.base_env._env.sim.data.site_xpos[self.base_env._env.sim.model.site_name2id("goal")] = goal_pos

    """
    def render_goal(self, goal, mode='rgb_array', width=640, height=480, camera_id=0):

        # random.sample(list(obs_element_goals), 1)[0]
        backup_qpos = self.base_env._env.sim.data.qpos.copy()
        backup_qvel = self.base_env._env.sim.data.qvel.copy()

        qpos = self.init_qpos.copy()
        qpos[19] = goal[2] # slide cabinet
        qpos[21] = goal[3] # hinge cabinet
        qpos[22] = goal[4] # microwave
        
        self.base_env._env.set_state(qpos, np.zeros(len(self.base_env._env.init_qvel)))

        self.set_marker(goal[-3:])

        goal_obs = self.base_env._env.render('rgb_array')

        self.base_env._env.set_state(backup_qpos, backup_qvel)

        self.rendered_goal = True
        self.rendered_goal_obj = goal_obs
        return goal_obs
    """
    

    # The task is to open the microwave, then open the slider and then open the cabinet
    def compute_shaped_distance(self, achieved_state, goal):
        assert achieved_state.shape == goal.shape
        ee_pos = achieved_state[:3]
        bonus = self.num_blocks + 5
        
        for i in range(self.num_blocks):
            if i == 0:
                obj_pos = achieved_state[-3*i-3:]
                obj_goal = goal[-3*i-3:]
            else:
                obj_pos = achieved_state[-3*i-3:-3*i]
                obj_goal = goal[-3*i-3:-3*i]

            distance_ee_obj = np.linalg.norm( ee_pos - obj_pos )
            distance_obj_goal = np.linalg.norm(obj_pos[:2] - obj_goal[:2])
            distance_above_ground = abs(obj_pos[2] - 0.05)*4

            grabbed_obj = achieved_state[4:4+self.num_blocks].copy()
            
            grabbed_obj[i] = 0
            
            dist_grabbed_obj = 1 - achieved_state[4+i] + np.sum(grabbed_obj)

            if distance_obj_goal < 0.05:
                if achieved_state[4+i] == 0:
                    continue
                else:
                    return bonus * (self.num_blocks - i) * 2 - bonus

            return distance_ee_obj + distance_above_ground + distance_obj_goal + dist_grabbed_obj + bonus * (self.num_blocks - i) *2
        return 0


        distance_ee_obj1 = np.linalg.norm(achieved_state[:3] - achieved_state[-6:-3])
        distance_ee_obj2 = np.linalg.norm(achieved_state[:3] - achieved_state[-3:])

        #
        #    if distance_ee_obj2 < 0.1:
        #        return distance_ee_obj2 + distance_obj2_goal + dist_object_grabbed

        #   return distance_ee_obj2 + distance_obj2_goal + 1 - dist_object_grabbed + 1
        #else:
        bonus = 5
        if distance_obj1_goal < 0.08:
            if distance_ee_obj2 < 0.1:
                return distance_ee_obj2 + distance_obj2_goal + dist_object_grabbed

            return distance_ee_obj2 + distance_obj2_goal + 1 - dist_object_grabbed + bonus 
            
        return distance_ee_obj1 + distance_obj1_goal + dist_object_grabbed + 2* bonus
        object_elevated = np.linalg.norm(achieved_state[-3:] - np.array([0.4831041007489618, 0.029937637798535994, 0.34]))
        print("distance", distance_obj_goal, distance_ee_obj, object_elevated)
        dist_object_grabbed = int(achieved_state[3] == 0)

        return distance_ee_obj + dist_object_grabbed + distance_obj_goal#+ object_elevated #

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        """
        Logs things
        Args:
            trajectories: Numpy Array [# Trajectories x Max Path Length x State Dim]
            desired_goal_states: Numpy Array [# Trajectories x State Dim]
        """
        euclidean_distances = np.array([self.goal_distance(trajectories[i][-1],desired_goal_states[i]) for i in range(trajectories.shape[0])])
        shaped_distances = np.array([self.goal_distance(trajectories[i][-1], desired_goal_states[i]) for i in range(trajectories.shape[0])])
        
        
        statistics = OrderedDict()
        for stat_name, stat in [
            ('final l2 distance', euclidean_distances),
            ('final shaped distance', shaped_distances),
        ]:
            statistics.update(create_stats_ordered_dict(
                    stat_name,
                    stat,
                    always_show_all_stats=True,
                ))
            
        return statistics