# filename: grasp_utils.py
"""
Utilities for grasp (pick_up_cup) dataset generation with control-point trajectory shaping.

Key design:
  - Two parameters define each grasp mode: (approach_angle, grasp_height)
  - Approach angle: horizontal direction the gripper approaches from
  - Grasp height: Z position where gripper contacts the cup
  - Reach phase is shaped using Bezier curves (same as close_drawer)
  - Other phases (descend, hold, close_gripper, lift) are linear

For pick_up_cup task:
  Phase 1: REACH     - shaped trajectory from start to pregrasp position
  Phase 2: DESCEND   - linear descent to grasp position
  Phase 3: HOLD      - dwell at grasp position (gripper open)
  Phase 4: CLOSE     - close gripper while holding position
  Phase 5: LIFT      - lift object up
"""

import numpy as np
from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.shape import Shape

# Import configuration
from grasp_config import (
    HOME_JOINTS,
    PHASE_STEPS as DEFAULT_PHASE_STEPS,
    WAYPOINT_OFFSETS as DEFAULT_WAYPOINT_OFFSETS,
    CONTROL_POINT_RADIUS,
    DEFAULT_GRASP_ORIENTATION,
    get_grasp_orientation,
    LIFT_SUCCESS_THRESHOLD,
)


# ============================================================================
# Cup Position Functions
# ============================================================================

def get_cup_position(task_env, cup_name="cup1"):
    """
    Get the position and orientation of the target cup.

    Args:
        task_env: RLBench task environment
        cup_name: str, name of the cup ("cup1" is target, "cup2" is distractor)

    Returns:
        cup_pos: np.ndarray(3,), position of the cup center
        cup_ori: np.ndarray(3,), orientation of the cup (euler angles)
    """
    task = task_env._task
    if cup_name == "cup1":
        cup = task.cup1
    else:
        cup = task.cup2

    cup_pos = np.array(cup.get_position())
    cup_ori = np.array(cup.get_orientation())

    return cup_pos, cup_ori


def get_cup_dimensions(task_env, cup_name="cup1"):
    """
    Get the bounding box dimensions of the cup.

    Returns:
        tuple: (width, depth, height) of the cup
    """
    task = task_env._task
    if cup_name == "cup1":
        cup = task.cup1
    else:
        cup = task.cup2

    bbox = cup.get_bounding_box()
    # bbox format: [min_x, max_x, min_y, max_y, min_z, max_z]
    width = bbox[1] - bbox[0]
    depth = bbox[3] - bbox[2]
    height = bbox[5] - bbox[4]

    return width, depth, height


def get_cup_base_z(task_env, cup_name="cup1"):
    """
    Get the Z position of the cup base (bottom).

    The cup.get_position() returns the center of the cup.
    We need the base position for calculating grasp heights.

    Args:
        task_env: RLBench task environment
        cup_name: str, name of the cup

    Returns:
        float: Z position of cup base in world coordinates
    """
    task = task_env._task
    if cup_name == "cup1":
        cup = task.cup1
    else:
        cup = task.cup2

    cup_center = cup.get_position()
    bbox = cup.get_bounding_box()
    # bbox[4] is min_z in local coordinates (relative to center)
    # Cup base Z = center Z + min_z offset
    cup_base_z = cup_center[2] + bbox[4]

    return cup_base_z


