"""
Debug script for REACH + GRASP + LIFT trajectory generation.

Phases:
1. REACH: HOME → pregrasp (Bezier curve with control point, 64 steps)
2. DESCEND: pregrasp → grasp position (linear, 8 steps)
3. GRASP: close gripper at grasp position (8 steps)
4. LIFT: grasp → pregrasp (linear, 8 steps)

Total: 64 + 8 + 8 + 8 = 88 steps

Tests filtered control points and visualizes the full reach-grasp-lift trajectory.

IMPORTANT: This script extracts grasp parameters (height, orientation, xy_offset) from
actual RLBench demos using extract_demo_waypoints(), similar to how the original
dataset_generator_pick_place_cp.py does it. This ensures we use parameters that
actually work in successful grasps.
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    prefilter_control_points,
    prefilter_control_points_with_collision,
    parabola3D,
    HOME_JOINTS,
)

# Import from pick_and_place_utils for grasp parameter extraction
from pick_and_place_utils import (
    extract_demo_waypoints,
    DEFAULT_WAYPOINT_OFFSETS,
    DEFAULT_GRASP_ORIENTATION,
)


def linear_interpolate(p1, p2, t):
    """Linear interpolation between two points."""
    return p1 + t * (p2 - p1)


def generate_reach_grasp_lift_trajectory(task_env, home_pos, pregrasp_pos, grasp_pos,
                                          control_point, orientation, lift_pos=None,
                                          reach_steps=64, descend_steps=8,
                                          grasp_steps=8, lift_steps=8,
                                          steps_per_point=5):
    """
    Generate full reach-grasp-lift trajectory with physics.

    Phases:
    1. REACH: home → pregrasp (Bezier curve, reach_steps)
    2. DESCEND: pregrasp → grasp (linear, descend_steps)
    3. GRASP: hold at grasp, close gripper (grasp_steps)
    4. LIFT: grasp → lift_pos (linear, lift_steps) - if lift_pos provided, else grasp → pregrasp

    Returns:
        trace: np.ndarray of tip positions
        phase_labels: list of phase names for each step
        gripper_states: list of gripper states (1.0=open, 0.0=closed)
        object_lifted: bool, whether object was successfully lifted
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Use lift_pos if provided, otherwise lift back to pregrasp
    if lift_pos is None:
        lift_pos = pregrasp_pos.copy()

    trace = []
    phase_labels = []
    gripper_states = []
    ik_failures = 0

    prev_joints = list(robot.arm.get_joint_positions())

    # Get object reference for checking if grasp succeeded
    from pyrep.objects.shape import Shape
    chicken = Shape('chicken')
    object_z_before_lift = None

    # ========================================================================
    # Phase 1: REACH (Bezier curve) - home → pregrasp
    # ========================================================================
    for i in range(reach_steps):
        t = i / max(reach_steps - 1, 1)
        target_pos = parabola3D(home_pos, pregrasp_pos, control_point, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper open during reach
            gripper.actuate(1.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            trace.append(tip.get_position().copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)
        except (IKError, ConfigurationPathError):
            ik_failures += 1
            if len(trace) > 0:
                trace.append(trace[-1].copy())
            else:
                trace.append(home_pos.copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)

    # ========================================================================
    # Phase 2: DESCEND (linear) - pregrasp → grasp
    # ========================================================================
    print(f"\n  DESCEND: pregrasp={pregrasp_pos} -> grasp={grasp_pos}")
    for i in range(descend_steps):
        t = i / max(descend_steps - 1, 1)
        target_pos = linear_interpolate(pregrasp_pos, grasp_pos, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper still open during descend
            gripper.actuate(1.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("descend")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)

            # Debug: print first and last descend step
            if i == 0 or i == descend_steps - 1:
                error = np.linalg.norm(actual_pos - target_pos) * 1000
                print(f"    step {i}: target_z={target_pos[2]:.4f}, actual_z={actual_pos[2]:.4f}, error={error:.1f}mm")
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy())
            phase_labels.append("descend")
            gripper_states.append(1.0)

    # ========================================================================
    # Phase 3: GRASP - hold position, close gripper, then ATTACH object
    # ========================================================================
    # Record object Z before grasp
    object_z_before_lift = chicken.get_position()[2]
    print(f"\n  GRASP: hold at grasp_pos={grasp_pos}, object_z_before={object_z_before_lift:.4f}")

    for i in range(grasp_steps):
        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(grasp_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Close gripper
            gripper.actuate(0.0, 0.2)

            # Extra physics steps for gripper to close
            for _ in range(steps_per_point + 5):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("grasp")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            # Debug: print gripper state
            if i == grasp_steps - 1:
                gripper_open_amount = gripper.get_open_amount()[0]
                print(f"    After GRASP: ee_z={actual_pos[2]:.4f}, obj_z={chicken.get_position()[2]:.4f}, gripper_open={gripper_open_amount:.3f}")
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy())
            phase_labels.append("grasp")
            gripper_states.append(0.0)

    # IMPORTANT: After gripper closes, use gripper.grasp() to ATTACH object to gripper!
    # gripper.actuate() just closes the fingers, but doesn't establish parent-child attachment
    # gripper.grasp() checks proximity sensor and attaches object to gripper
    grasp_success = gripper.grasp(chicken)
    print(f"    gripper.grasp() result: {grasp_success}")

    # ========================================================================
    # Phase 4: LIFT (linear) - grasp → lift_pos (or pregrasp if lift_pos not provided)
    # ========================================================================
    print(f"\n  LIFT: grasp={grasp_pos} -> lift={lift_pos}")
    for i in range(lift_steps):
        t = i / max(lift_steps - 1, 1)
        target_pos = linear_interpolate(grasp_pos, lift_pos, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper closed during lift
            gripper.actuate(0.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            actual_pos = tip.get_position().copy()
            trace.append(actual_pos)
            phase_labels.append("lift")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            # Debug: print first and last lift step
            if i == 0 or i == lift_steps - 1:
                obj_z = chicken.get_position()[2]
                print(f"    step {i}: ee_z={actual_pos[2]:.4f}, obj_z={obj_z:.4f}")
        except (IKError, ConfigurationPathError) as e:
            ik_failures += 1
            print(f"    step {i}: IK FAILED - {e}")
            trace.append(trace[-1].copy())
            phase_labels.append("lift")
            gripper_states.append(0.0)

    # Check if object was lifted
    object_z_after_lift = chicken.get_position()[2]
    object_lifted = (object_z_after_lift - object_z_before_lift) > 0.02  # Lifted at least 2cm

    return {
        'trace': np.array(trace),
        'phase_labels': phase_labels,
        'gripper_states': gripper_states,
        'ik_failures': ik_failures,
        'object_lifted': object_lifted,
        'object_z_before': object_z_before_lift,
        'object_z_after': object_z_after_lift,
    }


def compute_jitter_metric(trace):
    """Compute jitter as sum of direction changes (higher = more jitter)."""
    if len(trace) < 3:
        return 0.0
    velocities = np.diff(trace, axis=0)
    # Normalize velocities
    norms = np.linalg.norm(velocities, axis=1, keepdims=True)
    norms = np.where(norms < 1e-8, 1e-8, norms)
    velocities_norm = velocities / norms
    # Compute dot products between consecutive velocities
    dots = np.sum(velocities_norm[:-1] * velocities_norm[1:], axis=1)
    # Jitter = sum of (1 - dot), i.e., direction changes
    jitter = np.sum(1 - dots)
    return jitter


def main():
    print("=" * 60)
    print("REACH + GRASP + LIFT TRAJECTORY TEST")
    print("=" * 60)

    task_name = "meat_off_grill"

    # Phase configuration
    REACH_STEPS = 64
    DESCEND_STEPS = 8
    GRASP_STEPS = 8
    LIFT_STEPS = 8
    TOTAL_STEPS = REACH_STEPS + DESCEND_STEPS + GRASP_STEPS + LIFT_STEPS
    STEPS_PER_POINT = 5

    print(f"\nPhase configuration:")
    print(f"  REACH:   {REACH_STEPS} steps (Bezier curve)")
    print(f"  DESCEND: {DESCEND_STEPS} steps (linear)")
    print(f"  GRASP:   {GRASP_STEPS} steps (close gripper)")
    print(f"  LIFT:    {LIFT_STEPS} steps (linear)")
    print(f"  TOTAL:   {TOTAL_STEPS} steps")

    obs_config = create_obs_config(save_video=False)
    action_mode = create_action_mode("abs")

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    tasks = get_task_classes([task_name])
    t_cls = tasks[0]
    task_env = rlbench_env.get_task(t_cls)
    task_env.set_variation(0)

    # ========================================================================
    # STEP 1: Extract grasp parameters from actual RLBench demos
    # This is how the original dataset_generator_pick_place_cp.py does it!
    # ========================================================================
    print(f"\n{'='*60}")
    print("EXTRACTING GRASP PARAMETERS FROM RLBENCH DEMO")
    print(f"{'='*60}")

    waypoint_params = extract_demo_waypoints(task_env, task_name, num_demos=1)

    # Get the extracted parameters
    # Override PREGRASP_HEIGHT to be lower (closer to object)
    PREGRASP_HEIGHT = 0.05  # 5cm above object (lowered from extracted ~10cm)
    GRASP_HEIGHT = waypoint_params.get("grasp_height", 0.04)
    LIFT_HEIGHT = waypoint_params.get("lift_height", 0.15)
    GRASP_ORIENTATION = waypoint_params.get("grasp_orientation", DEFAULT_GRASP_ORIENTATION.copy())
    GRASP_XY_OFFSET = waypoint_params.get("grasp_xy_offset", np.zeros(2))

    print(f"\n  Extracted parameters:")
    print(f"    PREGRASP_HEIGHT: {PREGRASP_HEIGHT:.4f}m")
    print(f"    GRASP_HEIGHT:    {GRASP_HEIGHT:.4f}m")
    print(f"    LIFT_HEIGHT:     {LIFT_HEIGHT:.4f}m")
    print(f"    GRASP_XY_OFFSET: [{GRASP_XY_OFFSET[0]:.4f}, {GRASP_XY_OFFSET[1]:.4f}]m")
    print(f"    GRASP_ORIENTATION: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")

    # Build canonical control points
    canonical_params = generate_canonical_control_point_params(n_per_axis=5)
    control_point_radius = 0.05
    print(f"\nTotal canonical control points: {len(canonical_params)}")

    # ========================================================================
    # PREFILTER: Collision check for REACH phase only
    # ========================================================================
    print(f"\n{'='*60}")
    print("PRE-FILTERING CONTROL POINTS (REACH phase)")
    print(f"{'='*60}")

    np.random.seed(42)
    task_env.reset()

    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Go to HOME
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    gripper.actuate(1.0, 0.2)
    for _ in range(20):
        task_env._scene.pyrep.step()

    home_pos = tip.get_position().copy()
    home_ori = tip.get_orientation()

    # Get object position and FIX it for all tests
    from pyrep.objects.shape import Shape
    chicken = Shape('chicken')
    FIXED_OBJECT_POS = chicken.get_position().copy()  # Store as fixed reference

    # Compute waypoints using extracted parameters
    # Apply XY offset from demo (important for accurate grasp)
    pregrasp_pos = FIXED_OBJECT_POS.copy()
    pregrasp_pos[:2] += GRASP_XY_OFFSET  # Apply XY offset from demo
    pregrasp_pos[2] += PREGRASP_HEIGHT

    grasp_pos = FIXED_OBJECT_POS.copy()
    grasp_pos[:2] += GRASP_XY_OFFSET  # Apply XY offset from demo
    grasp_pos[2] += GRASP_HEIGHT

    # Lift position - go back to pregrasp (same XY, same Z as pregrasp)
    # DESCEND/GRASP/LIFT only change Z, XY stays constant
    lift_pos = pregrasp_pos.copy()  # Lift back to pregrasp position

    print(f"  HOME pos:     [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
    print(f"  Object pos (FIXED): [{FIXED_OBJECT_POS[0]:.4f}, {FIXED_OBJECT_POS[1]:.4f}, {FIXED_OBJECT_POS[2]:.4f}]")
    print(f"  Pregrasp pos: [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")
    print(f"  Grasp pos:    [{grasp_pos[0]:.4f}, {grasp_pos[1]:.4f}, {grasp_pos[2]:.4f}]")
    print(f"  Lift pos:     [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"  Grasp XY offset: [{GRASP_XY_OFFSET[0]:.4f}, {GRASP_XY_OFFSET[1]:.4f}]")

    # Prefilter control points for REACH phase (using extracted GRASP_ORIENTATION)
    print(f"\n  Testing IK + COLLISION for all {len(canonical_params)} control points...")
    print(f"  Using orientation: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")
    valid_indices, prefilter_results = prefilter_control_points_with_collision(
        robot.arm, home_pos, pregrasp_pos, GRASP_ORIENTATION,
        canonical_params, control_point_radius,
        num_samples=16, require_orientation=True
    )

    collision_failures = sum(1 for r in prefilter_results.values() if r['collision_count'] > 0)
    ik_failures = sum(1 for r in prefilter_results.values() if r['ik_failures'] > 0)
    print(f"  Valid (IK + collision-free): {len(valid_indices)}/{len(canonical_params)}")
    print(f"  Rejected due to collision: {collision_failures}")
    print(f"  Rejected due to IK failure: {ik_failures}")

    if len(valid_indices) == 0:
        print("  ERROR: No valid control points found!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # SECOND FILTER: Jitter filter for REACH phase
    # ========================================================================
    print(f"\n{'='*60}")
    print("SECOND FILTER: Jitter filter (REACH phase only)")
    print(f"{'='*60}")

    JITTER_THRESHOLD = 0.08
    final_valid_indices = []

    for cp_idx in valid_indices:
        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)
        for _ in range(20):
            task_env._scene.pyrep.step()

        home_pos = tip.get_position().copy()

        # Use FIXED object position (ensure meat is in consistent location)
        chicken = Shape('chicken')
        chicken.set_position(FIXED_OBJECT_POS)
        for _ in range(5):
            task_env._scene.pyrep.step()

        # Recompute pregrasp with XY offset
        pregrasp_pos = FIXED_OBJECT_POS.copy()
        pregrasp_pos[:2] += GRASP_XY_OFFSET
        pregrasp_pos[2] += PREGRASP_HEIGHT

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            home_pos, pregrasp_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Quick reach-only rollout for jitter check (using extracted GRASP_ORIENTATION)
        trace = []
        prev_joints = list(robot.arm.get_joint_positions())

        for i in range(REACH_STEPS):
            t = i / (REACH_STEPS - 1)
            target_pos = parabola3D(home_pos, pregrasp_pos, cp, t)

            try:
                robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=GRASP_ORIENTATION)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
                robot.arm.set_joint_target_velocities([0] * 7)

                for _ in range(STEPS_PER_POINT):
                    task_env._scene.pyrep.step()

                trace.append(tip.get_position().copy())
                prev_joints = list(joint_positions)
            except (IKError, ConfigurationPathError):
                if len(trace) > 0:
                    trace.append(trace[-1].copy())
                else:
                    trace.append(home_pos.copy())

        trace = np.array(trace)
        jitter = compute_jitter_metric(trace)

        if jitter <= JITTER_THRESHOLD:
            final_valid_indices.append(cp_idx)
            print(f"  CP {cp_idx:3d}: jitter={jitter:.4f} - PASS")
        else:
            print(f"  CP {cp_idx:3d}: jitter={jitter:.4f} - REJECT (>{JITTER_THRESHOLD})")

    print(f"\n  Final valid CPs: {len(final_valid_indices)}/{len(valid_indices)} (after jitter filter)")

    if len(final_valid_indices) == 0:
        print("  ERROR: No valid control points after jitter filter!")
        rlbench_env.shutdown()
        return

    # ========================================================================
    # TEST: Full REACH + GRASP + LIFT trajectory
    # ========================================================================
    num_tests = min(10, len(final_valid_indices))  # Test up to 10 trajectories
    test_cp_indices = [final_valid_indices[i] for i in np.linspace(0, len(final_valid_indices)-1, num_tests, dtype=int)]

    print(f"\n{'='*60}")
    print(f"TESTING {num_tests} FULL TRAJECTORIES (REACH + GRASP + LIFT)")
    print(f"{'='*60}")

    all_results = []

    for trial_idx, cp_idx in enumerate(test_cp_indices):
        print(f"\n--- Trial {trial_idx+1}/{num_tests} (cp_idx={cp_idx}) ---")

        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)

        for _ in range(20):
            task_env._scene.pyrep.step()

        home_pos = tip.get_position().copy()
        home_ori = tip.get_orientation()

        # FIX OBJECT POSITION - ensure meat is at consistent location
        chicken = Shape('chicken')
        chicken.set_position(FIXED_OBJECT_POS)
        for _ in range(10):  # Let physics settle
            task_env._scene.pyrep.step()

        object_pos = chicken.get_position().copy()

        # Compute waypoints using extracted parameters
        pregrasp_pos = FIXED_OBJECT_POS.copy()
        pregrasp_pos[:2] += GRASP_XY_OFFSET
        pregrasp_pos[2] += PREGRASP_HEIGHT

        grasp_pos = FIXED_OBJECT_POS.copy()
        grasp_pos[:2] += GRASP_XY_OFFSET
        grasp_pos[2] += GRASP_HEIGHT

        # Lift position - go back to pregrasp (same XY, same Z as pregrasp)
        # DESCEND/GRASP/LIFT only change Z, XY stays constant
        lift_pos = pregrasp_pos.copy()

        print(f"  HOME:     [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
        print(f"  HOME ori: [{home_ori[0]:.4f}, {home_ori[1]:.4f}, {home_ori[2]:.4f}]")
        print(f"  Object:   [{object_pos[0]:.4f}, {object_pos[1]:.4f}, {object_pos[2]:.4f}] (fixed)")
        print(f"  Pregrasp: [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")
        print(f"  Grasp:    [{grasp_pos[0]:.4f}, {grasp_pos[1]:.4f}, {grasp_pos[2]:.4f}]")
        print(f"  Lift:     [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
        print(f"  Using orientation: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")

        angle, dist_frac, pos_frac = canonical_params[cp_idx]
        cp = compute_control_point_from_params(
            home_pos, pregrasp_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Generate full trajectory with extracted GRASP_ORIENTATION
        result = generate_reach_grasp_lift_trajectory(
            task_env, home_pos, pregrasp_pos, grasp_pos, cp, GRASP_ORIENTATION,
            lift_pos=lift_pos,
            reach_steps=REACH_STEPS, descend_steps=DESCEND_STEPS,
            grasp_steps=GRASP_STEPS, lift_steps=LIFT_STEPS,
            steps_per_point=STEPS_PER_POINT
        )

        trace = result['trace']
        jitter = compute_jitter_metric(trace)
        object_lifted = result['object_lifted']
        z_rise = result['object_z_after'] - result['object_z_before']

        print(f"  Total steps: {len(trace)}")
        print(f"  Jitter: {jitter:.4f}")
        print(f"  IK failures: {result['ik_failures']}")
        print(f"  Object Z rise: {z_rise*1000:.1f}mm")
        print(f"  Object lifted: {'YES' if object_lifted else 'NO'}")

        all_results.append({
            "cp_idx": cp_idx,
            "trace": trace,
            "phase_labels": result['phase_labels'],
            "gripper_states": result['gripper_states'],
            "jitter": jitter,
            "ik_failures": result['ik_failures'],
            "object_lifted": object_lifted,
            "object_z_rise": z_rise,
            "home_pos": home_pos,
            "pregrasp_pos": pregrasp_pos,
            "grasp_pos": grasp_pos,
            "lift_pos": lift_pos,
            "object_pos": object_pos,
            "control_point": cp,
        })

    # ========================================================================
    # SUMMARY
    # ========================================================================
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")

    n_lifted = sum(1 for r in all_results if r['object_lifted'])
    avg_jitter = np.mean([r['jitter'] for r in all_results])
    avg_z_rise = np.mean([r['object_z_rise'] for r in all_results])

    print(f"\n  Trajectories tested: {len(all_results)}")
    print(f"  Objects lifted: {n_lifted}/{len(all_results)}")
    print(f"  Average jitter: {avg_jitter:.4f}")
    print(f"  Average Z rise: {avg_z_rise*1000:.1f}mm")

    print(f"\n  {'Trial':>5} | {'CP':>4} | {'Jitter':>8} | {'Z Rise':>10} | {'Lifted':>6}")
    print(f"  {'-'*5}-+-{'-'*4}-+-{'-'*8}-+-{'-'*10}-+-{'-'*6}")
    for i, r in enumerate(all_results):
        lifted_str = "YES" if r['object_lifted'] else "NO"
        print(f"  {i+1:5d} | {r['cp_idx']:4d} | {r['jitter']:8.4f} | {r['object_z_rise']*1000:8.1f}mm | {lifted_str:>6}")

    # ========================================================================
    # VISUALIZATION
    # ========================================================================
    print(f"\n{'='*60}")
    print("Creating visualization...")
    print(f"{'='*60}")

    fig = plt.figure(figsize=(18, 12))

    r0 = all_results[0]
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_results)))

    # Plot 1: 3D view - All trajectories
    ax1 = fig.add_subplot(231, projection='3d')
    for i, r in enumerate(all_results):
        trace = r['trace']
        lifted_label = "✓" if r['object_lifted'] else "✗"
        ax1.plot(trace[:, 0], trace[:, 1], trace[:, 2], '-', color=colors[i],
                 linewidth=1.5, alpha=0.8, label=f'CP{r["cp_idx"]} {lifted_label}')
        ax1.scatter(*r["control_point"], c=[colors[i]], marker='x', s=30, alpha=0.5)

    ax1.scatter(*r0["home_pos"], c='black', marker='o', s=150, zorder=10, label='HOME')
    ax1.scatter(*r0["pregrasp_pos"], c='orange', marker='*', s=200, zorder=10, label='Pregrasp')
    ax1.scatter(*r0["grasp_pos"], c='red', marker='v', s=150, zorder=10, label='Grasp')
    ax1.scatter(*r0["lift_pos"], c='purple', marker='^', s=150, zorder=10, label='Lift')
    ax1.scatter(*r0["object_pos"], c='green', marker='s', s=100, zorder=10, label='Object')

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title(f'3D View - REACH+GRASP+LIFT ({n_lifted}/{len(all_results)} lifted)')
    ax1.legend(fontsize=6, loc='upper left')

    # Plot 2: XZ view (side)
    ax2 = fig.add_subplot(232)
    for i, r in enumerate(all_results):
        trace = r['trace']
        ax2.plot(trace[:, 0], trace[:, 2], '-', color=colors[i], linewidth=1.5, alpha=0.8)

    ax2.scatter(r0["home_pos"][0], r0["home_pos"][2], c='black', marker='o', s=150, zorder=10, label='HOME')
    ax2.scatter(r0["pregrasp_pos"][0], r0["pregrasp_pos"][2], c='orange', marker='*', s=200, zorder=10, label='Pregrasp')
    ax2.scatter(r0["grasp_pos"][0], r0["grasp_pos"][2], c='red', marker='v', s=150, zorder=10, label='Grasp')
    ax2.scatter(r0["lift_pos"][0], r0["lift_pos"][2], c='purple', marker='^', s=150, zorder=10, label='Lift')
    ax2.scatter(r0["object_pos"][0], r0["object_pos"][2], c='green', marker='s', s=100, zorder=10, label='Object')

    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Z (m)')
    ax2.set_title('XZ View (Side)')
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)

    # Plot 3: XY view (top)
    ax3 = fig.add_subplot(233)
    for i, r in enumerate(all_results):
        trace = r['trace']
        ax3.plot(trace[:, 0], trace[:, 1], '-', color=colors[i], linewidth=1.5, alpha=0.8)

    ax3.scatter(r0["home_pos"][0], r0["home_pos"][1], c='black', marker='o', s=150, zorder=10)
    ax3.scatter(r0["pregrasp_pos"][0], r0["pregrasp_pos"][1], c='orange', marker='*', s=200, zorder=10)
    ax3.scatter(r0["grasp_pos"][0], r0["grasp_pos"][1], c='red', marker='v', s=150, zorder=10)
    ax3.scatter(r0["lift_pos"][0], r0["lift_pos"][1], c='purple', marker='^', s=150, zorder=10)
    ax3.scatter(r0["object_pos"][0], r0["object_pos"][1], c='green', marker='s', s=100, zorder=10)

    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Y (m)')
    ax3.set_title('XY View (Top)')
    ax3.grid(True, alpha=0.3)

    # Plot 4: Z position over time (shows descent and lift)
    ax4 = fig.add_subplot(234)
    for i, r in enumerate(all_results):
        trace = r['trace']
        ax4.plot(trace[:, 2], '-', color=colors[i], linewidth=1.5, alpha=0.8, label=f'CP{r["cp_idx"]}')

    # Add phase boundaries
    ax4.axvline(x=REACH_STEPS, color='gray', linestyle='--', alpha=0.5, label='Phase boundary')
    ax4.axvline(x=REACH_STEPS + DESCEND_STEPS, color='gray', linestyle='--', alpha=0.5)
    ax4.axvline(x=REACH_STEPS + DESCEND_STEPS + GRASP_STEPS, color='gray', linestyle='--', alpha=0.5)

    ax4.set_xlabel('Step')
    ax4.set_ylabel('Z (m)')
    ax4.set_title('Z Position Over Time (REACH → DESCEND → GRASP → LIFT)')
    ax4.legend(fontsize=6, ncol=2)
    ax4.grid(True, alpha=0.3)

    # Plot 5: Jitter and lift success bar chart
    ax5 = fig.add_subplot(235)
    x = np.arange(len(all_results))
    width = 0.6

    bar_colors = ['green' if r['object_lifted'] else 'red' for r in all_results]
    bars = ax5.bar(x, [r['jitter'] for r in all_results], width, color=bar_colors, alpha=0.7)

    ax5.axhline(y=JITTER_THRESHOLD, color='orange', linestyle='--', label=f'Threshold={JITTER_THRESHOLD}')
    ax5.set_xlabel('Control Point')
    ax5.set_ylabel('Jitter Metric')
    ax5.set_title('Jitter (green=lifted, red=not lifted)')
    ax5.set_xticks(x)
    ax5.set_xticklabels([f'{r["cp_idx"]}' for r in all_results], fontsize=8)
    ax5.legend(fontsize=8)

    # Plot 6: Object Z rise
    ax6 = fig.add_subplot(236)
    z_rises = [r['object_z_rise'] * 1000 for r in all_results]  # mm
    bar_colors = ['green' if r['object_lifted'] else 'red' for r in all_results]
    ax6.bar(x, z_rises, width, color=bar_colors, alpha=0.7)
    ax6.axhline(y=20, color='orange', linestyle='--', label='Threshold=20mm')

    ax6.set_xlabel('Control Point')
    ax6.set_ylabel('Object Z Rise (mm)')
    ax6.set_title('Object Lift Height')
    ax6.set_xticks(x)
    ax6.set_xticklabels([f'{r["cp_idx"]}' for r in all_results], fontsize=8)
    ax6.legend(fontsize=8)

    plt.tight_layout()
    plt.savefig('reach_grasp_lift_debug.png', dpi=150)
    print(f"\nPlot saved to: reach_grasp_lift_debug.png")
    plt.show()

    rlbench_env.shutdown()
    print("\n" + "=" * 60)
    print("TEST COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()
