import logging
from typing import Dict

import torch
from PIL import Image
from python_utils.transformations import (
    affine_inverse,
    affine_to_pose,
    affine_transform,
    pose_inverse,
    pose_to_affine,
    quaternion_to_rotation_matrix,
    rotation_matrix_to_quaternion,
)

from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment
from furniture_bench_api.utils.pose_utils import alert_about_joint_limits, get_joints_from_obs


class RobotSkill:

    def __init__(self, tolerance: float, sampling_interval: float, max_steps: int = 50):
        self.tolerance = tolerance
        self.sampling_interval = sampling_interval
        self.max_steps = max_steps

    @staticmethod
    def default_action() -> torch.Tensor:
        return torch.as_tensor([0, 0, 0, 1, 0, 0, 0, 0], dtype=torch.float32)

    def compute_action(self, current_pose: torch.Tensor, inputs: Dict[str, torch.Tensor]):
        raise NotImplementedError()

    def is_goal_reached(self, current_pose: torch.Tensor, inputs: Dict[str, torch.Tensor]):
        raise NotImplementedError()

    def action_to_keep_gripper_state(self, current_pose: torch.Tensor) -> torch.Tensor:
        # gripper_pose = current_pose[7]
        # if > 0.065, gripper is open -> -1 to keep it open
        return torch.as_tensor([0], device=current_pose.device)
        # return torch.as_tensor(
        #     [-1 if gripper_pose.item() > 0.065 else 1], device=gripper_pose.device, dtype=torch.float32
        # )

    def run(self, env: FurnitureBenchEnvironment, inputs: Dict[str, torch.Tensor]):
        print("%s: started" % self.__class__.__name__)
        current_pose = env.get_current_pose()
        trajectory = []
        actions = []

        Image.fromarray(env.get_observation()["color_image2"][0].cpu().numpy()).save("image2.png")
        add_data_every_steps = 5

        finished = False
        for i in range(self.max_steps):
            # StateTouchingPredicate().validate(env=env, obj1="lamp_bulb", obj2="lamp_base")
            # from PIL import Image
            # Image.fromarray(env.get_observation()['color_image2'][0].cpu().numpy()).save("image2.png")

            if self.is_goal_reached(current_pose=current_pose, inputs=inputs):
                print("%s: goal reached" % self.__class__.__name__)
                finished = True
                break

            # compute the action in world coords
            delta_action_tool = self.compute_action(current_pose=current_pose, inputs=inputs)

            # delta_action_tool is the delta w.r.t. the global world frame

            gripper_action = delta_action_tool[7]
            # 1. goal pose of tool
            # given the delta of the tool, we need to do following to compute the delta of the tool
            goal_pos_tool = current_pose[:3] + delta_action_tool[:3]
            # since delta_action_tool is in world frame, must apply before the tool_pose
            goal_rot_tool = rotation_matrix_to_quaternion(
                quaternion_to_rotation_matrix(current_pose[3:7]) @ quaternion_to_rotation_matrix(delta_action_tool[3:7])
            )

            goal_pose_tool = torch.concat((goal_pos_tool, goal_rot_tool))
            # 2. goal pose of ee
            goal_pose_ee = affine_transform(goal_pose_tool, pose_inverse(env.T_ee_to_tool))
            current_pose_ee = affine_transform(current_pose[:7], pose_inverse(env.T_ee_to_tool))
            # 3. delta of ee in world frame
            delta_ee_pos = goal_pose_ee[:3] - current_pose_ee[:3]
            delta_ee_rot = affine_to_pose(
                affine_inverse(pose_to_affine(current_pose_ee[:7])) @ pose_to_affine(goal_pose_ee)
            )[3:]
            delta_action_tool = torch.concat((delta_ee_pos, delta_ee_rot, gripper_action[None]))

            actions.append(delta_action_tool)

            delta_action_tool[3:7] = delta_action_tool[[4, 5, 6, 3]]
            delta_action_tool = delta_action_tool.nan_to_num()
            response = env.env.step(action=delta_action_tool)
            new_obs = response[0]

            if i != 0 and i % add_data_every_steps == 0:
                env.add_sample_to_data()

            current_pose = env.get_current_pose()
            current_joints = get_joints_from_obs(new_obs)
            alert_about_joint_limits(joints=current_joints)
            trajectory.append(current_pose)

        Image.fromarray(env.get_observation()["color_image2"][0].cpu().numpy()).save("image2.png")
        if not finished:
            logging.warning("Skill did not finish")
            # raise RuntimeError()

        return torch.stack(trajectory) if len(trajectory) > 0 else None
