# filename: pick_and_place_utils.py
"""
Utilities for pick-and-place dataset generation with control-point trajectory shaping.

Key design:
  - One control-point index (cp_idx) shapes BOTH the reach-to-grasp and carry-to-place arcs
  - Each shaped segment (reach, carry) gets its own 64-step sub-trajectory
  - Linear phases (descend, grasp, lift, descend_place, release, retreat) use fewer steps
  - Total trajectory length is configurable (default ~150 steps)
  - Uses the same canonical control-point parameterization as the reach task

Reference: This structure follows PerAct's action formulation where gripper open/close
are discrete events at semantic waypoints (grasp/release).
"""

import numpy as np
from pyrep.errors import ConfigurationPathError, IKError

from utils import (
    HOME_JOINTS,
    build_local_frame,
    compute_control_point_from_params,
    parabola3D,
)


# ============================================================================
# Phase Configuration
# ============================================================================

# Default phase step allocation
# Key insight: SHAPED phases (reach, carry) each get 64 steps to match reach task
# Linear phases get fewer steps since they're simple motions
# IMPORTANT: Grasp phases need sufficient steps for gripper to close and grip object
DEFAULT_PHASE_STEPS = {
    "reach": 64,           # Phase 1: shaped start -> pregrasp (SHAPED, 64 steps)
    "descend": 8,          # Phase 2: pregrasp -> grasp
    "hold_grasp": 8,       # Phase 3: dwell at grasp (gripper open, settle) - increased from 4
    "close_gripper": 12,   # Phase 4: close gripper while holding position - increased from 4 for grip
    "lift": 8,             # Phase 5: grasp -> lift (vertical motion before lateral)
    "carry": 64,           # Phase 6: shaped lift -> preplace (SHAPED, 64 steps)
    "descend_place": 8,    # Phase 7: preplace -> place
    "hold_place": 8,       # Phase 8: dwell at place (gripper closed) - increased from 4
    "release": 8,          # Phase 9: open gripper while holding - increased from 4
    "retreat": 8,          # Phase 10: lift away from place position
}
# Total: 64 + 8 + 8 + 12 + 8 + 64 + 8 + 8 + 8 + 8 = 196 steps

# Default waypoint offsets (used as fallback if demo extraction fails)
DEFAULT_WAYPOINT_OFFSETS = {
    "pregrasp_height": 0.12,   # Height above object for pre-grasp
    "grasp_height": 0.04,      # Height above object for actual grasp (gripper tip offset)
    "lift_height": 0.15,       # Height to lift object to after grasp
    "preplace_height": 0.15,   # Height above target for pre-place
    "place_height": 0.05,      # Height above target for release
    "retreat_height": 0.10,    # How much to lift after release
}

# Default grasp orientation (top-down, gripper pointing down)
# Euler angles in radians: [roll, pitch, yaw]
# This is a common "top-down" orientation for Panda gripper
DEFAULT_GRASP_ORIENTATION = np.array([np.pi, 0.0, 0.0])  # 180 deg roll = gripper pointing down


# ============================================================================
# Extract Parameters from RLBench Demo
# ============================================================================

def _infer_obj_tgt_pose_indices_from_demo(demo):
    """
    Infer which 7D pose blocks in task_low_dim_state correspond to object + target.

    The object moves the most during the demo (gets picked up and placed).
    The target moves the least (stationary).

    Returns:
        obj_idx: index of object pose block
        tgt_idx: index of target pose block
        poses: (T, N, 7) array of all pose blocks over time
    """
    # Stack task_low_dim_state from all timesteps
    ld = np.stack([obs.task_low_dim_state for obs in demo], axis=0)  # (T, D)

    if ld.ndim != 2 or (ld.shape[1] % 7) != 0:
        return None, None, None  # unexpected layout

    num_poses = ld.shape[1] // 7
    poses = ld.reshape(ld.shape[0], num_poses, 7)  # (T, N, 7)
    pos = poses[:, :, :3]  # (T, N, 3) - just positions

    # Motion score per block: how much each block moves from its initial position
    motion = np.linalg.norm(pos - pos[:1], axis=2).mean(axis=0)  # (N,)

    # Object = most motion, Target = least motion (among remaining)
    obj_idx = int(np.argmax(motion))
    motion_copy = motion.copy()
    motion_copy[obj_idx] = np.inf  # exclude object
    tgt_idx = int(np.argmin(motion_copy))

    return obj_idx, tgt_idx, poses


