# filename: push_utils.py
"""
Utilities for push/close drawer dataset generation with control-point trajectory shaping.

Key design:
  - One control-point index (cp_idx) shapes the REACH phase (getting to handle)
  - The PUSH phase is a simple linear motion to close the drawer
  - Each shaped segment (reach) gets its own configurable sub-trajectory
  - Linear phases (push) use fewer steps
  - Uses the same canonical control-point parameterization as reach/pick-place tasks

For close_drawer task:
  Phase 1: REACH - shaped trajectory from start to handle position
  Phase 2: PUSH  - linear push motion to close the drawer
  Phase 3: RETREAT - move away from drawer after push
"""

import numpy as np
from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.dummy import Dummy
from pyrep.objects.joint import Joint

# Import configuration
from close_drawer_config import (
    CABINET_POSITION,
    get_cabinet_orientation_radians,
    HOME_JOINTS,
    PHASE_STEPS as DEFAULT_PHASE_STEPS,
    WAYPOINT_OFFSETS as DEFAULT_WAYPOINT_OFFSETS,
    WORKSPACE_SCALE,
)


# ============================================================================
# Drawer Handle Position Functions
# ============================================================================

def get_drawer_handle_position(task_env, drawer_option="middle"):
    """
    Get the handle position for the specified drawer.

    The close_drawer task has 3 variations: bottom, middle, top.
    Without cabinet rotation, waypoint0 is correctly positioned at the drawer handle
    location after the cabinet is moved to its new position.

    Args:
        task_env: RLBench task environment
        drawer_option: str, one of "bottom", "middle", "top" (or index 0, 1, 2)

    Returns:
        handle_pos: np.ndarray(3,), position of the drawer handle (where gripper should go)
        handle_ori: np.ndarray(3,), orientation for grasping handle (euler angles)
    """
    # waypoint0 is set by RLBench's init_episode to the correct Z height for the
    # selected drawer variation. Without rotation, it moves correctly with the cabinet.
    waypoint = Dummy('waypoint0')
    handle_pos = np.array(waypoint.get_position())
    handle_ori = np.array(waypoint.get_orientation())

    return handle_pos, handle_ori


def get_drawer_joint(task_env, drawer_option="middle"):
    """
    Get the drawer joint for the specified drawer.

    Args:
        task_env: RLBench task environment
        drawer_option: str, one of "bottom", "middle", "top"

    Returns:
        joint: Joint object for the drawer
    """
    options_map = {"bottom": 0, "middle": 1, "top": 2}
    if isinstance(drawer_option, str):
        idx = options_map.get(drawer_option, 1)
    else:
        idx = drawer_option

    options = ["bottom", "middle", "top"]
    joint = Joint(f'drawer_joint_{options[idx]}')
    return joint


def set_drawer_open(task_env, drawer_option="middle", open_amount=0.1):
    """
    Set the drawer to an open position (so we can close it).
    Also ensures all OTHER drawers are closed and locked.

    Args:
        task_env: RLBench task environment
        drawer_option: str, drawer to open
        open_amount: float, how far to open (0.1 is default in RLBench)

    Returns:
        float: actual joint position after setting
    """
    options_map = {"bottom": 0, "middle": 1, "top": 2}
    if isinstance(drawer_option, str):
        target_idx = options_map.get(drawer_option, 1)
    else:
        target_idx = drawer_option

    drawer_names = ["bottom", "middle", "top"]

    # Close and lock ALL other drawers first
    for idx, name in enumerate(drawer_names):
        joint = Joint(f'drawer_joint_{name}')
        if idx == target_idx:
            # This is the target drawer - set to open position
            joint.set_joint_position(open_amount, disable_dynamics=True)
        else:
            # Other drawers - close them and lock at position 0
            joint.set_joint_position(0.0, disable_dynamics=True)
            # Lock the joint by setting velocity to 0 and target position to 0
            joint.set_joint_target_position(0.0)
            joint.set_joint_target_velocity(0.0)

    # Let physics settle
    for _ in range(20):
        task_env._scene.pyrep.step()

    # Re-enforce other drawer positions after physics settle
    for idx, name in enumerate(drawer_names):
        if idx != target_idx:
            joint = Joint(f'drawer_joint_{name}')
            joint.set_joint_position(0.0, disable_dynamics=True)

    # Get target joint and verify position
    target_joint = Joint(f'drawer_joint_{drawer_names[target_idx]}')
    actual_position = target_joint.get_joint_position()
    if abs(actual_position - open_amount) > 0.005:  # More than 5mm difference
        print(f"  Warning: Drawer joint set to {open_amount:.4f}m but actual is {actual_position:.4f}m")

    return actual_position