def set_cup_position(task_env, position, cup_name="cup1", max_retries=5, tolerance=0.005):
    """
    Set the cup to a fixed position with verification.

    Args:
        task_env: RLBench task environment
        position: np.ndarray(3,), target position [X, Y, Z]
        cup_name: str, name of the cup to move
        max_retries: int, number of times to retry if cup drifts
        tolerance: float, maximum allowed drift in meters (default 5mm)

    Returns:
        np.ndarray(3,): actual position after setting

    Raises:
        RuntimeError: if cup position cannot be set within tolerance
    """
    task = task_env._task
    if cup_name == "cup1":
        cup = task.cup1
        cup_visual = task.cup1_visual
    else:
        cup = task.cup2
        cup_visual = task.cup2_visual

    position = np.array(position)

    for attempt in range(max_retries):
        # Set position for both physics object and visual mesh
        cup.set_position(position)
        cup_visual.set_position(position)

        # Step simulation to apply and let physics settle
        for _ in range(30):
            task_env._scene.pyrep.step()

        # Verify position
        actual_pos = np.array(cup.get_position())
        drift = np.linalg.norm(actual_pos[:2] - position[:2])  # Check XY drift

        if drift <= tolerance:
            return actual_pos

    # If we get here, cup keeps drifting
    raise RuntimeError(f"Cup position drift ({drift:.4f}m) exceeds tolerance ({tolerance}m) after {max_retries} attempts")


def reset_robot_to_default(task_env):
    """
    Reset the robot arm to its default (initial) joint positions.

    This should be called BEFORE task_env.reset() to avoid collision errors.

    Args:
        task_env: RLBench task environment
    """
    robot = task_env._scene.robot

    # Get the initial joint positions stored by the scene
    default_joints = task_env._scene._start_arm_joint_pos

    # Reset arm to default position
    robot.arm.set_joint_positions(default_joints, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)

    # Release gripper and reset
    robot.gripper.release()
    default_gripper = task_env._scene._starting_gripper_joint_pos
    robot.gripper.set_joint_positions(default_gripper, disable_dynamics=True)

    # Step simulation to apply
    for _ in range(5):
        task_env._scene.pyrep.step()


# ============================================================================
# Waypoint Computation
# ============================================================================

def compute_grasp_waypoints(start_pos, cup_pos, cup_base_z, approach_angle, grasp_height, waypoint_params=None):
    """
    Compute the key waypoints for a grasp trajectory.

    The approach is computed as:
    1. Grasp: offset from cup center in approach direction, at specified height
       (EE is ~4cm from cup center based on official RLBench demos)
    2. Pregrasp: further offset from grasp position, above grasp height
    3. Lift: above grasp position

    Args:
        start_pos: np.ndarray(3,), starting end-effector position (HOME)
        cup_pos: np.ndarray(3,), position of cup center
        cup_base_z: float, Z position of cup base (bottom) in world coords
        approach_angle: float, angle in radians (direction to approach from in XY plane)
        grasp_height: float, Z offset above cup base for grasping
        waypoint_params: dict with offset parameters

    Returns:
        dict with keys: start, pregrasp, grasp, lift
    """
    if waypoint_params is None:
        waypoint_params = DEFAULT_WAYPOINT_OFFSETS.copy()

    pregrasp_offset_z = waypoint_params.get("pregrasp_offset_z", 0.08)
    approach_offset_xy = waypoint_params.get("approach_offset_xy", 0.10)
    gripper_offset = waypoint_params.get("gripper_offset_from_cup", 0.04)
    lift_height = waypoint_params.get("lift_height", 0.15)

    # Grasp position: offset from cup center in approach direction
    # The EE doesn't go exactly to cup center - it's offset ~4cm away
    # This is because the gripper fingers wrap around the cup rim
    grasp = cup_pos.copy()
    grasp[0] += gripper_offset * np.cos(approach_angle)  # Offset in approach direction
    grasp[1] += gripper_offset * np.sin(approach_angle)
    grasp[2] = cup_base_z + grasp_height  # Absolute Z = cup base + offset

    # Pregrasp: further offset horizontally from grasp based on approach angle, and above grasp
    # approach_angle = 0 means approach from +X direction
    # approach_angle = 90 means approach from +Y direction
    pregrasp = grasp.copy()
    pregrasp[0] += approach_offset_xy * np.cos(approach_angle)
    pregrasp[1] += approach_offset_xy * np.sin(approach_angle)
    pregrasp[2] += pregrasp_offset_z

    # Lift: straight up from grasp position
    lift = grasp.copy()
    lift[2] += lift_height

    return {
        "start": start_pos.copy(),
        "pregrasp": pregrasp,
        "grasp": grasp,
        "lift": lift,
    }


