import numpy as np
import gym
from gym import spaces
import time

from manipulator_learning.sim.envs.thing_generic import ThingEnv


class ThingPushingGeneric(ThingEnv):
    def __init__(self, task, camera_in_state, dense_reward, poses_ref_frame,
                 state_data=('pos', 'obj_pos', 'obj_rot_z_first_only'),
                 max_real_time=5, n_substeps=10, reach_radius=.085,
                 gap_between_prev_pos=.2, image_width=160, image_height=120,
                 moving_base=False, t_vel_limit=.3, r_vel_limit=1.5, control_frame='b', **kwargs):
        super().__init__(task, 'thing_rod', camera_in_state,
                         dense_reward, False, poses_ref_frame, state_data, max_real_time=max_real_time,
                         n_substeps=n_substeps, gap_between_prev_pos=gap_between_prev_pos,
                         image_width=image_width, image_height=image_height, moving_base=moving_base,
                         control_frame=control_frame, **kwargs)
        self.pos_limits = [[.55, -.45, .64], [1.05, .05, 1.0]]
        self.t_vel_limit = t_vel_limit
        self.r_vel_limit = r_vel_limit
        self.reach_radius = reach_radius
        self.reach_radius_time = .5
        self.reach_radius_start_time = None
        self.in_reach_radius = False

    def _calculate_reward_and_done(self, dense_reward, limit_reached, limits_cause_failure=False):
        block_pose = self.env._pb_client.getBasePositionAndOrientation(self.env.block_ids[0])
        if 'coaster' in self.task:
            goal_pose = self.env._pb_client.getBasePositionAndOrientation(self.env.goal_id)
        else:
            goal_pose = self.env._pb_client.getBasePositionAndOrientation(self.env.block_ids[1])
        ee_pose_world = self.env.gripper.manipulator.get_link_pose(
            self.env.gripper.manipulator._tool_link_ind, ref_frame_index=None)
        block_ee_dist = np.linalg.norm(np.array(block_pose[0]) - np.array(ee_pose_world[:3]))
        block_goal_dist = np.linalg.norm(np.array(block_pose[0]) - np.array(goal_pose[0]))
        reward = 3*(1 - np.tanh(10.0 * block_goal_dist)) + 1 - np.tanh(10.0 * block_ee_dist)
        done_success = False
        if block_goal_dist < self.reach_radius:
            if self.reach_radius_start_time is None:
                self.reach_radius_start_time = self.ep_timesteps
            elif (self.ep_timesteps - self.reach_radius_start_time) * self.real_t_per_ts > self.reach_radius_time:
                done_success = True
        else:
            self.reach_radius_start_time = None
        done_failure = False
        if limits_cause_failure and limit_reached:
            done_failure = True
            done_success = False
        if dense_reward:
            return reward, done_success, done_failure
        else:
            return done_success, done_success, done_failure


class ThingPushingXYState(ThingPushingGeneric):
    def __init__(self, max_real_time=7, n_substeps=10, dense_reward=True, **kwargs):
        self.action_space = spaces.Box(-1, 1, (2,), dtype=np.float32)
        self.observation_space = spaces.Box(-np.inf, np.inf, (8,), dtype=np.float32)
        super().__init__('pushing_xy', False, dense_reward, 'w', max_real_time=max_real_time, n_substeps=n_substeps,
                         **kwargs)

class ThingPushingXYImage(ThingPushingGeneric):
    def __init__(self, max_real_time=7, n_substeps=10, dense_reward=True, **kwargs):
        self.action_space = spaces.Tuple((
            spaces.Box(-1, 1, (2,), dtype=np.float32)
        ))
        self.observation_space = spaces.Dict({
            'obs': spaces.Box(-np.inf, np.inf, (10,), dtype=np.float32),
            'img': spaces.Box(0, 255, (160, 120, 3), dtype=np.uint8),
            'depth': spaces.Box(0, 255, (160, 120), dtype=np.uint8),
        })

        spaces.Dict({"position": spaces.Discrete(2), "velocity": spaces.Discrete(3)})
        super().__init__('pushing_xy', True, dense_reward, 'b', ('prev_pos'), max_real_time=max_real_time,
                         n_substeps=n_substeps, **kwargs)

class ThingPushing6DofMultiview(ThingPushingGeneric):
    def __init__(self, max_real_time=12, n_substeps=10, dense_reward=True,
                 image_width=64, image_height=48, **kwargs):
        self.action_space = spaces.Box(-1, 1, (6,), dtype=np.float32)
        self.observation_space = spaces.Dict({
            'obs': spaces.Box(-np.inf, np.inf, (7,), dtype=np.float32),
            'img': spaces.Box(0, 255, (image_height, image_width, 3), dtype=np.uint8),
            'depth': spaces.Box(0, 1, (image_height, image_width), dtype=np.float32)
        })
        super().__init__('pushing_6dof_coaster', True, dense_reward, 'b',
                         state_data=('pos'),
                         max_real_time=max_real_time, n_substeps=n_substeps,
                         image_width=image_width, image_height=image_height, moving_base=True,
                         reach_radius=.033, **kwargs)

        self.pos_limits = [[.55, -.45, .64], [.9, .15, 0.8]]