def lock_other_drawers(target_drawer_idx):
    """
    Lock all drawers except the target one at position 0.
    Call this periodically during simulation to prevent other drawers from opening.

    Args:
        target_drawer_idx: int, index of the drawer we're working with (0=bottom, 1=middle, 2=top)
    """
    drawer_names = ["bottom", "middle", "top"]
    for idx, name in enumerate(drawer_names):
        if idx != target_drawer_idx:
            joint = Joint(f'drawer_joint_{name}')
            joint.set_joint_position(0.0, disable_dynamics=True)


def extend_workspace(task_env, scale=None):
    """
    Extend the workspace/table to accommodate cabinet placement.

    Due to CoppeliaSim limitations, compound shapes cannot be non-isometrically scaled.
    Instead, we create a new larger table surface (cuboid) that extends the workspace.

    Args:
        task_env: RLBench task environment
        scale: list of [X, Y, Z] scale factors (default from config)
    """
    from pyrep.objects.shape import Shape
    from pyrep.const import PrimitiveShape

    if scale is None:
        scale = WORKSPACE_SCALE

    # Skip if no scaling needed
    if scale == [1.0, 1.0, 1.0]:
        return

    # Check if extension already exists (from previous episode) - skip if so
    try:
        existing_extension = Shape('workspace_extension')
        # Extension already exists, no need to create again
        return
    except Exception:
        pass  # Extension doesn't exist, continue to create it

    try:
        workspace = Shape('workspace')

        # Get current bounding box to understand original size
        # bbox format: [min_x, max_x, min_y, max_y, min_z, max_z]
        bbox = workspace.get_bounding_box()
        original_size_x = bbox[1] - bbox[0]
        original_size_y = bbox[3] - bbox[2]
        original_size_z = bbox[5] - bbox[4]
        original_pos = workspace.get_position()

        # Try non-uniform scaling first (works for simple shapes)
        try:
            workspace.scale_object(scale[0], scale[1], scale[2])

            # Adjust position to keep the robot-facing edge (-Y) in place
            if scale[1] != 1.0:
                original_half_y = original_size_y / 2
                current_pos = list(workspace.get_position())
                # Offset in -Y to keep front edge stationary
                y_offset = original_half_y * (scale[1] - 1.0)
                current_pos[1] -= y_offset
                workspace.set_position(current_pos)
            return
        except Exception:
            pass  # Fall through to create extension

        # If scaling failed (compound shape), create an extension piece
        # Calculate the extension size needed
        extension_size_y = original_size_y * (scale[1] - 1.0)
        if extension_size_y <= 0:
            return  # No extension needed

        print(f"  Workspace bbox: X=[{bbox[0]:.2f}, {bbox[1]:.2f}], Y=[{bbox[2]:.2f}, {bbox[3]:.2f}], Z=[{bbox[4]:.2f}, {bbox[5]:.2f}]")
        print(f"  Workspace pos: {original_pos}")
        print(f"  Creating extension: Y size = {extension_size_y:.2f}m")

        # Create a thin cuboid as the table extension
        # Position it adjacent to the original workspace in the -Y direction
        extension_size = [original_size_x * scale[0], extension_size_y, 0.02]  # Thin table top

        # Get workspace color to match (use a neutral gray if not available)
        try:
            workspace_color = workspace.get_color()
        except Exception:
            workspace_color = None

        # Create extension cuboid
        extension = Shape.create(
            type=PrimitiveShape.CUBOID,
            size=extension_size,
            mass=1.0,
            respondable=False,  # Don't interact with physics
            static=True,
            renderable=True,
            color=workspace_color if workspace_color else [0.6, 0.6, 0.6]
        )
        extension.set_name("workspace_extension")

        # Position the extension: adjacent to workspace in the -Y direction
        # The workspace's -Y edge is at: workspace_y - half_original_y
        # Extension center goes at: (workspace -Y edge) - half_extension_y
        table_top_z = bbox[5]  # Top of workspace bounding box
        extension_pos = [
            original_pos[0],
            bbox[2] - extension_size_y / 2,  # Start from workspace min Y and extend further in -Y
            table_top_z - 0.01  # Slightly below table top surface
        ]
        extension.set_position(extension_pos)
        print(f"  Extension placed at: {extension_pos}")

    except Exception as e:
        print(f"  Warning: Could not extend workspace: {e}")