# ============================================================================
# Control Point Functions (same as close_drawer/reach tasks)
# ============================================================================

def build_local_frame(start_pos, target_pos):
    """Return (line_vec_norm, perp1, perp2) forming an orthonormal frame."""
    line_vec = target_pos - start_pos
    line_len = np.linalg.norm(line_vec)
    if line_len < 1e-6:
        # Degenerate case: start and target are same
        return np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])

    line_vec_norm = line_vec / line_len
    if abs(line_vec_norm[2]) < 0.9:
        perp1 = np.cross(line_vec_norm, np.array([0, 0, 1]))
    else:
        perp1 = np.cross(line_vec_norm, np.array([1, 0, 0]))
    perp1 = perp1 / np.linalg.norm(perp1)
    perp2 = np.cross(line_vec_norm, perp1)
    return line_vec_norm, perp1, perp2


def parabola3D(S, E, C, t):
    """Quadratic Bezier curve in 3D with endpoints S, E and control point C."""
    M = 0.5 * (S + E)
    P1 = 2 * C - M
    return (1 - t) * (1 - t) * S + 2 * (1 - t) * t * P1 + t * t * E


def compute_control_point_from_params(start_pos, target_pos, radius, angle, dist_frac, pos_frac):
    """
    Compute a control point given canonical parameters.

    Args:
        start_pos: np.ndarray(3,), start position
        target_pos: np.ndarray(3,), target position
        radius: float, maximum offset radius (meters)
        angle: float, angle around the line (radians)
        dist_frac: float, fraction of radius for offset distance
        pos_frac: float, fraction along start->target line for base position

    Returns:
        control_point: np.ndarray(3,)
    """
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, target_pos)

    # Base position along the line
    base_pos = start_pos + pos_frac * (target_pos - start_pos)

    # Offset perpendicular to the line
    offset_dist = dist_frac * radius
    offset = offset_dist * (np.cos(angle) * perp1 + np.sin(angle) * perp2)

    return base_pos + offset


def generate_reach_control_point(start_pos, pregrasp_pos, approach_angle, control_point_radius):
    """
    Generate a control point for the reach phase based on approach angle.

    The control point is offset perpendicular to the start->pregrasp line,
    in a direction consistent with the approach angle.

    Args:
        start_pos: np.ndarray(3,), starting position (HOME)
        pregrasp_pos: np.ndarray(3,), pregrasp position
        approach_angle: float, approach direction in radians
        control_point_radius: float, offset radius

    Returns:
        control_point: np.ndarray(3,)
    """
    # Use approach angle to determine control point offset
    # Map approach angle to control point parameters
    dist_frac = 0.7  # Strong curve
    pos_frac = 0.5   # Control point at middle of trajectory

    return compute_control_point_from_params(
        start_pos, pregrasp_pos,
        control_point_radius, approach_angle, dist_frac, pos_frac
    )


# ============================================================================
# Trajectory Generation
# ============================================================================

def linear_interpolate(p1, p2, t):
    """Linear interpolation between two points."""
    return p1 + t * (p2 - p1)


