from rlf.envs.env_interface import EnvInterface, register_env_interface
from rlf.args import str2bool
import numpy as np
import rlf.rl.utils as rutils
import gym
from gym import core
from gym import spaces
import torch

from goal_prox.envs.goal_check import GoalCheckerWrapper

def widowx_obs_transform(obs):
    # old observation: agent_pos, agent_vel, cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel(), is_grasped
    # new observation: cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel(), is_grasped, relative_pos
    if len(obs.shape) == 1:
        obs = obs[16:]
        # relative_pos = cube_position - ee_pos
        cube_pos = obs[0:3]
        ee_pos = obs[3:6]
        relative_pos = cube_pos - ee_pos
        obs = np.concatenate([obs, relative_pos])
        return obs
    elif len(obs.shape) == 2 and isinstance(obs, torch.Tensor):
        obs = obs[:, 16:]
        # relative_pos = cube_position - ee_pos
        cube_pos = obs[:, 0:3]
        ee_pos = obs[:, 3:6]
        relative_pos = cube_pos - ee_pos
        obs = torch.concatenate([obs, relative_pos], axis=1)
        return obs
    else:
        raise ValueError


class EasyObsWidowxWrapper(core.ObservationWrapper):
    # old observation: agent_pos, agent_vel, cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel(),
    # new observation: cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel(), relative_pos
    def __init__(self, env):
        super().__init__(env)
        obs_space = self.observation_space

        # drop the first 16 dim of agent_pos and agent_vel, and add 3 dims as the relative_pos: cube_position - ee_pos
        high = obs_space.high[16:]
        low = obs_space.low[16:]

        # relative_pos = cube_position - ee_pos
        high = np.concatenate([high, obs_space.high[16:19] - obs_space.low[19:22]])
        low = np.concatenate([low, obs_space.low[16:19] - obs_space.high[19:22]])
        
        # is_grasping
        high = np.concatenate([high, np.array([1.0])])
        low = np.concatenate([low, np.array([0.0])])

        self.observation_space = spaces.Box(
                high=high,
                low=low,
                dtype=obs_space.dtype)
        try:
            self.max_episode_steps = env._max_episode_steps
        except AttributeError:
            pass

    def observation(self, obs):
        return widowx_obs_transform(obs)



class WidowxInterface(EnvInterface):
    def create_from_id(self, env_id):
        constrained_action_space = self.args.widowx_constrained
        env = gym.make(env_id, constrained_action_space=constrained_action_space)
        return env
    

class GoalWidowxInterface(WidowxInterface):
    def env_trans_fn(self, env, set_eval):
        env = super().env_trans_fn(env, set_eval)

        def check_goal(env, obs):
            # return env.env._is_success()

            is_obj_picked = (env.env.env.cube.pose.p[0][2] >= env.env.env.float_thresh)
            is_grasped = env.env.env.agent.is_grasping(env.env.env.cube)
            return is_obj_picked & is_grasped

        env = GoalCheckerWrapper(env, check_goal)
        if self.args.widowx_easy_obs:
            env = EasyObsWidowxWrapper(env)
        
        return env

    def get_add_args(self, parser):
        super().get_add_args(parser)
        # is constrained or not
        parser.add_argument('--widowx-constrained', type=str2bool, default=False)
        parser.add_argument('--widowx-easy-obs', type=str2bool, default=True)


WIDOWX_REGISTER_STR = "^(WidowX)"
register_env_interface(WIDOWX_REGISTER_STR, GoalWidowxInterface)
