from typing import Dict

import torch
from python_utils.transformations import (
    affine_inverse,
    affine_to_pose_euler_zyx,
    euler_zyx_to_quaternion,
    pose_to_affine,
    quaternion_to_euler_zyx,
    quaternion_to_rotation_matrix,
    rotation_matrix_to_quaternion,
)

from furniture_bench_api.skills.robot_skill import RobotSkill


class MoveLinearSkill(RobotSkill):

    def __init__(self, tolerance: float, sampling_interval: float, max_steps: int = 50):
        super().__init__(tolerance=tolerance, sampling_interval=sampling_interval, max_steps=max_steps)
        self.position_tolerance = tolerance
        self.orientation_tolerance = tolerance
        self.default_acc = torch.as_tensor([0.001, 0.005])

    def compute_action(self, current_pose: torch.Tensor, inputs: Dict[str, torch.Tensor]):
        if "vel" not in inputs:
            inputs["vel"] = torch.as_tensor([0.05, torch.pi / 4], device=current_pose.device)
        if "acc" not in inputs:
            # not used
            inputs["acc"] = self.default_acc.to(current_pose.device)

        ee_pose = current_pose[:7]
        _goal_pose = inputs["goal_pose"]

        if callable(_goal_pose):
            goal_pose = _goal_pose(current_pose)
            if isinstance(goal_pose, tuple):
                assert len(goal_pose) == 2
                goal_pose = goal_pose[0]
        else:
            goal_pose = _goal_pose

        orientation_error = rotation_matrix_to_quaternion(
            quaternion_to_rotation_matrix(ee_pose[3:]).T @ quaternion_to_rotation_matrix(goal_pose[3:])
        )
        delta_pos = goal_pose[:3] - ee_pose[:3]

        orientation_correction = quaternion_to_euler_zyx(orientation_error)

        orientation_correction = torch.where(
            orientation_correction.abs() > torch.pi,
            (orientation_correction + torch.pi * 2) % torch.pi * 2,
            orientation_correction,
        )

        delta_quat = euler_zyx_to_quaternion(orientation_correction)
        delta_gripper = self.action_to_keep_gripper_state(current_pose=current_pose)

        delta_action = torch.concat((delta_pos, delta_quat, delta_gripper), axis=-1)

        delta_action = self.normalize(action=delta_action, vel=inputs["vel"], acc=inputs["acc"])

        return delta_action

    def is_goal_reached(self, current_pose: torch.Tensor, inputs: Dict[str, torch.Tensor]):
        _goal_pose = inputs["goal_pose"]

        if callable(_goal_pose):
            goal_pose = _goal_pose(current_pose)
            if isinstance(goal_pose, tuple):
                assert len(goal_pose) == 2
                goal_pose, should_end = goal_pose
                if should_end:
                    return True
        else:
            goal_pose = _goal_pose

        delta_pose = affine_to_pose_euler_zyx(
            affine_inverse(pose_to_affine(current_pose[:7])) @ pose_to_affine(goal_pose)
        )
        position_error = delta_pose[:3].norm()
        orientation_error = delta_pose[3:].norm()
        # print('cond', position_error < self.tolerance, orientation_error < (self.tolerance * 6))
        # print('error', delta_pose)
        return position_error < self.tolerance and orientation_error < (self.tolerance * 6)

    def normalize(self, action: torch.Tensor, vel: torch.Tensor, acc: torch.Tensor):
        pos = action[:3]
        # TODO: why not for sampling_interval?
        max_pos_change = vel[0]  #  * self.sampling_interval
        max_rot_change = vel[1]  #  * self.sampling_interval
        if pos.norm() > max_pos_change:
            normalized_pos = pos / pos.norm() * max_pos_change
        else:
            normalized_pos = pos

        rot = action[3:7]
        rot_euler = quaternion_to_euler_zyx(rot)
        if rot_euler.norm() > max_rot_change:
            normalized_rot_euler = rot_euler / rot_euler.norm() * max_rot_change
        else:
            normalized_rot_euler = rot_euler
        normalized_rot = euler_zyx_to_quaternion(normalized_rot_euler)

        normalized_action = action.clone()
        normalized_action[:3] = normalized_pos
        normalized_action[3:7] = normalized_rot

        return normalized_action