def fix_cabinet_orientation(task_env):
    """
    Position and orient the cabinet so the drawer faces the robot.

    Setup:
    - Robot arm is at approximately X=0.28, Y=0 (base at origin)
    - Cabinet at configured position (Y=-0.7, in -Y region from robot)
    - No rotation - drawer opens in +Y direction (toward robot)

    Motion sequence:
    - Robot starts at HOME position
    - Reaches to handle position (at open drawer, Y~=-0.35)
    - Pushes in -Y direction (into cabinet) to close drawer

    Args:
        task_env: RLBench task environment
    """
    # NOTE: extend_workspace is disabled - it caused issues with repeated resets
    # because modifying the workspace/creating extensions doesn't persist correctly
    # across task_env.reset() calls. The cabinet works fine without table extension.
    # extend_workspace(task_env)

    # Get the task base object
    task_base = task_env._task.get_base()

    # Use position and orientation from config file
    fixed_position = CABINET_POSITION
    fixed_orientation = get_cabinet_orientation_radians()

    task_base.set_position(fixed_position)
    task_base.set_orientation(fixed_orientation)

    # Let physics settle
    for _ in range(20):
        task_env._scene.pyrep.step()


def reset_robot_to_default(task_env):
    """
    Reset the robot arm to its default (initial) joint positions.

    This should be called BEFORE task_env.reset() to avoid collision errors
    during RLBench's random placement validation.

    Args:
        task_env: RLBench task environment
    """
    robot = task_env._scene.robot

    # Get the initial joint positions stored by the scene
    default_joints = task_env._scene._start_arm_joint_pos

    # Reset arm to default position
    robot.arm.set_joint_positions(default_joints, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)

    # Release gripper and reset
    robot.gripper.release()
    default_gripper = task_env._scene._starting_gripper_joint_pos
    robot.gripper.set_joint_positions(default_gripper, disable_dynamics=True)

    # Step simulation to apply
    for _ in range(5):
        task_env._scene.pyrep.step()


# ============================================================================
# Waypoint Computation
# ============================================================================

def compute_push_waypoints(start_pos, handle_pos, handle_ori, waypoint_params=None):
    """
    Compute the key waypoints for a push/close drawer trajectory.

    Args:
        start_pos: np.ndarray(3,), starting end-effector position
        handle_pos: np.ndarray(3,), position of drawer handle (from waypoint0)
        handle_ori: np.ndarray(3,), orientation for grasping (euler angles)
        waypoint_params: dict with offset parameters

    Returns:
        dict with keys: start, prehandle, handle, push_end, retreat
    """
    if waypoint_params is None:
        waypoint_params = DEFAULT_WAYPOINT_OFFSETS.copy()

    # Without cabinet rotation, drawer opens toward +Y (toward robot at origin)
    # Cabinet is at Y=-0.7, handle when open is at Y~=-0.35 (closer to robot)
    # So we approach from +Y side and push in -Y direction to close
    prehandle_offset_y = waypoint_params.get("prehandle_offset_y", 0.15)
    push_distance = waypoint_params.get("push_distance", 0.05)
    retreat_distance = waypoint_params.get("retreat_distance", 0.08)

    # Handle position (where gripper should be to contact handle)
    # No offset - reach ends exactly at handle, push phase will close the drawer
    handle = handle_pos.copy()
    handle[1] -= 0.0  # No offset - reach ends at handle position

    # Pre-handle: position in front of open drawer (for approach)
    # Drawer opens toward +Y (toward robot), so we offset in +Y (closer to robot)
    prehandle = handle_pos.copy()
    prehandle[1] += prehandle_offset_y  # Offset in +Y (in front of drawer, toward robot)

    # Push end: position after pushing drawer closed
    # Push in -Y direction (into the cabinet, away from robot) to close the drawer
    push_end = handle_pos.copy()
    push_end[1] -= push_distance  # Push in -Y direction (into cabinet)

    # Retreat: move back after push (away from cabinet, back in +Y direction)
    retreat = push_end.copy()
    retreat[1] += retreat_distance  # Retreat in +Y direction (toward robot)

    return {
        "start": start_pos.copy(),
        "prehandle": prehandle,
        "handle": handle,
        "push_end": push_end,
        "retreat": retreat,
    }