def generate_phase_positions(waypoints, phase_steps, cp_reach):
    """
    Generate all EE positions for the grasp trajectory.

    Phase 1: REACH (shaped) - start -> pregrasp (using Bezier curve)
    Phase 2: DESCEND (linear) - pregrasp -> grasp
    Phase 3: HOLD_GRASP (dwell) - stay at grasp position
    Phase 4: CLOSE_GRIPPER (dwell) - stay at grasp position while closing
    Phase 5: LIFT (linear) - grasp -> lift

    Args:
        waypoints: dict from compute_grasp_waypoints
        phase_steps: dict with step counts per phase
        cp_reach: np.ndarray(3,), control point for reach phase

    Returns:
        positions: list of np.ndarray(3,), one per step
        gripper_states: list of float (1.0 = open, 0.0 = closed)
        phase_indices: dict mapping phase name to (start_idx, end_idx)
        phase_labels: list of str, phase name for each step
    """
    positions = []
    gripper_states = []
    phase_labels = []
    phase_indices = {}
    current_idx = 0

    # Phase 1: REACH (shaped) - start -> pregrasp
    n = phase_steps["reach"]
    phase_indices["reach"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = parabola3D(waypoints["start"], waypoints["pregrasp"], cp_reach, t)
        positions.append(pos)
        gripper_states.append(1.0)  # Open
        phase_labels.append("reach")
    current_idx += n

    # Phase 2: DESCEND (linear) - pregrasp -> grasp
    n = phase_steps["descend"]
    phase_indices["descend"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["pregrasp"], waypoints["grasp"], t)
        positions.append(pos)
        gripper_states.append(1.0)  # Open
        phase_labels.append("descend")
    current_idx += n

    # Phase 3: HOLD_GRASP (dwell) - stay at grasp position, gripper open
    n = phase_steps["hold_grasp"]
    phase_indices["hold_grasp"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["grasp"].copy())
        gripper_states.append(1.0)  # Open (settling)
        phase_labels.append("hold_grasp")
    current_idx += n

    # Phase 4: CLOSE_GRIPPER (dwell) - stay at grasp position, close gripper
    n = phase_steps["close_gripper"]
    phase_indices["close_gripper"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["grasp"].copy())
        gripper_states.append(0.0)  # Closed
        phase_labels.append("close_gripper")
    current_idx += n

    # Phase 5: LIFT (linear) - grasp -> lift
    n = phase_steps["lift"]
    phase_indices["lift"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["grasp"], waypoints["lift"], t)
        positions.append(pos)
        gripper_states.append(0.0)  # Closed
        phase_labels.append("lift")
    current_idx += n

    return positions, gripper_states, phase_indices, phase_labels


def move_robot_to_start(task_env, start_pos, desired_ori, gripper_open=True):
    """
    Move robot to HOME configuration and prepare for trajectory.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), requested start position (currently uses HOME)
        desired_ori: np.ndarray(3,), desired orientation
        gripper_open: bool, whether gripper should be open

    Returns:
        actual_start_pos: np.ndarray(3,), the actual start position (HOME position)
        actual_ori: np.ndarray(3,), the orientation to use for trajectory
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME configuration first
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(10):
        task_env._scene.pyrep.step()

    # Set gripper state
    gripper_velocity = 0.2
    if gripper_open:
        gripper.actuate(1.0, gripper_velocity)
    else:
        gripper.actuate(0.0, gripper_velocity)

    for _ in range(20):
        task_env._scene.pyrep.step()

    # Get actual position from HOME
    home_pos = tip.get_position()

    task_env._scene.pyrep.step()

    print(f"  HOME EE position: [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")

    # Return the requested desired orientation for the grasp trajectory
    # (not the HOME orientation, which may not be suitable for grasping)
    return home_pos.copy(), desired_ori


def execute_step(task_env, target_ee_pos, desired_ori, gripper_state, extra_steps=0):
    """
    Execute one step with position and gripper control.

    Args:
        task_env: RLBench task environment
        target_ee_pos: np.ndarray(3,), target end-effector position
        desired_ori: np.ndarray(3,), desired orientation (euler)
        gripper_state: float, 1.0 = open, 0.0 = closed
        extra_steps: int, additional physics steps

    Returns:
        obs: observation after the step
        joint_positions: the commanded joint positions

    Raises:
        IKError or ConfigurationPathError: if IK fails
    """
    robot = task_env._scene.robot
    gripper = robot.gripper

    # Solve IK - try with orientation first, then position-only if needed
    try:
        joint_positions = robot.arm.solve_ik(target_ee_pos, euler=desired_ori)
    except (IKError, ConfigurationPathError):
        # Fall back to position-only IK
        joint_positions = robot.arm.solve_ik(target_ee_pos)

    # Use set_joint_positions with disable_dynamics=True for reliable movement
    robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)

    # Set gripper
    gripper.actuate(gripper_state, 0.2)

    # Step simulation
    num_steps = 5 + extra_steps
    for _ in range(num_steps):
        task_env._scene.pyrep.step()
        task_env._scene.task.step()

    obs = task_env._scene.get_observation()

    if not hasattr(obs, 'misc'):
        obs.misc = {}
    obs.misc['joint_position_action'] = np.concatenate(
        [joint_positions, np.array([gripper_state])]
    )

    return obs, joint_positions


def generate_grasp_trajectory(
    task_env,
    start_pos,
    cup_pos,
    cup_ori,
    approach_angle,
    grasp_height,
    control_point_radius=None,
    waypoint_params=None,
    phase_steps=None,
    steps_per_point=5,
):
    """
    Generate a full grasp trajectory with shaped reach phase.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), starting EE position
        cup_pos: np.ndarray(3,), cup position
        cup_ori: np.ndarray(3,), cup orientation (euler)
        approach_angle: float, approach direction in radians
        grasp_height: float, grasp height above cup base in meters
        control_point_radius: float, radius for control point offset
        waypoint_params: dict, waypoint offset parameters
        phase_steps: dict, step counts per phase
        steps_per_point: int, number of physics steps per trajectory point

    Returns:
        demo: list of observations
        metadata: dict with trajectory info (includes 'trace' and 'phase_labels' for video)

    Raises:
        RuntimeError: if trajectory generation fails
    """
    if phase_steps is None:
        phase_steps = DEFAULT_PHASE_STEPS.copy()
    if control_point_radius is None:
        control_point_radius = CONTROL_POINT_RADIUS

    total_steps = sum(phase_steps.values())
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Get cup base Z position (bottom of cup in world coords)
    cup_base_z = get_cup_base_z(task_env)

    # Compute grasp orientation based on approach angle
    # This aligns the gripper fingers perpendicular to the approach direction
    grasp_ori = get_grasp_orientation(approach_angle)

    # Move to HOME position first (use grasp orientation for trajectory)
    actual_start_pos, actual_ori = move_robot_to_start(
        task_env, start_pos, grasp_ori, gripper_open=True
    )

    # Record HOME orientation for smooth interpolation during reach phase.
    # Using grasp_ori from step 1 causes Jacobian IK failure for some angles
    # (e.g. 120°, 150°) because the orientation gap from HOME is too large.
    home_ori = np.array(tip.get_orientation())

    # Compute waypoints using actual start position and cup base Z
    waypoints = compute_grasp_waypoints(
        actual_start_pos, cup_pos, cup_base_z, approach_angle, grasp_height, waypoint_params
    )

    # Debug: print waypoints
    print(f"  Waypoints:")
    print(f"    start:    [{waypoints['start'][0]:.4f}, {waypoints['start'][1]:.4f}, {waypoints['start'][2]:.4f}]")
    print(f"    pregrasp: [{waypoints['pregrasp'][0]:.4f}, {waypoints['pregrasp'][1]:.4f}, {waypoints['pregrasp'][2]:.4f}]")
    print(f"    grasp:    [{waypoints['grasp'][0]:.4f}, {waypoints['grasp'][1]:.4f}, {waypoints['grasp'][2]:.4f}]")
    print(f"    lift:     [{waypoints['lift'][0]:.4f}, {waypoints['lift'][1]:.4f}, {waypoints['lift'][2]:.4f}]")
    print(f"    cup_pos:  [{cup_pos[0]:.4f}, {cup_pos[1]:.4f}, {cup_pos[2]:.4f}]")
    print(f"    cup_base_z: {cup_base_z:.4f}")
    print(f"    approach: {np.degrees(approach_angle):.1f} deg, height: {grasp_height:.4f} m (above base)")
    print(f"    grasp_ori: [{np.degrees(grasp_ori[0]):.1f}, {np.degrees(grasp_ori[1]):.1f}, {np.degrees(grasp_ori[2]):.1f}] deg")

    # Compute control point for reach phase
    cp_reach = generate_reach_control_point(
        waypoints["start"], waypoints["pregrasp"],
        approach_angle, control_point_radius
    )

    # Generate all target positions and gripper states
    positions, gripper_states, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    print(f"  Trajectory: {total_steps} steps")

    # Execute trajectory - track EE trace for video overlay
    demo = []
    trace = []  # EE positions for video overlay
    trace_phase_labels = []  # Phase labels for coloring
    trace_gripper_states = []

    # Record initial cup position for success check
    initial_cup_z = cup_pos[2]

    # Initial observation
    demo.append(task_env._scene.get_observation())
    trace.append(tip.get_position().copy())
    trace_phase_labels.append("reach")
    trace_gripper_states.append(1.0)

    successful_steps = 1
    failed_steps = 0
    max_failures = total_steps // 4
    prev_joints = list(robot.arm.get_joint_positions())

    # Phases that need extra physics steps for gripper operation
    GRIPPER_PHASES = {"close_gripper"}

    # Track previous gripper state to detect close action
    prev_gripper_state = 1.0  # Start open
    has_grasped = False

    reach_start, reach_end = phase_indices["reach"]

    for i in range(1, len(positions)):
        target_pos_step = positions[i]
        gripper_state = gripper_states[i]
        phase = phase_labels[i]

        # Add extra physics steps during gripper close for reliable grip
        extra_steps = 10 if phase in GRIPPER_PHASES else 0

        # Interpolate orientation during reach phase: HOME ori -> grasp ori
        # This avoids Jacobian IK failure at early steps where the orientation
        # gap from HOME is too large for the local solver.
        if phase == "reach":
            t = (i - reach_start) / max(reach_end - reach_start - 1, 1)
            step_ori = home_ori + t * (actual_ori - home_ori)
        else:
            step_ori = actual_ori

        try:
            # Use Jacobian IK for smooth motion
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos_step, euler=step_ori)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Handle gripper state change - IMPORTANT: call grasp() when closing!
            if gripper_state < 0.5 and prev_gripper_state >= 0.5 and not has_grasped:
                # Closing gripper - try to grasp the cup
                graspable_objects = task_env._task.get_graspable_objects()
                for g_obj in graspable_objects:
                    grasped = gripper.grasp(g_obj)
                    if grasped:
                        has_grasped = True
                        print(f"    Grasped object: {g_obj.get_name()}")
                        break

            # Actuate gripper
            gripper.actuate(gripper_state, 0.2)
            prev_gripper_state = gripper_state

            # Step simulation
            sim_steps = steps_per_point + extra_steps
            for _ in range(sim_steps):
                task_env._scene.pyrep.step()
                task_env._scene.task.step()

            obs = task_env._scene.get_observation()
            if not hasattr(obs, 'misc'):
                obs.misc = {}
            obs.misc['joint_position_action'] = np.concatenate([joint_positions, [gripper_state]])
            demo.append(obs)

            # Track EE position for video
            trace.append(tip.get_position().copy())
            trace_phase_labels.append(phase)
            trace_gripper_states.append(gripper_state)
            prev_joints = list(joint_positions)
            successful_steps += 1

        except (IKError, ConfigurationPathError) as e:
            failed_steps += 1
            print(f"    IK failed at step {i}/{len(positions)}: target=[{target_pos_step[0]:.4f}, {target_pos_step[1]:.4f}, {target_pos_step[2]:.4f}]")
            # Strict mode: reject entire episode on first IK failure
            raise RuntimeError(f"IK failure at step {i}/{len(positions)} - rejecting episode")

    assert len(demo) == total_steps, f"Expected {total_steps} steps, got {len(demo)}"

    # Debug: print final EE position
    final_ee_pos = tip.get_position()
    target_lift = waypoints["lift"]
    ee_error = np.linalg.norm(final_ee_pos - target_lift)
    print(f"  Final EE: [{final_ee_pos[0]:.4f}, {final_ee_pos[1]:.4f}, {final_ee_pos[2]:.4f}]")
    print(f"  Target lift: [{target_lift[0]:.4f}, {target_lift[1]:.4f}, {target_lift[2]:.4f}]")
    print(f"  Position error: {ee_error*100:.2f}cm")

    if failed_steps > 0:
        print(f"\t  IK: {successful_steps} ok, {failed_steps} failed")

    # Build metadata
    metadata = {
        "waypoints": waypoints,
        "waypoints_list": {k: v.tolist() for k, v in waypoints.items()},
        "cp_reach": cp_reach.tolist(),
        "approach_angle": approach_angle,
        "grasp_height": grasp_height,
        "phase_indices": phase_indices,
        "phase_steps": phase_steps,
        "ik_failures": failed_steps,
        "actual_orientation": actual_ori.tolist(),
        "trace": np.array(trace),
        "phase_labels": trace_phase_labels,
        "gripper_states": trace_gripper_states,
        "initial_cup_z": initial_cup_z,
        "gripper_grasped": has_grasped,
    }

    return demo, metadata


def check_task_success(task_env):
    """
    Check if the grasp task was successful using RLBench's built-in check.

    For pick_up_cup:
    - Cup1 must be grasped
    - Cup1 must be lifted (above success sensor)

    Args:
        task_env: RLBench task environment

    Returns:
        bool: True if task succeeded
    """
    try:
        success, _ = task_env._task.success()
        return success
    except Exception:
        return False


def check_grasp_success_manual(task_env, initial_cup_z, lift_threshold=None):
    """
    Manual check for grasp success: did the cup rise?

    Args:
        task_env: RLBench task environment
        initial_cup_z: float, initial Z position of cup
        lift_threshold: float, minimum Z rise required (default from config)

    Returns:
        bool: True if cup was lifted successfully
    """
    if lift_threshold is None:
        lift_threshold = LIFT_SUCCESS_THRESHOLD

    try:
        cup = task_env._task.cup1
        current_cup_z = cup.get_position()[2]
        z_rise = current_cup_z - initial_cup_z

        print(f"  Grasp check: cup rose {z_rise*100:.2f}cm (threshold: {lift_threshold*100:.2f}cm)")
        return z_rise >= lift_threshold
    except Exception as e:
        print(f"  Grasp check failed: {e}")
        return False


# ============================================================================
# Pre-filtering Control Points (for IK feasibility)
# ============================================================================

def prefilter_grasp_modes(robot_arm, start_pos, cup_pos, cup_base_z, orientation,
                          canonical_params, waypoint_params=None):
    """
    Pre-filter grasp modes to find which (angle, height) combinations are IK-feasible.

    Args:
        robot_arm: PyRep arm object
        start_pos: np.ndarray(3,), start position
        cup_pos: np.ndarray(3,), cup position
        cup_base_z: float, Z position of cup base (bottom) in world coords
        orientation: np.ndarray(3,), desired orientation
        canonical_params: np.ndarray of (angle, height) pairs
        waypoint_params: dict, waypoint parameters

    Returns:
        valid_indices: list of valid mode indices
    """
    valid_indices = []

    for idx, (angle, height) in enumerate(canonical_params):
        # Compute waypoints for this mode
        waypoints = compute_grasp_waypoints(start_pos, cup_pos, cup_base_z, angle, height, waypoint_params)

        # Test key waypoints (pregrasp and grasp)
        try:
            robot_arm.solve_ik(waypoints["pregrasp"], euler=orientation)
            robot_arm.solve_ik(waypoints["grasp"], euler=orientation)
            valid_indices.append(idx)
        except (IKError, ConfigurationPathError):
            pass

    return valid_indices
