import numpy as np
import torch

from src.common.constants import FLOOR_AABB

def normalize_quaternion(q):
    return q / np.linalg.norm(q)

def direction_to_quat(direction, angle=0.0):
    dir_vec = np.asarray(direction, dtype=float)
    n = np.linalg.norm(dir_vec)
    if n < 1e-8:
        raise ValueError("direction must be non-zero")
    dir_vec /= n

    z = np.array([0.0, 0.0, 1.0], dtype=float)
    dot = float(np.clip(np.dot(z, dir_vec), -1.0, 1.0))

    eps = 1e-6
    if abs(dot - 1.0) < eps:
        base = np.array([1.0, 0.0, 0.0, 0.0], dtype=float)
    elif abs(dot + 1.0) < eps:
        base = np.array([0.0, 1.0, 0.0, 0.0], dtype=float)
    else:
        axis = np.cross(z, dir_vec)
        axis /= np.linalg.norm(axis)
        half = np.arccos(dot) * 0.5
        s = np.sin(half)
        base = np.array(
            [np.cos(half), axis[0] * s, axis[1] * s, axis[2] * s], dtype=float
        )

    if angle != 0.0:
        h = -angle * 0.5
        s = np.sin(h)
        wrist = np.array(
            [np.cos(h), dir_vec[0] * s, dir_vec[1] * s, dir_vec[2] * s], dtype=float
        )
        w1, x1, y1, z1 = wrist
        w2, x2, y2, z2 = base
        base = np.array(
            [
                w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
                w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
                w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
                w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
            ],
            dtype=float,
        )

    base /= np.linalg.norm(base)
    return base

def get_gripper_offset():
    return np.array([0.0, 0.0, 0.215])

def move_gripper_to(env, obj_name, pointing_to="down", depth=0.01):
    if obj_name not in env.env.scene_objects:
        print(f"Object '{obj_name}' not found in scene")
        return False

    if not _feasibility_check(env, obj_name):
        print(f"Object '{obj_name}' is not reachable.")
        return False

    direction = np.array([0.0, 0.0, -1.0])

    obj_aabb = env.get_obj_bbox(obj_name)
    obj_offset = (obj_aabb[1] - obj_aabb[0])[2] / 2.0  # Default z-axis offset

    obj_pos = env.get_obj_pos(obj_name)
    target_pos = obj_pos + (depth - obj_offset) * direction

    move_to_position(env, target_pos, pointing_to, angle=0.0)
    return True

def move_to_position(env, pos, pointing_to="down", lift_clearance=0.12, angle=0.0):
    end_effector = env.env.franka.get_link(env.ee_name)

    prev_direction = np.array([0, 0, -1])
    prev_quat = direction_to_quat(prev_direction, angle=env.gripper_state["angle"])

    curr_direction = np.array([0, 0, -1])
    curr_quat = direction_to_quat(curr_direction, angle=angle)

    env.gripper_state["angle"] = angle
    env.env.scene_objects["gripper"].pointing_to = pointing_to

    pos_offset = get_gripper_offset()
    target_pos = pos + pos_offset
    current_pos = end_effector.get_pos().cpu().numpy()

    ascend_pos = current_pos - prev_direction * lift_clearance
    descend_pos = target_pos - curr_direction * lift_clearance

    waypoints = [ascend_pos, descend_pos, target_pos]
    quaternions = [prev_quat, curr_quat, curr_quat]

    if len(env.trajectory) == 1:
        waypoints.pop(0)
        quaternions.pop(0)

    for i, (waypoint, quat) in enumerate(zip(waypoints, quaternions)):
        qpos = env.env.franka.inverse_kinematics(
            link=end_effector, pos=waypoint, quat=quat
        )

        if i == len(waypoints) - 2:
            # Skip path planning if holding an object
            holding_object = env._grasp["active"] or env._welded["active"]
            use_path_planning = not holding_object

            if use_path_planning:
                path, is_valid = env.franka.plan_path(
                    qpos_goal=qpos, num_waypoints=200, return_valid_mask=True
                )
            else:
                is_valid = None

            if use_path_planning and is_valid is not None and is_valid.all():
                for wp in path:
                    env.franka.control_dofs_position(wp)

                    if not env.env.scene_objects["gripper"].gripper_open:
                        env.env.franka.control_dofs_force(
                            np.array([-0.5, -0.5]), [7, 8]
                        )
                    else:
                        env.env.franka.control_dofs_position(
                            [0.05, 0.05], [7, 8]
                        )

                    env.step()

                for _ in range(50):
                    env.step()
            else:
                dist = np.linalg.norm(waypoints[i] - waypoints[i + 1])
                steps = max(int(dist * 150), 80)
                _execute_joint_motion(env, qpos[:7], steps=steps)
        else:
            _execute_joint_motion(env, qpos[:7], steps=120)

            for _ in range(50):
                env.step()

    env.trajectory.append(env.env.capture_obs())