def extract_demo_waypoints(task_env, task_name, num_demos=1):
    """
    Extract waypoint parameters from RLBench demos.

    IMPORTANT: Uses task_low_dim_state from the demo observations to get
    object/target positions AT THE CORRECT TIMESTEPS (grasp moment, release moment).
    This avoids the bug where get_object_and_target_positions() returns positions
    after the demo has already moved the object.

    Args:
        task_env: RLBench task environment
        task_name: str, name of the task
        num_demos: int, number of demos to average over

    Returns:
        dict with extracted parameters
    """
    print(f"  Extracting waypoint parameters from {num_demos} RLBench demo(s)...")

    all_params = []

    for demo_idx in range(num_demos):
        try:
            demos = task_env.get_demos(amount=1, live_demos=True)
            if not demos or len(demos) == 0:
                print(f"    Demo {demo_idx}: Failed to get demo")
                continue

            demo = demos[0]

            # Find key moments in demo by analyzing gripper state
            grasp_idx = None
            release_idx = None

            # Find grasp moment (gripper closes)
            for i in range(1, len(demo)):
                if demo[i-1].gripper_open > 0.5 and demo[i].gripper_open < 0.5:
                    grasp_idx = i
                    break

            # Find release moment (gripper opens after grasp)
            if grasp_idx is not None:
                for i in range(grasp_idx + 1, len(demo)):
                    if demo[i-1].gripper_open < 0.5 and demo[i].gripper_open > 0.5:
                        release_idx = i
                        break

            if grasp_idx is None:
                print(f"    Demo {demo_idx}: Could not find grasp moment")
                continue

            if release_idx is None:
                print(f"    Demo {demo_idx}: Could not find release moment")
                continue

            # Infer object/target indices from task_low_dim_state
            obj_idx, tgt_idx, poses = _infer_obj_tgt_pose_indices_from_demo(demo)
            if poses is None:
                print(f"    Demo {demo_idx}: task_low_dim_state missing or unexpected format")
                continue

            # Get object position AT GRASP TIME (not after demo ends!)
            object_pos_at_grasp = poses[grasp_idx, obj_idx, :3]
            # Get target position at release time (usually stationary anyway)
            target_pos_at_release = poses[release_idx, tgt_idx, :3]
            # Also get object position at start for pregrasp height calculation
            object_pos_at_start = poses[0, obj_idx, :3]

            # Extract gripper positions at key moments
            grasp_pos = demo[grasp_idx].gripper_pose[:3]
            release_pos = demo[release_idx].gripper_pose[:3]

            # Pregrasp: look backwards from grasp for highest point above object
            pregrasp_height = 0
            for i in range(max(0, grasp_idx - 20), grasp_idx):
                h = demo[i].gripper_pose[2] - object_pos_at_start[2]
                if h > pregrasp_height:
                    pregrasp_height = h

            # Lift: look forward from grasp for highest point (relative to grasp pos)
            lift_height = 0
            for i in range(grasp_idx, min(release_idx, grasp_idx + 30)):
                h = demo[i].gripper_pose[2] - grasp_pos[2]
                if h > lift_height:
                    lift_height = h

            # Preplace: look backwards from release for highest point above target
            preplace_height = 0
            for i in range(max(grasp_idx, release_idx - 30), release_idx):
                h = demo[i].gripper_pose[2] - target_pos_at_release[2]
                if h > preplace_height:
                    preplace_height = h

            # Retreat: look forward from release
            retreat_height = 0
            for i in range(release_idx, min(len(demo), release_idx + 20)):
                h = demo[i].gripper_pose[2] - release_pos[2]
                if h > retreat_height:
                    retreat_height = h

            # Extract grasp orientation
            grasp_quat = demo[grasp_idx].gripper_pose[3:7]
            grasp_ori = quaternion_to_euler(grasp_quat)

            # XY offsets - computed using positions AT THE CORRECT TIMESTEPS
            grasp_xy_offset = grasp_pos[:2] - object_pos_at_grasp[:2]
            place_xy_offset = release_pos[:2] - target_pos_at_release[:2]

            # Heights
            grasp_height = grasp_pos[2] - object_pos_at_grasp[2]
            place_height = release_pos[2] - target_pos_at_release[2]

            # Sanity check - grasp_xy_offset should be small (cm-scale, not 10+ cm)
            if np.linalg.norm(grasp_xy_offset) > 0.08:
                print(f"    Demo {demo_idx}: WARNING suspicious grasp_xy_offset: {grasp_xy_offset}")

            params = {
                "grasp_height": grasp_height,
                "pregrasp_height": pregrasp_height if pregrasp_height > 0 else 0.12,
                "lift_height": lift_height if lift_height > 0 else 0.15,
                "place_height": place_height if place_height > 0 else 0.05,
                "preplace_height": preplace_height if preplace_height > 0 else 0.15,
                "retreat_height": retreat_height if retreat_height > 0 else 0.10,
                "grasp_orientation": grasp_ori,
                "grasp_xy_offset": grasp_xy_offset,
                "place_xy_offset": place_xy_offset,
            }

            all_params.append(params)
            print(f"    Demo {demo_idx}: grasp_h={grasp_height:.3f}, "
                  f"pregrasp_h={params['pregrasp_height']:.3f}, lift_h={params['lift_height']:.3f}, "
                  f"grasp_xy=[{grasp_xy_offset[0]:.3f}, {grasp_xy_offset[1]:.3f}]")

        except Exception as e:
            print(f"    Demo {demo_idx}: Error - {e}")
            import traceback
            traceback.print_exc()
            continue

    if len(all_params) == 0:
        print("  WARNING: No demos succeeded, using default parameters")
        return {
            **DEFAULT_WAYPOINT_OFFSETS,
            "grasp_orientation": DEFAULT_GRASP_ORIENTATION.copy(),
            "grasp_xy_offset": np.zeros(2),
            "place_xy_offset": np.zeros(2),
        }

    # Average across successful demos
    avg_params = {}
    for key in all_params[0].keys():
        if key in ("grasp_orientation", "grasp_xy_offset", "place_xy_offset"):
            # Average arrays (simple mean, works for small variations)
            avg_params[key] = np.mean([p[key] for p in all_params], axis=0)
        else:
            avg_params[key] = np.mean([p[key] for p in all_params])

    print(f"  Extracted parameters (averaged over {len(all_params)} demos):")
    for key, val in avg_params.items():
        if key == "grasp_orientation":
            print(f"    {key}: [{val[0]:.3f}, {val[1]:.3f}, {val[2]:.3f}]")
        elif key in ("grasp_xy_offset", "place_xy_offset"):
            print(f"    {key}: [{val[0]:.3f}, {val[1]:.3f}]m")
        else:
            print(f"    {key}: {val:.3f}m")

    return avg_params