# ============================================================================
# Control Point Functions (same as reach/pick-place)
# ============================================================================

def build_local_frame(start_pos, target_pos):
    """Return (line_vec_norm, perp1, perp2) forming an orthonormal frame."""
    line_vec = target_pos - start_pos
    line_len = np.linalg.norm(line_vec)
    if line_len < 1e-6:
        # Degenerate case: start and target are same
        return np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])

    line_vec_norm = line_vec / line_len
    if abs(line_vec_norm[2]) < 0.9:
        perp1 = np.cross(line_vec_norm, np.array([0, 0, 1]))
    else:
        perp1 = np.cross(line_vec_norm, np.array([1, 0, 0]))
    perp1 = perp1 / np.linalg.norm(perp1)
    perp2 = np.cross(line_vec_norm, perp1)
    return line_vec_norm, perp1, perp2


def parabola3D(S, E, C, t):
    """Quadratic Bezier curve in 3D with endpoints S, E and control point C."""
    M = 0.5 * (S + E)
    P1 = 2 * C - M
    return (1 - t) * (1 - t) * S + 2 * (1 - t) * t * P1 + t * t * E


def compute_control_point_from_params(start_pos, target_pos, radius, angle, dist_frac, pos_frac):
    """
    Compute a control point given canonical parameters.

    Args:
        start_pos: np.ndarray(3,), start position
        target_pos: np.ndarray(3,), target position
        radius: float, maximum offset radius (meters)
        angle: float, angle around the line (radians)
        dist_frac: float, fraction of radius for offset distance
        pos_frac: float, fraction along start->target line for base position

    Returns:
        control_point: np.ndarray(3,)
    """
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, target_pos)

    # Base position along the line
    base_pos = start_pos + pos_frac * (target_pos - start_pos)

    # Offset perpendicular to the line
    offset_dist = dist_frac * radius
    offset = offset_dist * (np.cos(angle) * perp1 + np.sin(angle) * perp2)

    return base_pos + offset


def generate_canonical_control_point_params(num_modes):
    """
    Generate a canonical set of control point parameters based on num_modes.

    Parameters are (angle, distance, position_along_line):
      - angle: angle around the start-end line (radians)
      - distance: fraction of radius for perpendicular offset (0-1)
      - position: fraction along start->target line (0=start, 1=end) - FIXED at 0.5

    Structure:
      - num_modes = num_angles * 2 (two distance values: 1.0 and 0.5)
      - angles are evenly split across 360 degrees
      - position is fixed at 0.5 (middle of line)

    Example: num_modes=8 -> 4 angles (0°, 90°, 180°, 270°) x 2 distances (1.0, 0.5)

    Args:
        num_modes: Total number of control point modes

    Returns:
        params: np.ndarray of shape (num_modes, 3) with (angle, dist_frac, pos_frac)
    """
    num_angles = num_modes // 2
    distances = [1.0, 0.5]  # Two distance values
    pos_frac = 0.5  # Fixed position

    params = []
    for dist in distances:
        for i in range(num_angles):
            angle_deg = i * (360.0 / num_angles)  # Evenly split 360 degrees
            params.append([np.radians(angle_deg), dist, pos_frac])

    return np.array(params)