def move_parallel(env, move_dir, offset, pointing_to="down"):
    if move_dir not in ["left", "right", "front", "back", "up", "down"]:
        raise ValueError(
            "Invalid direction. Choose from 'left', 'right', 'front', 'back', 'up', 'down' "
        )
    if offset <= 0:
        raise ValueError("Offset must be a positive value.")

    end_effector = env.franka.get_link(env.ee_name)
    direction = np.array([0, 0, -1])
    quat = direction_to_quat(direction, angle=env.gripper_state["angle"])
    env.env.scene_objects["gripper"].pointing_to = pointing_to

    current_pos = end_effector.get_pos().cpu().numpy()
    displacement_map = {
        "front": [offset, 0, 0],
        "back": [-offset, 0, 0],
        "left": [0, offset, 0],
        "right": [0, -offset, 0],
        "up": [0, 0, offset],
        "down": [0, 0, -offset],
    }
    waypoint = current_pos + np.array(displacement_map[move_dir])

    qpos = env.env.franka.inverse_kinematics(
        link=end_effector, pos=waypoint, quat=quat
    )
    _execute_joint_motion(env, qpos[:7])

    env.trajectory.append(env.env.capture_obs())

def rotate_gripper(env, angle, steps=80):
    # if env.ee_type != "gripper":
    #     print("[env] rotate_gripper called but end-effector is not gripper.")
    #     return

    # Convert angle delta to radians
    angle_delta = np.radians(angle)

    # Update gripper state
    env.gripper_state["angle"] += angle_delta

    # Get current joint positions and only modify joint 6 (wrist)
    current_qpos = env.env.franka.get_qpos().cpu().numpy()[:7].astype(np.float32)
    goal_qpos = current_qpos.copy()
    goal_qpos[6] -= angle_delta

    # Execute the motion
    _execute_joint_motion(env, goal_qpos, steps=steps)

    env.trajectory.append(env.env.capture_obs())

def open_gripper(env):
    src_qpos = env.env.franka.get_qpos().cpu().numpy()[-2:]
    tgt_qpos = np.array([0.05, 0.05])
    step = (tgt_qpos - src_qpos) / 50

    if env._welded["active"]:
        rigid = env.env.scene.sim.rigid_solver
        link_object = env._welded["obj_link_idx"]
        link_franka = env.franka.get_link(env.ee_name).idx
        rigid.delete_weld_constraint(link_object, link_franka)
        env._welded.update(
            dict(
                active=False,
                object=None,
                obj_link_idx=None,
            )
        )

    for i in range(50):
        if env.ee_type == "gripper":
            env.env.franka.control_dofs_position(src_qpos + step * i, [7, 8])
        env.step()

    env.env.scene_objects["gripper"].gripper_open = True
    env.trajectory.append(env.env.capture_obs())

def close_gripper(env):
    env.env.franka.control_dofs_force(np.array([-0.5, -0.5]), [7, 8])
    for _ in range(50):
        env.step()

    if not env._grasp["active"] and not env._welded["active"]:
        for obj_name in env.env.scene_objects:
            if obj_name in ["gripper", "floor"] or "handle" in obj_name:
                continue

            scene_obj = env.env.scene_objects[obj_name]

            if env.obj_in_gripper(obj_name):
                rigid = env.env.scene.sim.rigid_solver

                if not hasattr(scene_obj, "links"):
                    continue

                link_object = scene_obj.links[0].idx
                link_franka = env.franka.get_link(env.ee_name).idx

                rigid.add_weld_constraint(link_object, link_franka)
                env._welded.update(
                    dict(
                        active=True,
                        object=obj_name,
                        obj_link_idx=link_object,
                    )
                )
                break

    env.env.scene_objects["gripper"].gripper_open = False
    env.trajectory.append(env.env.capture_obs())

def grasp_handle(env, handle_name):
    if env._grasp["active"] and env._grasp["object"] == handle_name:
        close_gripper(env)
        return True

    if "handle" not in handle_name or handle_name not in env.env.scene_objects:
        return False

    if env._welded["active"]:
        return False

    handle = env.env.scene_objects[handle_name]
    ee = env.franka.get_link(env.ee_name)

    handle_pos = env.get_obj_pos(handle_name)
    gripper_pos = ee.get_pos().cpu().numpy() - get_gripper_offset()

    displacement = gripper_pos - handle_pos
    distance = np.linalg.norm(displacement)

    if distance > 0.1:
        return False

    env._grasp.update(
        dict(
            active=True,
            object=handle_name,
            obj_link_idx=handle.idx,
        )
    )
    close_gripper(env)
    return True