# ============================================================================
# Waypoint Computation
# ============================================================================

def compute_pick_place_waypoints(start_pos, object_pos, target_pos, waypoint_params=None, grasp_offset=None):
    """
    Compute the key waypoints for a pick-and-place trajectory.

    Args:
        start_pos: np.ndarray(3,), starting end-effector position
        object_pos: np.ndarray(3,), position of object to pick
        target_pos: np.ndarray(3,), position of place target (bin/plate/etc)
        waypoint_params: dict with height offsets (from extract_demo_waypoints or defaults)
        grasp_offset: np.ndarray(3,), optional offset for grasp position

    Returns:
        dict with keys: start, pregrasp, grasp, lift, preplace, place, retreat
    """
    if grasp_offset is None:
        grasp_offset = np.array([0.0, 0.0, 0.0])

    # Use extracted params or defaults
    if waypoint_params is None:
        waypoint_params = DEFAULT_WAYPOINT_OFFSETS

    pregrasp_height = waypoint_params.get("pregrasp_height", 0.12)
    grasp_height = waypoint_params.get("grasp_height", 0.04)
    lift_height = waypoint_params.get("lift_height", 0.15)
    preplace_height = waypoint_params.get("preplace_height", 0.15)
    place_height = waypoint_params.get("place_height", 0.05)
    retreat_height = waypoint_params.get("retreat_height", 0.10)

    # XY offsets extracted from demos - crucial for accurate grasp/place
    grasp_xy_offset = waypoint_params.get("grasp_xy_offset", np.zeros(2))
    place_xy_offset = waypoint_params.get("place_xy_offset", np.zeros(2))

    # Pre-grasp: above object with XY offset
    pregrasp = object_pos.copy()
    pregrasp[:2] += grasp_xy_offset  # Apply XY offset from demo
    pregrasp[2] += pregrasp_height

    # Grasp: at object with XY offset (with optional additional offset for gripper center)
    grasp = object_pos.copy() + grasp_offset
    grasp[:2] += grasp_xy_offset  # Apply XY offset from demo
    grasp[2] += grasp_height

    # Lift: straight up from grasp (important: do this BEFORE lateral motion)
    lift = grasp.copy()
    lift[2] = grasp[2] + lift_height

    # Pre-place: above target with XY offset, at least as high as lift
    preplace = target_pos.copy()
    preplace[:2] += place_xy_offset  # Apply XY offset from demo
    preplace[2] = max(target_pos[2] + preplace_height, lift[2])

    # Place: at target with XY offset
    place = target_pos.copy()
    place[:2] += place_xy_offset  # Apply XY offset from demo
    place[2] += place_height

    # Retreat: lift away after release (keep same XY)
    retreat = place.copy()
    retreat[2] += retreat_height

    return {
        "start": start_pos.copy(),
        "pregrasp": pregrasp,
        "grasp": grasp,
        "lift": lift,
        "preplace": preplace,
        "place": place,
        "retreat": retreat,
    }


# ============================================================================
# Trajectory Generation for Pick-and-Place
# ============================================================================

def linear_interpolate(p1, p2, t):
    """Linear interpolation between two points."""
    return p1 + t * (p2 - p1)


