from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np


@dataclass
class ObsData:
    current_pos: np.ndarray
    current_gripper: float
    obj1_pos: np.ndarray
    obj2_pos: np.ndarray
    goal_pos: np.ndarray


def parse_obs(obs: np.ndarray) -> ObsData:
    assert len(obs) == 18 * 2 + 3
    current_obs = obs[:18]
    obs_data = ObsData(
        current_pos=current_obs[:3],
        current_gripper=current_obs[3],
        obj1_pos=current_obs[4:7],
        obj2_pos=current_obs[-7:-4],
        goal_pos=obs[-3:],
    )
    return obs_data


def diff_to_action(diff: np.ndarray, scale: float = 50.) -> np.ndarray:
    return np.clip(diff * scale, -1, 1)


class ScriptedPolicy:
    """Base class for scripted policies.
    It also provides a default implementation of __call__ method.
    """

    def __init__(self, precision: float = 0.05, velocity: float = 40.):
        self.precision = precision
        self.velocity = velocity
        self.target_waypoint_idx = 0

    def reset(self):
        self.target_waypoint_idx = 0

    def update_target_if_necessary(self, obs_data: ObsData):
        target_pos, _ = self.get_target(obs_data)
        diff = target_pos - obs_data.current_pos
        distance = np.linalg.norm(diff)
        if distance < self.precision:
            self.target_waypoint_idx += 1

    def get_target(self, obs_data: ObsData) -> Tuple[np.ndarray, float]:
        """Returns the target position and target gripper action."""
        raise NotImplementedError

    def __call__(
        self,
        obs: np.ndarray,
        img: Optional[np.ndarray] = None,
        gripper_img: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        obs_data = parse_obs(obs)
        self.update_target_if_necessary(obs_data)
        target_pos, target_gripper = self.get_target(obs_data)
        diff = target_pos - obs_data.current_pos
        action = diff_to_action(diff, scale=self.velocity)
        action = np.concatenate([action, [target_gripper]])

        return action


class ReachPolicy(ScriptedPolicy):
    """Move directly to the goal."""

    def get_target(self, obs_data: ObsData) -> Tuple[np.ndarray, float]:
        return obs_data.goal_pos, 0.


class DetourPolicy(ScriptedPolicy):
    """Move to the object, then move to the goal."""

    def get_target(self, obs_data) -> Tuple[np.ndarray, float]:
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos, 0.
        elif self.target_waypoint_idx >= 1:
            return obs_data.goal_pos, 0.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class HookPolicy(ScriptedPolicy):
    """Move above the object, then move to the object, then move to the goal."""

    def get_target(self, obs_data):
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos + np.array([0, 0, 0.1]), 0.
        elif self.target_waypoint_idx == 1:
            return obs_data.obj1_pos, 0.
        elif self.target_waypoint_idx >= 2:
            return obs_data.goal_pos, 0.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class ButtonPressTopdownPolicy(ScriptedPolicy):
    """Policy specialized for button press task."""

    def get_target(self, obs_data):
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos + np.array([0, 0, 0.1]), 0.
        elif self.target_waypoint_idx >= 1:
            return obs_data.obj1_pos, 1.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class LeverPullPolicy(ScriptedPolicy):
    """Policy specialized for lever pull task."""

    def __init__(self):
        super().__init__(precision=0.02, velocity=30)

    def get_target(self, obs_data) -> Tuple[np.ndarray, float]:
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos + np.array([0., -0.1, -0.12]), 0.
        elif self.target_waypoint_idx >= 1:
            return obs_data.goal_pos, 0.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class PegInsertSidePolicy(ScriptedPolicy):
    """Policy specialized for peg insertion task."""

    def update_target_if_necessary(self, obs_data: ObsData):
        target_pos, _ = self.get_target(obs_data)
        diff = target_pos - obs_data.current_pos
        distance = np.linalg.norm(diff)
        if distance < self.precision:
            if self.target_waypoint_idx == 1:
                if obs_data.current_gripper < 0.3:
                    self.target_waypoint_idx += 1
            else:
                self.target_waypoint_idx += 1

    def get_target(self, obs_data) -> Tuple[np.ndarray, float]:
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos, 0.
        elif self.target_waypoint_idx == 1:
            return obs_data.obj1_pos, 1.
        elif self.target_waypoint_idx == 2:
            return obs_data.goal_pos + np.array([0.2, 0, 0.03]), 1.
        elif self.target_waypoint_idx == 3:
            return obs_data.goal_pos + np.array([0, 0, 0.03]), 1.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class PushPolicy(ScriptedPolicy):
    """Pinch and push the object to the goal position."""

    def __init__(self, precision: float = 0.05, velocity: float = 30.):
        super().__init__(precision, velocity)

    def update_target_if_necessary(self, obs_data: ObsData):
        target_pos, _ = self.get_target(obs_data)
        diff = target_pos - obs_data.current_pos
        distance = np.linalg.norm(diff)
        if distance < self.precision:
            if self.target_waypoint_idx == 1:
                if obs_data.current_gripper < 0.5:
                    self.target_waypoint_idx += 1
            else:
                self.target_waypoint_idx += 1

    def get_target(self, obs_data) -> Tuple[np.ndarray, float]:
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos, 0.
        elif self.target_waypoint_idx == 1:
            return obs_data.obj1_pos, 1.
        elif self.target_waypoint_idx == 2:
            return obs_data.goal_pos, 1.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


class PickPlacePolicy(ScriptedPolicy):
    """Pick, move up, and get to the goal position horizontally (y direction).
    """

    def __init__(self,
                 precision: float = 0.05,
                 velocity: float = 30.,
                 gripper_interval: float = 0.5,
                 y_margin: float = 0.15):
        super().__init__(precision, velocity)
        self.gripper_interval = gripper_interval
        self.y_margin = y_margin

    def update_target_if_necessary(self, obs_data: ObsData):
        target_pos, _ = self.get_target(obs_data)
        diff = target_pos - obs_data.current_pos
        distance = np.linalg.norm(diff)
        if distance < self.precision:
            if self.target_waypoint_idx == 1:
                if obs_data.current_gripper < self.gripper_interval:
                    self.target_waypoint_idx += 1
            else:
                self.target_waypoint_idx += 1

    def get_target(self, obs_data) -> Tuple[np.ndarray, float]:
        if self.target_waypoint_idx == 0:
            return obs_data.obj1_pos, 0.
        elif self.target_waypoint_idx == 1:
            return obs_data.obj1_pos, 1.
        elif self.target_waypoint_idx == 2:
            return obs_data.goal_pos + np.array([0, -self.y_margin, 0.07]), 1.
        elif self.target_waypoint_idx == 3:
            return obs_data.goal_pos + np.array([0, 0., 0.07]), 1.
        else:
            raise ValueError(
                f'Unknown target waypoint index: {self.target_waypoint_idx}')


def get_scripted_policy(task_name: str) -> ScriptedPolicy:
    if task_name == 'reach-v2':
        return ReachPolicy()
    elif task_name == 'door-close-v2':
        return ReachPolicy()
    elif task_name == 'door-open-v2':
        return HookPolicy()
    elif task_name == 'drawer-close-v2':
        return ReachPolicy()
    elif task_name == 'drawer-open-v2':
        return HookPolicy()
    elif task_name == 'window-open-v2':
        return DetourPolicy()  # can be better
    elif task_name == 'button-press-topdown-v2':
        return ButtonPressTopdownPolicy()
    elif task_name == 'lever-pull-v2':
        return LeverPullPolicy()  # a bit unstable
    elif task_name == 'peg-insert-side-v2':
        return PegInsertSidePolicy()
    elif task_name == 'push-v2':
        return PushPolicy(velocity=25)
    elif task_name == 'sweep-v2':
        return PushPolicy()
    elif task_name == 'sweep-into-v2':
        return PushPolicy()
    elif task_name == 'pick-place-v2':
        return PushPolicy()
    elif task_name == 'basketball-v2':
        return PickPlacePolicy(gripper_interval=0.7, precision=0.02)
    elif task_name == 'shelf-place-v2':
        return PickPlacePolicy(gripper_interval=0.3, y_margin=0.2)
    else:
        raise ValueError(f'Unknown task name: {task_name}')