# ============================================================================
# Trajectory Generation
# ============================================================================

def linear_interpolate(p1, p2, t):
    """Linear interpolation between two points."""
    return p1 + t * (p2 - p1)


def generate_phase_positions(waypoints, phase_steps, cp_reach):
    """
    Generate all EE positions for the push/close drawer trajectory.

    Phase 1: REACH (shaped) - start -> handle (using Bezier curve)
    Phase 2: PUSH (linear) - handle -> push_end (DISABLED for testing)

    Args:
        waypoints: dict from compute_push_waypoints
        phase_steps: dict with step counts per phase
        cp_reach: np.ndarray(3,), control point for reach phase

    Returns:
        positions: list of np.ndarray(3,), one per step
        phase_indices: dict mapping phase name to (start_idx, end_idx)
        phase_labels: list of str, phase name for each step
    """
    positions = []
    phase_labels = []
    phase_indices = {}
    current_idx = 0

    # Phase 1: REACH (shaped) - start -> handle (directly!)
    n = phase_steps["reach"]
    phase_indices["reach"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = parabola3D(waypoints["start"], waypoints["handle"], cp_reach, t)
        positions.append(pos)
        phase_labels.append("reach")
    current_idx += n

    # Phase 2: PUSH (linear) - handle -> push_end
    n = phase_steps["push"]
    phase_indices["push"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["handle"], waypoints["push_end"], t)
        positions.append(pos)
        phase_labels.append("push")
    current_idx += n

    return positions, phase_indices, phase_labels


def move_robot_to_start(task_env, start_pos, desired_ori, gripper_open=True):
    """
    Move robot to HOME configuration and prepare for trajectory.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), requested start position (currently uses HOME)
        desired_ori: np.ndarray(3,), desired orientation
        gripper_open: bool, whether gripper should be open

    Returns:
        actual_start_pos: np.ndarray(3,), the actual start position (HOME position)
        actual_ori: np.ndarray(3,), the orientation to use for trajectory
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME configuration first
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(10):
        task_env._scene.pyrep.step()

    # Set gripper state
    gripper_velocity = 0.2
    if gripper_open:
        gripper.actuate(1.0, gripper_velocity)
    else:
        gripper.actuate(0.0, gripper_velocity)

    for _ in range(20):
        task_env._scene.pyrep.step()

    # Get actual position and orientation from HOME
    home_pos = tip.get_position()
    home_ori = tip.get_orientation()

    task_env._scene.pyrep.step()

    # Debug: print HOME EE position (helpful for verifying Z matches handle Z)
    print(f"  HOME EE position: [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")

    return home_pos.copy(), home_ori


def execute_step(task_env, target_ee_pos, desired_ori, extra_steps=0):
    """
    Execute one step with position control.

    Args:
        task_env: RLBench task environment
        target_ee_pos: np.ndarray(3,), target end-effector position
        desired_ori: np.ndarray(3,), desired orientation (euler)
        extra_steps: int, additional physics steps

    Returns:
        obs: observation after the step
        joint_positions: the commanded joint positions

    Raises:
        IKError or ConfigurationPathError: if IK fails
    """
    robot = task_env._scene.robot

    # Solve IK - try with orientation first, then position-only if needed
    try:
        joint_positions = robot.arm.solve_ik(target_ee_pos, euler=desired_ori)
    except (IKError, ConfigurationPathError):
        # Fall back to position-only IK
        joint_positions = robot.arm.solve_ik(target_ee_pos)

    # Use set_joint_positions with disable_dynamics=True for reliable movement
    robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)

    # Step simulation
    num_steps = 5 + extra_steps
    for _ in range(num_steps):
        task_env._scene.pyrep.step()
        task_env._scene.task.step()

    obs = task_env._scene.get_observation()

    if not hasattr(obs, 'misc'):
        obs.misc = {}
    obs.misc['joint_position_action'] = np.concatenate(
        [joint_positions, np.array([obs.gripper_open])]
    )

    return obs, joint_positions


def generate_push_trajectory(
    task_env,
    start_pos,
    handle_pos,
    handle_ori,
    cp_idx,
    canonical_params,
    control_point_radius,
    waypoint_params=None,
    phase_steps=None,
    steps_per_point=5,
    target_drawer_idx=2,  # Default to top drawer (variation 2)
):
    """
    Generate a full push/close drawer trajectory with shaped reach phase.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), starting EE position
        handle_pos: np.ndarray(3,), handle position
        handle_ori: np.ndarray(3,), handle orientation (euler)
        cp_idx: int, canonical control point index for reach phase
        canonical_params: np.ndarray, canonical control point parameters
        control_point_radius: float, radius for control point offset
        waypoint_params: dict, waypoint offset parameters
        phase_steps: dict, step counts per phase
        steps_per_point: int, number of physics steps per trajectory point
        target_drawer_idx: int, index of target drawer (0=bottom, 1=middle, 2=top)

    Returns:
        demo: list of observations
        metadata: dict with trajectory info (includes 'trace' and 'phase_labels' for video)

    Raises:
        RuntimeError: if trajectory generation fails
    """
    if phase_steps is None:
        phase_steps = DEFAULT_PHASE_STEPS.copy()

    # Note: total_steps will be calculated from actual positions generated,
    # not from config, because some phases may be disabled
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Move to HOME position first
    actual_start_pos, actual_ori = move_robot_to_start(
        task_env, start_pos, handle_ori, gripper_open=True
    )

    # Compute waypoints using actual start position
    waypoints = compute_push_waypoints(actual_start_pos, handle_pos, handle_ori, waypoint_params)

    # Debug: print waypoints
    print(f"  Waypoints:")
    print(f"    start:  [{waypoints['start'][0]:.4f}, {waypoints['start'][1]:.4f}, {waypoints['start'][2]:.4f}]")
    print(f"    handle: [{waypoints['handle'][0]:.4f}, {waypoints['handle'][1]:.4f}, {waypoints['handle'][2]:.4f}] (reach target, 3cm offset)")

    # Compute control point for reach phase
    # Reach goes directly from start to handle (not to prehandle)
    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp_reach = compute_control_point_from_params(
        waypoints["start"], waypoints["handle"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    # Debug: print control point
    print(f"    cp_reach: [{cp_reach[0]:.4f}, {cp_reach[1]:.4f}, {cp_reach[2]:.4f}] (angle={np.degrees(angle):.1f}deg)")

    # Generate all target positions
    positions, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    # Calculate total_steps from actual positions (some phases may be disabled)
    total_steps = len(positions)
    print(f"  Trajectory: {total_steps} steps (reach only, push disabled)")

    # Execute trajectory - track EE trace for video overlay
    demo = []
    trace = []  # EE positions for video overlay
    trace_phase_labels = []  # Phase labels for coloring
    gripper_states = []

    # Initial observation
    demo.append(task_env._scene.get_observation())
    trace.append(tip.get_position().copy())
    trace_phase_labels.append("reach")
    gripper_states.append(1.0)

    successful_steps = 1
    failed_steps = 0
    max_failures = total_steps // 4
    prev_joints = list(robot.arm.get_joint_positions())

    for i in range(1, len(positions)):
        target_pos_step = positions[i]
        phase = phase_labels[i]

        try:
            # Use Jacobian IK for reach phase, position-only for push/retreat
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            if phase == "reach":
                # Reach: maintain orientation, disable dynamics (no collision needed)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos_step, euler=actual_ori)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            else:
                # Push: enable dynamics so gripper physically pushes the drawer
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos_step)
                robot.arm.set_joint_target_positions(joint_positions)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper stays open for push task
            gripper.actuate(1.0, 0.2)

            # Step simulation - more steps during push for physics to work
            sim_steps = steps_per_point * 2 if phase == "push" else steps_per_point
            for _ in range(sim_steps):
                task_env._scene.pyrep.step()
                task_env._scene.task.step()
                # Lock other drawers to prevent them from opening during simulation
                lock_other_drawers(target_drawer_idx)

            obs = task_env._scene.get_observation()
            if not hasattr(obs, 'misc'):
                obs.misc = {}
            obs.misc['joint_position_action'] = np.concatenate([joint_positions, [1.0]])
            demo.append(obs)

            # Track EE position for video
            trace.append(tip.get_position().copy())
            trace_phase_labels.append(phase)
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)
            successful_steps += 1

        except (IKError, ConfigurationPathError) as e:
            failed_steps += 1
            print(f"    IK failed at step {i}/{len(positions)}: target=[{target_pos_step[0]:.4f}, {target_pos_step[1]:.4f}, {target_pos_step[2]:.4f}]")
            # Strict mode: reject entire episode on first IK failure
            raise RuntimeError(f"IK failure at step {i}/{len(positions)} - rejecting episode")

    assert len(demo) == total_steps, f"Expected {total_steps} steps, got {len(demo)}"

    # Debug: print final EE position vs target
    final_ee_pos = tip.get_position()
    target_handle = waypoints["handle"]
    ee_error = np.linalg.norm(final_ee_pos - target_handle)
    print(f"  Final EE: [{final_ee_pos[0]:.4f}, {final_ee_pos[1]:.4f}, {final_ee_pos[2]:.4f}]")
    print(f"  Target:   [{target_handle[0]:.4f}, {target_handle[1]:.4f}, {target_handle[2]:.4f}]")
    print(f"  Position error: {ee_error*100:.2f}cm")

    if failed_steps > 0:
        print(f"\t  IK: {successful_steps} ok, {failed_steps} failed")

    # Build metadata (includes trace for video overlay)
    metadata = {
        "waypoints": waypoints,  # Keep as numpy arrays for validation checks
        "waypoints_list": {k: v.tolist() for k, v in waypoints.items()},  # JSON-serializable version
        "cp_reach": cp_reach.tolist(),
        "phase_indices": phase_indices,
        "phase_steps": phase_steps,
        "ik_failures": failed_steps,
        "actual_orientation": actual_ori.tolist(),
        "trace": np.array(trace),  # For video overlay
        "phase_labels": trace_phase_labels,  # For video coloring
        "gripper_states": gripper_states,
    }

    return demo, metadata


def check_task_success(task_env):
    """
    Check if the close drawer task was successful using RLBench's official check.

    RLBench's JointCondition checks: |current_pos - original_pos| > 0.06
    where original_pos = 0.1m (set by init_episode).
    So success requires final position < 0.04m (moved more than 6cm from 0.1m).

    Args:
        task_env: RLBench task environment

    Returns:
        bool: True if task succeeded (drawer moved > 6cm from original position)
    """
    try:
        success, _ = task_env._task.success()
        return success
    except Exception:
        return False


# ============================================================================
# Pre-filtering Control Points (for IK feasibility)
# ============================================================================

def prefilter_control_points_fast(robot_arm, start_pos, target_pos, orientation,
                                   canonical_params, control_point_radius,
                                   require_orientation=True):
    """
    Fast pre-filter: only test the midpoint and endpoint of each curve.

    Args:
        robot_arm: PyRep arm object
        start_pos, target_pos, orientation: trajectory parameters
        canonical_params: control point parameters
        control_point_radius: float
        require_orientation: bool

    Returns:
        valid_indices: list of valid control point indices
    """
    valid_indices = []

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Test midpoint (t=0.5) and endpoint (t=1.0)
        try:
            mid_pos = parabola3D(start_pos, target_pos, control_point, 0.5)
            end_pos = parabola3D(start_pos, target_pos, control_point, 1.0)

            if require_orientation and orientation is not None:
                robot_arm.solve_ik(mid_pos, euler=orientation)
                robot_arm.solve_ik(end_pos, euler=orientation)
            else:
                robot_arm.solve_ik(mid_pos)
                robot_arm.solve_ik(end_pos)

            valid_indices.append(cp_idx)
        except (IKError, ConfigurationPathError):
            pass

    return valid_indices