def release_handle(env):
    env._grasp.update(
        dict(
            active=False,
            object=None,
            obj_link_idx=None,
        )
    )
    open_gripper(env)

def pick(env, obj_name, pointing_to="down"):
    if not move_gripper_to(env, obj_name, pointing_to=pointing_to):
        return False
    close_gripper(env)
    return True

def place(env, obj_name, pointing_to="down"):
    if not move_gripper_to(env, obj_name, pointing_to=pointing_to):
        return False
    open_gripper(env)
    return True

# Robotiq85 gripper skills
def open_robotiq85(env):
    if env.ee_type != "robotiq85":
        print("[env] open_robotiq85 called but end-effector is not robotiq85.")
        return

    if env._welded["active"]:
        rigid = env.env.scene.sim.rigid_solver
        link_object = env._welded["obj_link_idx"]
        link_franka = env.franka.get_link(env.ee_name).idx
        rigid.delete_weld_constraint(link_object, link_franka)
        env._welded.update(
            dict(
                active=False,
                object=None,
                obj_link_idx=None,
            )
        )

    # finger_joint: 0 (closed) ~ 0.725 (open)
    src_qpos = env.env.franka.get_qpos().cpu().numpy()[7]
    tgt_qpos = 0.7
    steps = 50

    for i in range(steps):
        alpha = i / steps
        interp = (1 - alpha) * src_qpos + alpha * tgt_qpos
        env.env.franka.control_dofs_position([interp], [7])
        env.step()

    env.env.scene_objects["gripper"].gripper_open = True
    env.trajectory.append(env.env.capture_obs())

def close_robotiq85(env):
    if env.ee_type != "robotiq85":
        print("[env] close_robotiq85 called but end-effector is not robotiq85.")
        return

    # finger_joint: 0 (closed) ~ 0.725 (open)
    # Use force control for grasping
    env.env.franka.control_dofs_force(np.array([-50.0]), [7])
    for _ in range(50):
        env.step()

    if not env._grasp["active"] and not env._welded["active"]:
        for obj_name in env.env.scene_objects:
            if obj_name in ["gripper", "floor"] or "handle" in obj_name:
                continue

            scene_obj = env.env.scene_objects[obj_name]

            # DEBUG: Check obj_in_gripper result
            in_gripper = env.obj_in_gripper(obj_name)
            print(f"[DEBUG close_robotiq85] Checking {obj_name}: in_gripper={in_gripper}")

            if in_gripper:
                rigid = env.env.scene.sim.rigid_solver

                if not hasattr(scene_obj, "links"):
                    print(f"[DEBUG close_robotiq85] {obj_name} has no links attribute, skipping...")
                    continue

                link_object = scene_obj.links[0].idx
                link_franka = env.franka.get_link(env.ee_name).idx

                rigid.add_weld_constraint(link_object, link_franka)
                env._welded.update(
                    dict(
                        active=True,
                        object=obj_name,
                        obj_link_idx=link_object,
                    )
                )
                print(f"[DEBUG close_robotiq85] Welded {obj_name} to gripper!")
                break

    env.env.scene_objects["gripper"].gripper_open = False
    env.trajectory.append(env.env.capture_obs())

def grasp_handle_robotiq85(env, handle_name):
    if env.ee_type != "robotiq85":
        print("[env] grasp_handle_robotiq85 called but end-effector is not robotiq85.")
        return False

    if env._grasp["active"] and env._grasp["object"] == handle_name:
        close_robotiq85(env)
        return True

    if "handle" not in handle_name or handle_name not in env.env.scene_objects:
        return False

    if env._welded["active"]:
        return False

    handle = env.env.scene_objects[handle_name]
    ee = env.franka.get_link(env.ee_name)

    handle_pos = env.get_obj_pos(handle_name)
    gripper_pos = ee.get_pos().cpu().numpy() - get_gripper_offset()

    displacement = gripper_pos - handle_pos
    distance = np.linalg.norm(displacement)

    if distance > 0.1:
        return False

    env._grasp.update(
        dict(
            active=True,
            object=handle_name,
            obj_link_idx=handle.idx,
        )
    )
    close_robotiq85(env)
    return True

def release_handle_robotiq85(env):
    if env.ee_type != "robotiq85":
        print("[env] release_handle_robotiq85 called but end-effector is not robotiq85.")
        return

    env._grasp.update(
        dict(
            active=False,
            object=None,
            obj_link_idx=None,
        )
    )
    open_robotiq85(env)

def pick_robotiq85(env, obj_name, pointing_to="down"):
    if env.ee_type != "robotiq85":
        print("[env] pick_robotiq85 called but end-effector is not robotiq85.")
        return False

    if not move_gripper_to(env, obj_name, pointing_to=pointing_to):
        return False
    close_robotiq85(env)
    return True

