"""
Grasp and Lift utilities for pick-and-place tasks.

This module contains hard-coded functions for:
1. descend_and_grasp: DESCEND (pregrasp → grasp) + GRASP (close gripper)
2. lift: LIFT (grasp → lift position)
3. descend_and_release: DESCEND_RELEASE (prerelease → release) + RELEASE (open gripper)

These phases are deterministic and don't need to be learned by a diffusion model.
The diffusion model only needs to learn REACH and CARRY phases.
"""

import numpy as np
from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.shape import Shape


def linear_interpolate(p1, p2, t):
    """Linear interpolation between two points."""
    return p1 + t * (p2 - p1)


def capture_frame(task_env):
    """Capture front camera RGB frame."""
    obs = task_env._scene.get_observation()
    if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
        return obs.front_rgb.copy()
    return None


def descend_and_grasp(task_env, pregrasp_pos, grasp_pos, orientation,
                      object_shape_name='stack_blocks_target0',
                      descend_steps=8, grasp_steps=8,
                      steps_per_point=5, capture_video=False,
                      verbose=True, target_object=None):
    """
    Execute DESCEND and GRASP phases.

    DESCEND: Move from pregrasp position down to grasp position (linear Z).
    GRASP: Close gripper and attach object.

    Args:
        task_env: RLBench task environment
        pregrasp_pos: Starting position (above object)
        grasp_pos: Grasp position (at object)
        orientation: End-effector orientation (euler angles)
        object_shape_name: Name of object to grasp (used if target_object is None)
        descend_steps: Number of steps for descend phase
        grasp_steps: Number of steps for grasp phase
        steps_per_point: Physics steps per trajectory point
        capture_video: Whether to capture video frames
        verbose: Whether to print debug info
        target_object: Direct object reference (if provided, overrides object_shape_name)

    Returns:
        dict with:
            trace: list of tip positions
            phase_labels: list of phase names
            gripper_states: list of gripper states (1.0=open, 0.0=closed)
            ik_failures: number of IK failures
            grasp_success: whether gripper.grasp() succeeded
            object_z_before: object Z before grasp
            frames: list of RGB frames (if capture_video=True)
            prev_joints: joint positions after grasp (for next phase)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    trace = []
    phase_labels = []
    gripper_states = []
    ik_failures = 0
    frames = []

    prev_joints = list(robot.arm.get_joint_positions())

    # Get object reference - use target_object if provided, else lookup by name
    if target_object is None:
        target_object = Shape(object_shape_name)

    # ========================================================================
    # Phase: DESCEND (linear) - pregrasp → grasp
    # ========================================================================
    if verbose:
        print(f"\n  DESCEND: pregrasp={pregrasp_pos} -> grasp={grasp_pos}")

    for i in range(descend_steps):
        t = i / max(descend_steps - 1, 1)
        target_pos = linear_interpolate(pregrasp_pos, grasp_pos, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper open during descend
            gripper.actuate(1.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("descend")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)

            if verbose and (i == 0 or i == descend_steps - 1):
                error = np.linalg.norm(actual_pos - target_pos) * 1000
                print(f"    step {i}: target_z={target_pos[2]:.4f}, actual_z={actual_pos[2]:.4f}, error={error:.1f}mm")

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            if verbose:
                print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy() if trace else pregrasp_pos.copy())
            phase_labels.append("descend")
            gripper_states.append(1.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # ========================================================================
    # Phase: GRASP - hold position, close gripper, then ATTACH object
    # ========================================================================
    object_z_before = target_object.get_position()[2]
    if verbose:
        print(f"\n  GRASP: hold at grasp_pos={grasp_pos}, object_z_before={object_z_before:.4f}")

    for i in range(grasp_steps):
        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(grasp_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Close gripper
            gripper.actuate(0.0, 0.2)

            # Extra physics steps for gripper to close
            for _ in range(steps_per_point + 5):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("grasp")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            if verbose and i == grasp_steps - 1:
                gripper_open_amount = gripper.get_open_amount()[0]
                print(f"    After GRASP: ee_z={actual_pos[2]:.4f}, obj_z={target_object.get_position()[2]:.4f}, gripper_open={gripper_open_amount:.3f}")

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            if verbose:
                print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy() if trace else grasp_pos.copy())
            phase_labels.append("grasp")
            gripper_states.append(0.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # IMPORTANT: After gripper closes, use gripper.grasp() to ATTACH object to gripper!
    grasp_success = gripper.grasp(target_object)
    if verbose:
        print(f"    gripper.grasp() result: {grasp_success}")

    return {
        'trace': trace,
        'phase_labels': phase_labels,
        'gripper_states': gripper_states,
        'ik_failures': ik_failures,
        'grasp_success': grasp_success,
        'object_z_before': object_z_before,
        'frames': frames,
        'prev_joints': prev_joints,
    }


def lift(task_env, grasp_pos, lift_pos, orientation,
         object_shape_name='stack_blocks_target0',
         lift_steps=8, steps_per_point=5,
         capture_video=False, verbose=True,
         prev_joints=None, target_object=None):
    """
    Execute LIFT phase.

    LIFT: Move from grasp position up to lift position (linear Z).

    Args:
        task_env: RLBench task environment
        grasp_pos: Starting position (at object)
        lift_pos: Lift position (above object, usually same as pregrasp)
        orientation: End-effector orientation (euler angles)
        object_shape_name: Name of object being lifted (used if target_object is None)
        lift_steps: Number of steps for lift phase
        steps_per_point: Physics steps per trajectory point
        capture_video: Whether to capture video frames
        verbose: Whether to print debug info
        prev_joints: Previous joint positions (optional, will use current if None)
        target_object: Direct object reference (if provided, overrides object_shape_name)

    Returns:
        dict with:
            trace: list of tip positions
            phase_labels: list of phase names
            gripper_states: list of gripper states
            ik_failures: number of IK failures
            object_lifted: whether object was lifted (Z rise > 2cm)
            object_z_before: object Z before lift
            object_z_after: object Z after lift
            frames: list of RGB frames (if capture_video=True)
            prev_joints: joint positions after lift (for next phase)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    trace = []
    phase_labels = []
    gripper_states = []
    ik_failures = 0
    frames = []

    if prev_joints is None:
        prev_joints = list(robot.arm.get_joint_positions())
    else:
        prev_joints = list(prev_joints)

    # Get object reference - use target_object if provided, else lookup by name
    if target_object is None:
        target_object = Shape(object_shape_name)
    object_z_before = target_object.get_position()[2]

    # ========================================================================
    # Phase: LIFT (linear) - grasp → lift_pos
    # ========================================================================
    if verbose:
        print(f"\n  LIFT: grasp={grasp_pos} -> lift={lift_pos}")

    for i in range(lift_steps):
        t = i / max(lift_steps - 1, 1)
        target_pos = linear_interpolate(grasp_pos, lift_pos, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper closed during lift
            gripper.actuate(0.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("lift")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            if verbose and (i == 0 or i == lift_steps - 1):
                obj_z = target_object.get_position()[2]
                print(f"    step {i}: ee_z={actual_pos[2]:.4f}, obj_z={obj_z:.4f}")

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            if verbose:
                print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy() if trace else grasp_pos.copy())
            phase_labels.append("lift")
            gripper_states.append(0.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # Check if object was lifted
    object_z_after = target_object.get_position()[2]
    object_lifted = (object_z_after - object_z_before) > 0.02  # Lifted at least 2cm
    if verbose:
        print(f"    Object lifted: {object_lifted}, Z rise: {(object_z_after - object_z_before)*1000:.1f}mm")

    return {
        'trace': trace,
        'phase_labels': phase_labels,
        'gripper_states': gripper_states,
        'ik_failures': ik_failures,
        'object_lifted': object_lifted,
        'object_z_before': object_z_before,
        'object_z_after': object_z_after,
        'frames': frames,
        'prev_joints': prev_joints,
    }


def descend_and_release(task_env, prerelease_pos, release_pos, orientation,
                        object_shape_name='stack_blocks_target0',
                        descend_steps=8, release_steps=8,
                        steps_per_point=5, capture_video=False,
                        verbose=True, prev_joints=None, target_object=None):
    """
    Execute DESCEND_RELEASE and RELEASE phases.

    DESCEND_RELEASE: Move from prerelease position down to release position (linear Z).
    RELEASE: Open gripper and detach object.

    Args:
        task_env: RLBench task environment
        prerelease_pos: Starting position (above target)
        release_pos: Release position (at target)
        orientation: End-effector orientation (euler angles)
        object_shape_name: Name of object being released (used if target_object is None)
        descend_steps: Number of steps for descend phase
        release_steps: Number of steps for release phase
        steps_per_point: Physics steps per trajectory point
        capture_video: Whether to capture video frames
        verbose: Whether to print debug info
        prev_joints: Previous joint positions (optional, will use current if None)
        target_object: Direct object reference (if provided, overrides object_shape_name)

    Returns:
        dict with:
            trace: list of tip positions
            phase_labels: list of phase names
            gripper_states: list of gripper states
            ik_failures: number of IK failures
            object_released: whether object is within 10cm of target XY
            object_final_pos: final object position
            xy_distance: distance from target XY
            frames: list of RGB frames (if capture_video=True)
            prev_joints: joint positions after release (for next phase)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    trace = []
    phase_labels = []
    gripper_states = []
    ik_failures = 0
    frames = []

    if prev_joints is None:
        prev_joints = list(robot.arm.get_joint_positions())
    else:
        prev_joints = list(prev_joints)

    # Get object reference - use target_object if provided, else lookup by name
    if target_object is None:
        target_object = Shape(object_shape_name)

    # ========================================================================
    # Phase: DESCEND_RELEASE (linear) - prerelease → release
    # ========================================================================
    if verbose:
        print(f"\n  DESCEND_RELEASE: prerelease={prerelease_pos} -> release={release_pos}")

    for i in range(descend_steps):
        t = i / max(descend_steps - 1, 1)
        target_pos = linear_interpolate(prerelease_pos, release_pos, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper still closed during descend
            gripper.actuate(0.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("descend_release")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            if verbose and (i == 0 or i == descend_steps - 1):
                error = np.linalg.norm(actual_pos - target_pos) * 1000
                print(f"    step {i}: target_z={target_pos[2]:.4f}, actual_z={actual_pos[2]:.4f}, error={error:.1f}mm")

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            if verbose:
                print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy() if trace else prerelease_pos.copy())
            phase_labels.append("descend_release")
            gripper_states.append(0.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # ========================================================================
    # Phase: RELEASE - hold position, open gripper, then DETACH object
    # ========================================================================
    object_z_before_release = target_object.get_position()[2]
    if verbose:
        print(f"\n  RELEASE: hold at release_pos={release_pos}, object_z_before={object_z_before_release:.4f}")

    for i in range(release_steps):
        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(release_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Open gripper
            gripper.actuate(1.0, 0.2)

            # Extra physics steps for gripper to open
            for _ in range(steps_per_point + 5):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("release")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)

            if verbose and i == release_steps - 1:
                gripper_open_amount = gripper.get_open_amount()[0]
                print(f"    After RELEASE: ee_z={actual_pos[2]:.4f}, obj_z={target_object.get_position()[2]:.4f}, gripper_open={gripper_open_amount:.3f}")

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            if verbose:
                print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy() if trace else release_pos.copy())
            phase_labels.append("release")
            gripper_states.append(1.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # IMPORTANT: After gripper opens, use gripper.release() to DETACH object from gripper!
    gripper.release()
    if verbose:
        print(f"    gripper.release() called - object detached")

    # Let object settle after release
    for _ in range(20):
        task_env._scene.pyrep.step()

    object_final_pos = target_object.get_position()

    # Check if object is near target XY
    target_xy = np.array([release_pos[0], release_pos[1]])
    object_xy = np.array([object_final_pos[0], object_final_pos[1]])
    xy_distance = np.linalg.norm(object_xy - target_xy)
    object_released = xy_distance < 0.05  # Within 5cm of target XY

    if verbose:
        print(f"    Object final pos: [{object_final_pos[0]:.4f}, {object_final_pos[1]:.4f}, {object_final_pos[2]:.4f}]")
        print(f"    Target XY: [{target_xy[0]:.4f}, {target_xy[1]:.4f}], Object XY: [{object_xy[0]:.4f}, {object_xy[1]:.4f}]")
        print(f"    XY distance to target: {xy_distance*100:.1f}cm")

    return {
        'trace': trace,
        'phase_labels': phase_labels,
        'gripper_states': gripper_states,
        'ik_failures': ik_failures,
        'object_released': object_released,
        'object_final_pos': object_final_pos,
        'xy_distance': xy_distance,
        'frames': frames,
        'prev_joints': prev_joints,
    }
