# isort: off
from typing import Callable, Optional, Tuple, Union
from furniture_bench_api.api.api_predicates import StateAssembledPredicate, StateTouchingPredicate
from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment
import torch

# isort on
from python_utils.transformations import (
    euler_zyx_to_quaternion,
    absolute_to_relative,
    relative_to_absolute,
    quaternion_to_euler_zyx,
    pose_euler_zyx_to_quaternion,
    pose_quaternion_to_euler_zyx,
    euler_zyx_to_quaternion,
)

from furniture_bench_api.skills.gripper_skills import SetGripperStateSkill
from furniture_bench_api.skills.move_skills import MoveLinearSkill
from furniture_bench_api.skills.robot_skill import RobotSkill
from furniture_bench_api.utils.pose_utils import (
    transform_pose_in_world_coords,
    transform_pose_in_world_coords_by_pose,
)


class APISkills:

    def __init__(self, furniture_bench_env: FurnitureBenchEnvironment, tolerance: float):
        self.furniture_bench_env = furniture_bench_env
        self.tolerance = tolerance

    def move_linear_relative_until_touching(
        self,
        touching_obj_a: str,
        touching_obj_b: str,
        *,
        x_mm=0,
        y_mm=0,
        z_mm=-150,
        in_world_coords: bool = True
    ):
        if touching_obj_a not in self.furniture_bench_env.get_parts():
            raise RuntimeError("Object %s is no part in the environment" % touching_obj_a)
        if touching_obj_b not in self.furniture_bench_env.get_parts():
            raise RuntimeError("Object %s is no part in the environment" % touching_obj_b)

        relative_pos = torch.as_tensor([x_mm, y_mm, z_mm]) * 1e-3
        relative_rot = torch.as_tensor([1, 0, 0, 0])
        direction = relative_pos / relative_pos.norm() * 1e-2
        if direction.norm() > relative_pos.norm():
            direction = direction * relative_pos.norm() / direction.norm()
        total_direction = torch.as_tensor([0, 0, 0], dtype=torch.float32)

        def is_touching() -> bool:
            is_touching, _ = StateTouchingPredicate().validate(
                env=self.furniture_bench_env, obj1=touching_obj_a, obj2=touching_obj_b
            )
            return is_touching

        while total_direction.norm() < relative_pos.norm():
            relative_pose = torch.concat((direction, relative_rot)).to(self.furniture_bench_env.device)
            goal_object_pose = self.furniture_bench_env.get_transformed_pose(touching_obj_b, pose="center")
            curr_pose = self.furniture_bench_env.get_current_pose()[:7]
            pose_correction = goal_object_pose - curr_pose
            relative_pose[:2] += pose_correction[:2]

            rel_pos = (relative_pose[:3] / 1e-3).tolist()
            self.move_linear_relative(
                x_mm=rel_pos[0],
                y_mm=rel_pos[1],
                z_mm=rel_pos[2],
                in_world_coords=in_world_coords,
                should_stop_callback=is_touching,
                lin_acc=0.1, # factor to change default acceleration
            )
            total_direction += relative_pose[:3].cpu()

            print("Is touching", is_touching())

            if is_touching():
                # fix_part = self.furniture_bench_env.get_transformed_pose(touching_obj_b, pose="center")
                # fix_part[2] = self.furniture_bench_env.get_current_pose()[2]
                # self.move_linear(pose=fix_part)
                break

    def move_linear_up(self):
        if (self.furniture_bench_env.furniture == "lamp" and self.furniture_bench_env.grasps_object("lamp_hood")) or (
            self.furniture_bench_env.furniture == "round_table" and self.furniture_bench_env.grasps_object("round_table_base")
        ):
            self.move_linear_relative(z_mm=200)
        elif self.furniture_bench_env.furniture == "round_table" and self.furniture_bench_env.grasps_object("round_table_leg"):
            self.move_linear_relative(z_mm=50)
        else:
            self.move_linear_relative(z_mm=120)

    def move_linear_relative(
        self,
        x_mm=0,
        y_mm=0,
        z_mm=0,
        *,
        in_world_coords: bool = True,
        should_stop_callback: Optional[Callable[[], bool]] = None,
        lin_vel=None,
        lin_acc=None,
    ):
        relative_pos = torch.as_tensor([x_mm, y_mm, z_mm]) * 1e-3
        relative_rot = torch.as_tensor([1, 0, 0, 0])
        rel_pose = torch.concat((relative_pos, relative_rot)).to(self.furniture_bench_env.device)

        current_pose = self.furniture_bench_env.get_current_pose()[:7]
        if in_world_coords:
            goal_pose = transform_pose_in_world_coords_by_pose(current_pose, rel_pose)
        else:
            goal_pose = relative_to_absolute(rel_pose, reference=current_pose)

        if should_stop_callback is not None:

            def callback(current_pose: torch.Tensor):
                return goal_pose, should_stop_callback()

            self.move_linear(callback, lin_vel=lin_vel, lin_acc=lin_acc)
        else:
            self.move_linear(pose=goal_pose, lin_vel=lin_vel, lin_acc=lin_acc)

    def move_linear(
        self, pose: Union[Callable[[torch.Tensor], Tuple[torch.Tensor, bool]], torch.Tensor], *, max_steps: int = 50, lin_vel=None, lin_acc=None
    ):
        inputs = {
            "goal_pose": pose,
        }
        if lin_vel is not None:
            inputs["vel"] = torch.as_tensor([lin_vel, torch.pi / 4], device=self.furniture_bench_env.device)
        if lin_acc is not None:
            inputs["acc"] = torch.as_tensor([0.001, 0.005], device=self.furniture_bench_env.device) * lin_acc
        MoveLinearSkill(
            tolerance=self.tolerance, sampling_interval=self.furniture_bench_env.sampling_interval, max_steps=max_steps
        ).run(
            env=self.furniture_bench_env,
            inputs=inputs,
        )

    def move_to_part_above(self, part_name: str):
        if self.furniture_bench_env.grasps_object(part_name):
            raise RuntimeError("Cannot move above %s as it is currently grasped" % part_name)

        def above_pose():
            pose_z = self.furniture_bench_env.get_current_pose()[2]
            offset = 0
            for part in self.furniture_bench_env.get_parts():
                if self.furniture_bench_env.grasps_object(part):
                    part_min_z = self.furniture_bench_env.get_object_bounding_box(part)[0][2]
                    offset = pose_z - part_min_z  # negative
                    break
            from furniture_bench_api.api.api_predicates_validator import supported_predicates
            for part in self.furniture_bench_env.get_parts():
                if part == part_name:
                    continue

                if supported_predicates["assembled"].validate(self.furniture_bench_env, part_name, part)[0]:
                    other_part_max_z = self.furniture_bench_env.get_object_bounding_box(part)[1][2]
                    part_max_z = self.furniture_bench_env.get_object_bounding_box(part_name)[1][2]
                    if other_part_max_z > part_max_z:
                        offset += other_part_max_z - part_max_z

            pose = self.furniture_bench_env.get_transformed_pose(part=part_name, pose="pre_grasp")
            pose[2] += offset

            return pose

        self.move_linear(pose=lambda _: above_pose(), max_steps=50)

    def move_to_part_center(self, part_name: str):
        try:
            pose = self.furniture_bench_env.get_transformed_pose(part=part_name, pose="grasp")
        except KeyError:
            raise RuntimeError("Part %s is fixed and cannot be moved." % part_name)

        self.move_linear(pose=pose)

    def align(self, part_a: str, part_b: str):
        # setup tool transformation
        pose_a = self.furniture_bench_env.get_transformed_pose(part=part_a, pose="center_for_align")
        curr_pose = self.furniture_bench_env.get_current_pose(at_flange=True)[:7]
        T_ee_to_bulb = absolute_to_relative(pose_a, curr_pose)
        self.furniture_bench_env.T_ee_to_tool = T_ee_to_bulb

        # get initial pose
        init_pose = self.furniture_bench_env.get_current_pose()[:7]

        def get_pose(current_pose: torch.Tensor) -> torch.Tensor:
            ee_pose = current_pose[:7]

            goal_pose = self.furniture_bench_env.get_transformed_pose(part=part_b, pose="center")
            curr_euler = torch.rad2deg(quaternion_to_euler_zyx(ee_pose[3:]))
            goal_quat = euler_zyx_to_quaternion(
                torch.as_tensor([180, 0, curr_euler[2]], device=init_pose.device, dtype=torch.float32), degrees=True
            )
            goal_pose[3:] = goal_quat
            goal_pose[2] = init_pose[2]  # keep height

            return goal_pose

        for _ in range(2):
            self.move_linear(get_pose)

            # move to old center
            # curr_pose = self.furniture_bench_env.get_current_pose()[:7]
            # pose_b = self.furniture_bench_env.get_transformed_pose(part=part_b, pose="center")
            # curr_pose[:2] = pose_b[:2]
            # curr_pose[2] = pose_a[2]
            # curr_pose[2] = init_pose[2]
            # self.move_linear(pose=curr_pose)

            # refine tool offset
            pose_a = self.furniture_bench_env.get_transformed_pose(part=part_a, pose="center_for_align")
            curr_pose = self.furniture_bench_env.get_current_pose(at_flange=True)[:7]
            T_ee_to_bulb = absolute_to_relative(pose_a, curr_pose)
            self.furniture_bench_env.T_ee_to_tool = T_ee_to_bulb

    def close_gripper(self):
        SetGripperStateSkill(
            tolerance=self.tolerance, sampling_interval=self.furniture_bench_env.sampling_interval
        ).run(
            env=self.furniture_bench_env,
            inputs={
                "gripper_percent_open": torch.as_tensor([0], device=self.furniture_bench_env.device),
            },
        )
        self.furniture_bench_env.gripper_closed = True

    def open_gripper(self):
        holds_object = False
        for obj in self.furniture_bench_env.get_objects():
            if self.furniture_bench_env.grasps_object(obj):
                holds_object = True
                break

        SetGripperStateSkill(
            tolerance=self.tolerance, sampling_interval=self.furniture_bench_env.sampling_interval
        ).run(
            env=self.furniture_bench_env,
            inputs={
                "gripper_percent_open": torch.as_tensor([1], device=self.furniture_bench_env.device),
            },
        )

        self.furniture_bench_env.T_ee_to_tool = torch.as_tensor(
            [0, 0, 0, 1, 0, 0, 0], device=self.furniture_bench_env.device, dtype=torch.float32
        )
        self.furniture_bench_env.gripper_closed = False

        if holds_object:
            self.idle(steps=50)

    def rotate_arm_around_z(self, degrees, *, should_stop_callback: Optional[Callable] = None):
        curr_pose = self.furniture_bench_env.get_current_pose()[:7]
        new_pose = transform_pose_in_world_coords(curr_pose, rz=torch.deg2rad(torch.as_tensor([degrees])).item())

        if should_stop_callback is not None:

            def callback(current_pose: torch.Tensor):
                return new_pose, should_stop_callback()

            self.move_linear(callback)
        else:
            self.move_linear(new_pose)

    def rotate_arm_until_screwed(self, part_being_held: str, part_on_table: str):
        if self.furniture_bench_env.grasps_object(part_on_table):
            raise RuntimeError("Cannot rotate arm as %s is currently grasped" % part_on_table)

        from furniture_bench_api.api.api_predicates_validator import supported_predicates
        def is_assembled():
            is_assembled, _ = supported_predicates["assembled"].validate(
                self.furniture_bench_env, part_being_held, part_on_table
            )
            return is_assembled

        for i in range(10):
            self.open_gripper()

            try:
                part_pre = self.furniture_bench_env.get_transformed_pose(
                    part=part_being_held, pose="screwing_grasp_pre"
                )
                self.move_linear(pose=part_pre)
                did_approach = False
            except KeyError:
                self._approach_for_screwing(part_being_held)
                did_approach = True

            if i == 0:
                # self.move_to_part_above(part_on_table)
                self.move_linear_relative(x_mm=0, y_mm=0, z_mm=30)
                joints = self.furniture_bench_env.get_observation()["robot_state"]["joint_positions"]
                joint_6 = torch.rad2deg(joints[0, 6])
                joint_steps = [joint_6.sign() * 30 * (i + 1) for i in range(joint_6.int() // 30)]
                joint_steps += [joint_6]

                total_joint = 0
                for joint_step in joint_steps:
                    self.rotate_arm_around_z(total_joint - joint_step)
                    total_joint = joint_step
                self.move_linear_relative(x_mm=0, y_mm=0, z_mm=-30)
            else:
                self.rotate_arm_around_z(degrees=-90)

            if not did_approach:
                self._approach_for_screwing(part_being_held)
            self.close_gripper()
            self.rotate_arm_around_z(degrees=90, should_stop_callback=is_assembled)

            if is_assembled():
                print("Assembled")
                break
        if not is_assembled():
            print("Not assembled")

        else:
            # reset gripper orientation
            if part_being_held == "round_table_base":
                return
            self.open_gripper()
            self.move_linear_relative(z_mm=10)
            # so robot is at same pose relative to part he has held before
            approach_pose = self.furniture_bench_env.get_transformed_pose(part=part_being_held, pose="grasp")
            e_pose = pose_quaternion_to_euler_zyx(self.furniture_bench_env.get_current_pose()[:7])
            e_pose[-1] = torch.deg2rad(torch.as_tensor([90], device=e_pose.device))
            # pose[-1] = 0
            q_pose = pose_euler_zyx_to_quaternion(e_pose)
            q_pose[:3] = approach_pose[:3]
            self.move_linear(q_pose)
            # self.move_linear_relative(z_mm=-11)
            self.close_gripper()

    def _approach_for_screwing(self, part_being_held: str):
        # assert part_being_held == "lamp_bulb"

        try:
            part_center = self.furniture_bench_env.get_transformed_pose(part=part_being_held, pose="screwing_grasp")
        except KeyError:
            part_center = self.furniture_bench_env.get_transformed_pose(part=part_being_held, pose="center")

            part_center[2] += 0.01  # 1 cm higher
            part_center[3:] = euler_zyx_to_quaternion(
                torch.deg2rad(torch.as_tensor([-180, 0, 54], device=self.furniture_bench_env.device))
            )

        self.move_linear(pose=part_center)

    def idle(self, steps: int = 100):
        device = self.furniture_bench_env.get_current_pose().device
        for _ in range(steps):
            # do 100 idle actions to stop any dynamic moving parts
            self.furniture_bench_env.env.step(RobotSkill.default_action().to(device))