def place_robotiq85(env, obj_name, pointing_to="down"):
    if env.ee_type != "robotiq85":
        print("[env] place_robotiq85 called but end-effector is not robotiq85.")
        return False

    if not move_gripper_to(env, obj_name, pointing_to=pointing_to):
        return False
    open_robotiq85(env)
    return True

# Vacuum gripper skills
def activate_vacuum(env):
    if env.ee_type != "suction":
        print("[env] activate_vacuum called but end-effector is not suction.")
        return

    strict = env.ee_strict
    env.ee_strict = False
    close_gripper(env)
    env.ee_strict = strict

def deactivate_vacuum(env):
    if env.ee_type != "suction":
        print("[env] deactivate_vacuum called but end-effector is not suction.")
        return

    strict = env.ee_strict
    env.ee_strict = False
    open_gripper(env)
    env.ee_strict = strict

def attach_vacuum_handle(env, handle_name):
    if env.ee_type != "suction":
        print("[env] attach_vacuum_handle called but end-effector is not suction.")
        return False

    strict = env.ee_strict
    env.ee_strict = False
    result = grasp_handle(env, handle_name)
    env.ee_strict = strict
    return result

def detach_vacuum_handle(env):
    if env.ee_type != "suction":
        print("[env] detach_vacuum_handle called but end-effector is not suction.")
        return

    strict = env.ee_strict
    env.ee_strict = False
    release_handle(env)
    env.ee_strict = strict

# Helper functions
def _execute_joint_motion(env, goal_qpos, steps=80):
    if torch.is_tensor(goal_qpos):
        goal_qpos = goal_qpos.detach().cpu().numpy()
    goal_qpos = np.asarray(goal_qpos, dtype=np.float32)

    current_qpos = env.env.franka.get_qpos().cpu().numpy()[:7].astype(np.float32)
    lin = np.linspace(0, 1, steps)
    alphas = np.where(lin < 0.5, 2 * lin**2, 1 - 2 * (1 - lin) ** 2)

    for alpha in alphas:
        interp = (1 - alpha) * current_qpos + alpha * goal_qpos
        env.env.franka.control_dofs_position(interp, range(7))

        if not env.env.scene_objects["gripper"].gripper_open:
            env.env.franka.control_dofs_force(np.array([-0.5, -0.5]), [7, 8])

        env.step()

    env.env.franka.control_dofs_position(goal_qpos, range(7))
    for _ in range(50):
        env.step()

def _feasibility_check(env, obj_name):
    # Put in hinge
    if any(
        [
            obj_name in ["apple", "orange", "lemon"],
            "cube" in obj_name,
            "cylinder" in obj_name,
        ]
    ):
        floor_low, floor_high = map(np.asarray, FLOOR_AABB)
        ol, oh = map(np.asarray, env.get_obj_bbox(obj_name))
        if (
            (ol[0] < floor_low[0])
            or (oh[0] > floor_high[0])
            or (ol[1] < floor_low[1])
            or (oh[1] > floor_high[1])
        ):
            return False

    # Pick place
    if "cube" in obj_name or "cylinder" in obj_name:
        target_bbox = env.get_obj_bbox(obj_name)
        if target_bbox is not None:
            tl, th = map(np.asarray, target_bbox)

            tol_xy = 0.005
            tol_z = 0.004

            for other_name, other in env.env.scene_objects.items():
                if other_name == obj_name:
                    continue
                if not ("cube" in other_name or "cylinder" in other_name):
                    continue

                other_bbox = env.get_obj_bbox(other_name)
                if other_bbox is None:
                    continue

                ol, oh = map(np.asarray, other_bbox)

                overlap_x = (tl[0] - tol_xy) < oh[0] and (th[0] + tol_xy) > ol[0]
                overlap_y = (tl[1] - tol_xy) < oh[1] and (th[1] + tol_xy) > ol[1]
                if not (overlap_x and overlap_y):
                    continue

                if ol[2] >= th[2] - tol_z and ol[2] > tl[2] + tol_z:
                    return False

    # Put in prismatic
    if obj_name in ["bottom_drawer_handle"]:
        top_drawer = env.env.scene_objects.get("top_drawer")
        if top_drawer is not None:
            top_bbox = env.get_obj_bbox("top_drawer")
            handle_bbox = env.get_obj_bbox(obj_name)

            if top_bbox is not None and handle_bbox is not None:
                tl, th = map(np.asarray, top_bbox)
                hl, hh = map(np.asarray, handle_bbox)

                overlap_x = (hl[0] - 0.005) < th[0] and (hh[0] + 0.005) > tl[0]
                overlap_y = (hl[1] - 0.005) < th[1] and (hh[1] + 0.005) > tl[1]
                overlaps_xy = overlap_x and overlap_y

                if overlaps_xy:
                    return False
    return True
