import logging
from typing import Optional

import torch
from python_utils.transformations import (
    absolute_to_relative,
    relative_to_absolute,
    affine_transform,
    euler_zyx_to_quaternion,
    pose_euler_zyx_to_quaternion,
    quaternion_to_rotation_matrix,
    rotation_matrix_to_quaternion,
)


def get_pose_from_obs(obs) -> torch.Tensor:
    robot_state = obs["robot_state"]
    pos_quat = robot_state["ee_quat"][:, [3, 0, 1, 2]]
    pose = torch.concat((robot_state["ee_pos"], pos_quat), dim=-1)[0]
    # vel = torch.concat((robot_state["ee_pos_vel"], robot_state["ee_ori_vel"]), dim=-1)[0]
    # print(vel)
    gripper_width = robot_state["gripper_width"][0]
    return torch.concat([pose, gripper_width])


def get_joints_from_obs(obs) -> torch.Tensor:
    robot_state = obs["robot_state"]
    return robot_state["joint_positions"]


def relative_to_base(pose: torch.Tensor, fb_env, inverse: bool = False) -> torch.Tensor:
    # part poses are defined in world coords, but actions and robot ee-pose are in base coords of robot.
    # so we need to convert them to base coords
    env = fb_env.env.env.env
    base_pos = env.rb_states[env.base_idxs, :3]
    base_offset = torch.zeros((base_pos.shape[0], 7), device=pose.device)
    base_offset[..., :3] = base_pos
    # rotation is ignored in furniture bench https://github.com/clvrai/furniture-bench/blob/a87960a2490a7a19c86ea16a5ef371718b2baa70/furniture_bench/envs/furniture_sim_env.py#L1003
    base_offset[..., 3] = 1
    if inverse:
        return relative_to_absolute(pose, reference=base_offset)
    else:
        return absolute_to_relative(pose=pose, reference=base_offset)


def get_part_poses_from_obs(obs, env):
    furniture_env = env.env.env.env  # type: FurnitureSimEnv
    parts = furniture_env.furniture.parts
    parts_poses = obs["parts_poses"]
    parts_poses = parts_poses.reshape(parts_poses.shape[0], -1, 7)
    parts_poses[..., 3:7] = parts_poses[..., [6, 3, 4, 5]]
    part_poses = {part.name: pose for part, pose in zip(parts, torch.transpose(parts_poses, 0, 1))}
    part_poses = {k: relative_to_base(pose=v, fb_env=env) for k, v in part_poses.items()}

    return part_poses


def alert_about_joint_limits(joints: torch.Tensor):
    # Franka Panda Joint Limits: https://frankaemika.github.io/docs/control_parameters.html
    upper_limits = torch.as_tensor([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973], device=joints.device)
    lower_limits = torch.as_tensor(
        [-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973], device=joints.device
    )
    # print('joints', joints)
    mask = ((lower_limits + 0.001) >= joints) | (joints >= (upper_limits - 0.001))
    if mask.any():
        logging.warning(
            msg="Robot near joint limits at %s" % ", ".join(map(str, mask.flatten().nonzero().flatten().tolist()))
        )


def transform_pose_in_local_coords(
    pose: Optional[torch.Tensor] = None,
    *,
    x: float = 0,
    y: float = 0,
    z: float = 0,
    rx: float = 0,
    ry: float = 0,
    rz: float = 0,
    device: Optional[torch.dtype] = None,
):
    """translation in mm, rotation in degree"""
    if pose is None:
        assert device is not None
        pose = torch.as_tensor([0, 0, 0, 1, 0, 0, 0], device=device, dtype=torch.float32)

    if device is None:
        device = pose.device

    offset = torch.as_tensor(
        [
            x,
            y,
            z,
            *euler_zyx_to_quaternion(torch.deg2rad(torch.as_tensor([rx, ry, rz], dtype=torch.float32, device=device))),
        ],
        device=device,
        dtype=torch.float32,
    )
    return affine_transform(pose, offset)


def transform_pose_in_world_coords(pose, *, x=0, y=0, z=0, rx=0, ry=0, rz=0):
    # the transformation is of the world coords, so not of the local coords
    trans = pose_euler_zyx_to_quaternion(torch.as_tensor([x, y, z, rx, ry, rz], device=pose.device))
    return transform_pose_in_world_coords_by_pose(pose, trans)


def transform_pose_in_world_coords_by_pose(pose, trans):
    pos = pose[:3]
    rot_affine = quaternion_to_rotation_matrix(pose[3:])
    trans_pos = trans[:3]
    trans_rot_affine = quaternion_to_rotation_matrix(trans[3:])

    new_pos = pos + trans_pos
    new_rot_affine = rot_affine @ trans_rot_affine

    new_rot = rotation_matrix_to_quaternion(new_rot_affine)

    return torch.concat((new_pos, new_rot), dim=-1)


def normalize_pose_to_z(part_pose: torch.Tensor) -> torch.Tensor:
    R = quaternion_to_rotation_matrix(part_pose[3:])

    x_o, y_o, z_o = R[:, 0], R[:, 1], R[:, 2]  # Extract local axes

    # New y-axis should be aligned with world z-axis
    new_y = torch.tensor([0.0, 0.0, 1.0], device=R.device)

    # Project the local z-axis onto the world XY plane
    new_z = z_o.clone()
    new_z[2] = 0  # Remove z component
    import torch.nn.functional as F

    new_z = F.normalize(new_z, dim=0)  # Normalize

    # Compute the new x-axis as the cross product of new_y and new_z
    new_x = torch.cross(new_y, new_z, dim=0)

    # Form the new rotation matrix
    R_new = torch.stack([new_x, new_y, new_z], dim=1)

    part_pose[3:] = rotation_matrix_to_quaternion(R_new)

    part_pose = transform_pose_in_local_coords(part_pose, rz=-90)
    part_pose = transform_pose_in_local_coords(part_pose, ry=90)

    return part_pose