def generate_phase_positions(waypoints, phase_steps, cp_reach, cp_carry):
    """
    Generate all EE positions for the 10-phase pick-and-place trajectory.

    Each shaped phase (reach, carry) gets 64 steps.
    Linear phases get fewer steps.

    Args:
        waypoints: dict from compute_pick_place_waypoints
        phase_steps: dict with step counts per phase
        cp_reach: np.ndarray(3,), control point for reach phase
        cp_carry: np.ndarray(3,), control point for carry phase

    Returns:
        positions: list of np.ndarray(3,), one per step
        gripper_states: list of float (1.0 = open, 0.0 = closed)
        phase_indices: dict mapping phase name to (start_idx, end_idx)
        phase_labels: list of str, phase name for each step (for extra physics steps)
    """
    positions = []
    gripper_states = []
    phase_labels = []
    phase_indices = {}
    current_idx = 0

    # Phase 1: REACH (shaped) - start -> pregrasp (64 steps)
    n = phase_steps["reach"]
    phase_indices["reach"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = parabola3D(waypoints["start"], waypoints["pregrasp"], cp_reach, t)
        positions.append(pos)
        gripper_states.append(1.0)  # Open
        phase_labels.append("reach")
    current_idx += n

    # Phase 2: DESCEND (straight) - pregrasp -> grasp
    n = phase_steps["descend"]
    phase_indices["descend"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["pregrasp"], waypoints["grasp"], t)
        positions.append(pos)
        gripper_states.append(1.0)  # Open
        phase_labels.append("descend")
    current_idx += n

    # Phase 3: HOLD_GRASP (dwell at grasp, gripper open)
    n = phase_steps["hold_grasp"]
    phase_indices["hold_grasp"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["grasp"].copy())
        gripper_states.append(1.0)  # Open (settling)
        phase_labels.append("hold_grasp")
    current_idx += n

    # Phase 4: CLOSE_GRIPPER (hold position, close gripper)
    n = phase_steps["close_gripper"]
    phase_indices["close_gripper"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["grasp"].copy())
        gripper_states.append(0.0)  # Closed
        phase_labels.append("close_gripper")
    current_idx += n

    # Phase 5: LIFT (straight) - grasp -> lift
    # Critical: move UP before moving laterally to avoid collisions
    n = phase_steps["lift"]
    phase_indices["lift"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["grasp"], waypoints["lift"], t)
        positions.append(pos)
        gripper_states.append(0.0)  # Closed
        phase_labels.append("lift")
    current_idx += n

    # Phase 6: CARRY (shaped) - lift -> preplace (64 steps)
    n = phase_steps["carry"]
    phase_indices["carry"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = parabola3D(waypoints["lift"], waypoints["preplace"], cp_carry, t)
        positions.append(pos)
        gripper_states.append(0.0)  # Closed
        phase_labels.append("carry")
    current_idx += n

    # Phase 7: DESCEND_PLACE (straight) - preplace -> place
    n = phase_steps["descend_place"]
    phase_indices["descend_place"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["preplace"], waypoints["place"], t)
        positions.append(pos)
        gripper_states.append(0.0)  # Closed
        phase_labels.append("descend_place")
    current_idx += n

    # Phase 8: HOLD_PLACE (dwell at place, gripper closed)
    n = phase_steps["hold_place"]
    phase_indices["hold_place"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["place"].copy())
        gripper_states.append(0.0)  # Closed (settling)
        phase_labels.append("hold_place")
    current_idx += n

    # Phase 9: RELEASE (hold position, open gripper)
    n = phase_steps["release"]
    phase_indices["release"] = (current_idx, current_idx + n)
    for i in range(n):
        positions.append(waypoints["place"].copy())
        gripper_states.append(1.0)  # Open (release)
        phase_labels.append("release")
    current_idx += n

    # Phase 10: RETREAT (straight) - place -> retreat
    n = phase_steps["retreat"]
    phase_indices["retreat"] = (current_idx, current_idx + n)
    for i in range(n):
        t = i / max(n - 1, 1)
        pos = linear_interpolate(waypoints["place"], waypoints["retreat"], t)
        positions.append(pos)
        gripper_states.append(1.0)  # Open
        phase_labels.append("retreat")
    current_idx += n

    return positions, gripper_states, phase_indices, phase_labels


def quaternion_to_euler(quat):
    """
    Convert quaternion to euler angles (XYZ convention).

    Args:
        quat: np.ndarray(4,), quaternion [x, y, z, w] or [w, x, y, z]
              RLBench/PyRep uses [x, y, z, w] ordering

    Returns:
        euler: np.ndarray(3,), euler angles [roll, pitch, yaw] in radians
    """
    # Assume RLBench quaternion order: [x, y, z, w]
    x, y, z, w = quat[0], quat[1], quat[2], quat[3]

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = np.arctan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    if abs(sinp) >= 1:
        pitch = np.sign(sinp) * np.pi / 2  # Use 90 degrees if out of range
    else:
        pitch = np.arcsin(sinp)

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = np.arctan2(siny_cosp, cosy_cosp)

    return np.array([roll, pitch, yaw])


def get_grasp_orientation(task_env, task_name):
    """
    Get a reliable grasp orientation for the task by extracting from a demo.

    Strategy: Run extract_demo_waypoints and return the orientation that was
    actually used in a successful RLBench demo. This is more reliable than
    using a fixed default orientation.

    Args:
        task_env: RLBench task environment
        task_name: str, name of the task

    Returns:
        orientation: np.ndarray(3,), euler angles for grasp orientation
    """
    # Extract from demo - this is the orientation that actually worked
    waypoint_params = extract_demo_waypoints(task_env, task_name, num_demos=1)
    return waypoint_params.get("grasp_orientation", DEFAULT_GRASP_ORIENTATION.copy())


def move_robot_to_start_with_gripper(task_env, start_pos, desired_ori, gripper_open=True):
    """
    Move robot to HOME configuration and prepare for trajectory.

    IMPORTANT: Uses disable_dynamics=True for set_joint_positions because
    the physics engine overrides joint positions when dynamics are enabled.

    NOTE: The start_pos parameter is currently ignored - we use HOME position
    as the start of trajectory because Jacobian IK can't handle large jumps.
    The actual start position is returned for the caller to use.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), requested start position (currently ignored)
        desired_ori: np.ndarray(3,), desired orientation for trajectory (currently ignored)
        gripper_open: bool, whether gripper should be open

    Returns:
        actual_start_pos: np.ndarray(3,), the actual start position (HOME position)
        actual_ori: np.ndarray(3,), the orientation to use for trajectory

    Raises:
        RuntimeError: if robot fails to reach HOME position
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME configuration first - MUST use disable_dynamics=True
    # Without this, the physics engine ignores the commanded positions
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)  # Stop any motion
    for _ in range(10):
        task_env._scene.pyrep.step()

    # Set gripper state
    gripper_velocity = 0.2
    if gripper_open:
        gripper.actuate(1.0, gripper_velocity)
    else:
        gripper.actuate(0.0, gripper_velocity)

    for _ in range(20):
        task_env._scene.pyrep.step()

    # Get actual position and orientation from HOME
    home_pos = tip.get_position()
    home_ori = tip.get_orientation()

    print(f"\t  After HOME: pos=[{home_pos[0]:.3f}, {home_pos[1]:.3f}, {home_pos[2]:.3f}], "
          f"ori=[{home_ori[0]:.3f}, {home_ori[1]:.3f}, {home_ori[2]:.3f}]")

    task_env._scene.pyrep.step()

    # Return HOME position and orientation as the actual start
    return home_pos.copy(), home_ori


def execute_step_with_gripper(task_env, target_ee_pos, desired_ori, gripper_state, extra_steps=0,
                              require_orientation=False):
    """
    Execute one step with position and gripper control.

    IMPORTANT: Uses set_joint_positions with disable_dynamics=True because
    set_joint_target_positions doesn't work reliably in some scenes (like meat_off_grill).
    This gives instant positioning rather than gradual dynamics-based movement.

    Args:
        task_env: RLBench task environment
        target_ee_pos: np.ndarray(3,), target end-effector position
        desired_ori: np.ndarray(3,), desired orientation (euler)
        gripper_state: float, 1.0 = open, 0.0 = closed
        extra_steps: int, additional physics steps (for grasp/release phases)
        require_orientation: bool, if True, raise error if orientation-constrained IK fails
                             if False, fall back to position-only IK

    Returns:
        obs: observation after the step
        joint_positions: the commanded joint positions

    Raises:
        IKError or ConfigurationPathError: if IK fails
    """
    robot = task_env._scene.robot
    gripper = robot.gripper

    # Solve IK - try with orientation first, then position-only if needed
    try:
        joint_positions = robot.arm.solve_ik(target_ee_pos, euler=desired_ori)
    except (IKError, ConfigurationPathError):
        if require_orientation:
            raise
        # Fall back to position-only IK
        joint_positions = robot.arm.solve_ik(target_ee_pos)

    # Use set_joint_positions with disable_dynamics=True for reliable movement
    # set_joint_target_positions doesn't work in meat_off_grill scene
    robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)

    # Set gripper
    gripper.actuate(gripper_state, 0.2)

    # Step simulation to let physics settle and gripper actuate
    num_steps = 5 + extra_steps
    for _ in range(num_steps):
        task_env._scene.pyrep.step()
        task_env._scene.task.step()

    obs = task_env._scene.get_observation()

    if not hasattr(obs, 'misc'):
        obs.misc = {}
    obs.misc['joint_position_action'] = np.concatenate(
        [joint_positions, np.array([gripper_state])]
    )

    return obs, joint_positions


def generate_pick_and_place_trajectory(
    task_env,
    start_pos,
    object_pos,
    target_pos,
    cp_idx,
    canonical_params,
    control_point_radius,
    grasp_orientation,
    waypoint_params=None,
    phase_steps=None,
    grasp_offset=None,
):
    """
    Generate a full pick-and-place trajectory with shaped reach and carry phases.

    Each shaped phase (reach, carry) gets 64 steps by default.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), starting EE position
        object_pos: np.ndarray(3,), object position (thing to pick)
        target_pos: np.ndarray(3,), target position (where to place)
        cp_idx: int, canonical control point index (used for both phases)
        canonical_params: np.ndarray, canonical control point parameters
        control_point_radius: float, radius for control point offset
        grasp_orientation: np.ndarray(3,), euler angles for grasp orientation
        waypoint_params: dict, extracted waypoint heights from demo (or defaults)
        phase_steps: dict, step counts per phase (default: DEFAULT_PHASE_STEPS)
        grasp_offset: np.ndarray(3,), optional offset for grasp

    Returns:
        demo: list of observations
        metadata: dict with trajectory info

    Raises:
        RuntimeError: if trajectory generation fails
    """
    if phase_steps is None:
        phase_steps = DEFAULT_PHASE_STEPS.copy()

    total_steps = sum(phase_steps.values())

    # Move to HOME position first and get actual start position/orientation
    # IMPORTANT: We use HOME as the start, not the original gripper position,
    # because Jacobian IK can't handle large jumps from arbitrary positions
    actual_start_pos, actual_ori = move_robot_to_start_with_gripper(
        task_env, start_pos, grasp_orientation, gripper_open=True
    )
    print(f"\t  Using orientation: [{actual_ori[0]:.3f}, {actual_ori[1]:.3f}, {actual_ori[2]:.3f}]")

    # Compute waypoints using ACTUAL start position (HOME), not requested start
    waypoints = compute_pick_place_waypoints(actual_start_pos, object_pos, target_pos, waypoint_params, grasp_offset)

    # Debug: print waypoints
    print(f"\t  Waypoints:")
    print(f"\t    start:    [{waypoints['start'][0]:.3f}, {waypoints['start'][1]:.3f}, {waypoints['start'][2]:.3f}]")
    print(f"\t    pregrasp: [{waypoints['pregrasp'][0]:.3f}, {waypoints['pregrasp'][1]:.3f}, {waypoints['pregrasp'][2]:.3f}]")
    print(f"\t    grasp:    [{waypoints['grasp'][0]:.3f}, {waypoints['grasp'][1]:.3f}, {waypoints['grasp'][2]:.3f}]")
    print(f"\t    object:   [{object_pos[0]:.3f}, {object_pos[1]:.3f}, {object_pos[2]:.3f}]")

    # Compute control points for shaped phases using same cp_idx
    angle, dist_frac, pos_frac = canonical_params[cp_idx]

    # Control point for reach: start -> pregrasp
    cp_reach = compute_control_point_from_params(
        waypoints["start"], waypoints["pregrasp"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    # Control point for carry: lift -> preplace (same canonical params, different endpoints)
    cp_carry = compute_control_point_from_params(
        waypoints["lift"], waypoints["preplace"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    # Generate all target positions and gripper states
    positions, gripper_states, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach, cp_carry
    )

    # Execute trajectory using actual achieved orientation
    demo = []
    demo.append(task_env._scene.get_observation())  # Initial observation

    successful_steps = 1
    failed_steps = 0
    max_failures = total_steps // 4  # Allow up to 25% failures

    # Phases that need extra physics steps for gripper operation
    GRIPPER_PHASES = {"close_gripper", "release"}

    # For grasp verification
    object_z_at_grasp = None
    object_z_after_lift = None
    task = task_env._task

    for i in range(1, len(positions)):
        target_pos_step = positions[i]
        gripper_state = gripper_states[i]
        phase = phase_labels[i]
        prev_phase = phase_labels[i-1] if i > 0 else None

        # Add extra physics steps during gripper close/release for reliable grip
        extra_steps = 10 if phase in GRIPPER_PHASES else 0

        # Debug: print first few steps to see what's happening
        if i <= 3:
            robot = task_env._scene.robot
            current_ee = robot.arm.get_tip().get_position()
            dist = np.linalg.norm(target_pos_step - current_ee)
            print(f"\t  Step {i} ({phase}): current_ee=[{current_ee[0]:.3f}, {current_ee[1]:.3f}, {current_ee[2]:.3f}], "
                  f"target=[{target_pos_step[0]:.3f}, {target_pos_step[1]:.3f}, {target_pos_step[2]:.3f}], dist={dist*1000:.1f}mm")

        try:
            obs, _ = execute_step_with_gripper(
                task_env, target_pos_step, actual_ori, gripper_state, extra_steps
            )
            demo.append(obs)
            successful_steps += 1

            # Debug: record object Z and EE position at end of close_gripper phase
            if prev_phase == "close_gripper" and phase == "lift":
                try:
                    obj = task._chicken if hasattr(task, '_chicken') else task._steak
                    object_z_at_grasp = obj.get_position()[2]
                    # Also print actual EE position vs target
                    robot = task_env._scene.robot
                    actual_ee = robot.arm.get_tip().get_position()
                    target_grasp = waypoints["grasp"]
                    print(f"\t  At grasp: EE=[{actual_ee[0]:.3f}, {actual_ee[1]:.3f}, {actual_ee[2]:.3f}], "
                          f"target=[{target_grasp[0]:.3f}, {target_grasp[1]:.3f}, {target_grasp[2]:.3f}], "
                          f"obj=[{obj.get_position()[0]:.3f}, {obj.get_position()[1]:.3f}, {obj.get_position()[2]:.3f}]")
                except:
                    pass

            # Debug: record object Z at end of lift phase
            if prev_phase == "lift" and phase == "carry":
                try:
                    obj = task._chicken if hasattr(task, '_chicken') else task._steak
                    object_z_after_lift = obj.get_position()[2]
                except:
                    pass

        except (IKError, ConfigurationPathError):
            failed_steps += 1
            if failed_steps > max_failures:
                raise RuntimeError(
                    f"Too many IK failures: {failed_steps}/{total_steps}"
                )
            # Use previous observation as placeholder
            demo.append(demo[-1])

    assert len(demo) == total_steps, f"Expected {total_steps} steps, got {len(demo)}"

    # Debug: verify grasp worked (object should have risen during lift)
    grasp_verified = False
    if object_z_at_grasp is not None and object_z_after_lift is not None:
        z_rise = object_z_after_lift - object_z_at_grasp
        grasp_verified = z_rise > 0.05  # Object rose at least 5cm
        if not grasp_verified:
            print(f"\t  WARNING: Grasp may have failed (object Z rise: {z_rise*1000:.1f}mm)")

    if failed_steps > 0:
        print(f"\t  IK: {successful_steps} ok, {failed_steps} failed")

    # Build metadata
    metadata = {
        "waypoints": {k: v.tolist() for k, v in waypoints.items()},
        "cp_reach": cp_reach.tolist(),
        "cp_carry": cp_carry.tolist(),
        "phase_indices": phase_indices,
        "phase_steps": phase_steps,
        "ik_failures": failed_steps,
        "requested_orientation": grasp_orientation.tolist(),  # What we asked for
        "actual_orientation": actual_ori.tolist(),            # What we actually used
    }

    return demo, metadata


def check_task_success(task_env):
    """
    Check if the task was successful using RLBench's built-in success check.

    Args:
        task_env: RLBench task environment

    Returns:
        bool: True if task succeeded
    """
    try:
        success, _ = task_env._task.success()
        return success
    except Exception:
        return False


def get_object_and_target_positions(task_env, task_name):
    """
    Get object and target positions for different tasks.

    IMPORTANT: Call this AFTER reset and object settling to get accurate positions.

    Supports:
        - meat_off_grill: chicken/steak -> side plate
        - put_rubbish_in_bin: rubbish -> bin
        - place_wine_at_rack_location: wine bottle -> rack slot

    Args:
        task_env: RLBench task environment
        task_name: str, name of the task

    Returns:
        object_pos: np.ndarray(3,), position of object to pick
        target_pos: np.ndarray(3,), position to place object
        object_name: str, name of the object (for logging)
    """
    task = task_env._task

    if task_name == "meat_off_grill":
        # MeatOffGrill: pick chicken or steak, place on side
        # Try chicken first, then steak
        try:
            obj = task._chicken
            object_name = "chicken"
        except AttributeError:
            obj = task._steak
            object_name = "steak"
        object_pos = obj.get_position()

        # Target is the success sensor (proximity sensor that defines success)
        # This is more reliable than _plate which may not exist or be positioned correctly
        target = task._success_sensor
        target_pos = target.get_position()

    elif task_name == "put_rubbish_in_bin":
        # PutRubbishInBin: pick rubbish, place in bin
        obj = task._rubbish
        object_pos = obj.get_position()
        object_name = "rubbish"

        target = task._bin
        target_pos = target.get_position()

    elif task_name == "place_wine_at_rack_location":
        # PlaceWineAtRackLocation: pick wine bottle, place at rack slot
        obj = task._wine_bottle
        object_pos = obj.get_position()
        object_name = "wine_bottle"

        # Target depends on variation (left/middle/right)
        try:
            target = task._target_rack_location
        except AttributeError:
            target = task._rack
        target_pos = target.get_position()

    else:
        raise ValueError(f"Unknown task: {task_name}. "
                        f"Supported: meat_off_grill, put_rubbish_in_bin, place_wine_at_rack_location")

    return object_pos.copy(), target_pos.copy(), object_name


def set_object_position_and_settle(task_env, task_name, new_pos, settle_steps=20, max_z_drop=0.05):
    """
    Set the object position for a task and let physics settle.

    Returns the ACTUAL object position after settling (may differ from commanded).

    Args:
        task_env: RLBench task environment
        task_name: str, name of the task
        new_pos: np.ndarray(3,), new object position
        settle_steps: int, number of physics steps to let object settle
        max_z_drop: float, maximum allowed Z drop (m). If exceeded, raises error.

    Returns:
        actual_pos: np.ndarray(3,), actual object position after settling

    Raises:
        RuntimeError: if object falls (Z drop exceeds max_z_drop)
    """
    task = task_env._task

    if task_name == "meat_off_grill":
        try:
            obj = task._chicken
        except AttributeError:
            obj = task._steak
    elif task_name == "put_rubbish_in_bin":
        obj = task._rubbish
    elif task_name == "place_wine_at_rack_location":
        obj = task._wine_bottle
    else:
        raise ValueError(f"Unknown task: {task_name}")

    # Set position
    obj.set_position(new_pos)

    # Let physics settle
    for _ in range(settle_steps):
        task_env._scene.pyrep.step()

    # Check actual position after settling
    actual_pos = obj.get_position().copy()

    # Validate object didn't fall off surface
    z_drop = new_pos[2] - actual_pos[2]
    if z_drop > max_z_drop:
        raise RuntimeError(
            f"Object fell off surface (Z drop: {z_drop*1000:.1f}mm > {max_z_drop*1000:.1f}mm). "
            f"Commanded: {new_pos}, Actual: {actual_pos}"
        )

    return actual_pos


# ============================================================================
# Pose Alignment for Evaluation
# ============================================================================

def align_robot_to_carry_start(task_env, target_ee_pos, orientation, settle_steps=10):
    """
    Align robot to the expected CARRY phase starting pose.

    This is critical for evaluation because the diffusion model expects specific
    input states (matching training data). After the LIFT phase, the actual robot
    pose may differ slightly from the training data CARRY starting pose.

    The training data shows CARRY phases start with:
    - Specific joint positions (at lift position)
    - Gripper closed (gripper_open ~ 0 in state, action = -1)

    Args:
        task_env: RLBench task environment
        target_ee_pos: np.ndarray(3,), target end-effector position (lift position)
        orientation: np.ndarray(3,), euler orientation for IK
        settle_steps: int, physics steps to settle after positioning

    Returns:
        success: bool, whether alignment was successful
        actual_joints: np.ndarray(7,), actual joint positions after alignment
    """
    robot = task_env._scene.robot
    gripper = robot.gripper

    try:
        # Use IK to reach target EE position
        joint_positions = robot.arm.solve_ik(target_ee_pos, euler=orientation)

        # Apply joint positions
        robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)

        # Ensure gripper is closed (this is critical for CARRY phase)
        gripper.actuate(0.0, 0.2)  # 0.0 = closed

        # Let physics settle
        for _ in range(settle_steps):
            task_env._scene.pyrep.step()

        # Get actual positions
        actual_joints = np.array(robot.arm.get_joint_positions())
        actual_ee = np.array(robot.arm.get_tip().get_position())

        # Check alignment quality
        ee_error = np.linalg.norm(actual_ee - target_ee_pos)
        if ee_error > 0.02:  # More than 2cm error
            print(f"    WARNING: Alignment EE error = {ee_error*1000:.1f}mm")

        return True, actual_joints

    except Exception as e:
        print(f"    Alignment failed: {e}")
        return False, None


def get_carry_start_pose_from_training(carry_dataset_path, carry_metadata_path, traj_idx=0):
    """
    Get the expected CARRY starting pose from training data.

    This loads a CARRY trajectory and extracts the initial state,
    which can be used to align the robot before running CARRY phase.

    Args:
        carry_dataset_path: str, path to CARRY dataset .npz file
        carry_metadata_path: str, path to CARRY metadata .npy file
        traj_idx: int, trajectory index to use (default: first trajectory)

    Returns:
        dict with:
            - start_joints: np.ndarray(7,), starting joint positions (normalized)
            - start_gripper: float, starting gripper state (normalized)
            - start_state: np.ndarray(22,), full starting state (normalized)
            - z_embedding: np.ndarray(3,), control point parameters
            - target_pos: np.ndarray(3,), target end position
    """
    data = np.load(carry_dataset_path)
    metadata = np.load(carry_metadata_path, allow_pickle=True)

    # Get trajectory start index
    traj_lengths = data['traj_lengths']
    start_idx = sum(traj_lengths[:traj_idx])

    # Get starting state
    start_state = data['states'][start_idx]  # (22,) normalized
    start_actions = data['actions'][start_idx]  # (8,) normalized

    # Extract components
    start_joints = start_state[:7]  # Joint positions (normalized)
    start_gripper = start_state[7]  # Gripper state (normalized)

    # Get metadata
    meta = metadata[traj_idx]
    z_embedding = np.array(meta['canonical_cp_params'])
    target_pos = np.array(meta['end_pos'])

    return {
        'start_joints': start_joints,
        'start_gripper': start_gripper,
        'start_state': start_state,
        'start_action': start_actions,
        'z_embedding': z_embedding,
        'target_pos': target_pos,
    }
