"""
Debug script for FULL PICK-AND-PLACE trajectory generation.

Phases:
1. REACH: HOME → pregrasp (Bezier curve with control point, 64 steps)
2. DESCEND: pregrasp → grasp position (linear Z only, 8 steps)
3. GRASP: close gripper at grasp position (8 steps)
4. LIFT: grasp → pregrasp (linear Z only, 8 steps)
5. CARRY: pregrasp → prerelease (Bezier curve with control point, 64 steps)
6. DESCEND_RELEASE: prerelease → release position (linear Z only, 8 steps)
7. RELEASE: open gripper at release position (8 steps)

Total: 64 + 8 + 8 + 8 + 64 + 8 + 8 = 168 steps

Two Bezier curve phases (REACH and CARRY) have their control points filtered using:
- IK feasibility check
- Collision check
- Jitter metric (threshold 0.08)

IMPORTANT: This script extracts grasp parameters (height, orientation, xy_offset) from
actual RLBench demos using extract_demo_waypoints(), similar to how the original
dataset_generator_pick_place_cp.py does it. This ensures we use parameters that
actually work in successful grasps.
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import imageio

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    compute_control_point_from_params,
    parabola3D,
    HOME_JOINTS,
)

from grasp_and_lift import (
    descend_and_grasp,
    lift,
    descend_and_release,
    capture_frame,
)

# Video output directory
VIDEO_OUTPUT_DIR = "/scratch4/workspace/placeholder-hdp1/test/"


def generate_full_pick_place_trajectory(task_env, home_pos, pregrasp_pos, grasp_pos,
                                          reach_cp, orientation,
                                          prerelease_pos, release_pos, carry_cp,
                                          object_shape_name='stack_blocks_target0',
                                          reach_steps=64, descend_steps=8,
                                          grasp_steps=8, lift_steps=8,
                                          carry_steps=64, descend_release_steps=8,
                                          release_steps=8,
                                          steps_per_point=5,
                                          capture_video=False):
    """
    Generate full pick-and-place trajectory with physics.

    Phases:
    1. REACH: home → pregrasp (Bezier curve, reach_steps) - learned by diffusion
    2. DESCEND + GRASP: pregrasp → grasp, close gripper - hard-coded
    3. LIFT: grasp → lift position - hard-coded
    4. CARRY: lift → prerelease (Bezier curve, carry_steps) - learned by diffusion
    5. DESCEND_RELEASE + RELEASE: prerelease → release, open gripper - hard-coded

    Returns dict with:
        trace: np.ndarray of tip positions
        phase_labels: list of phase names for each step
        gripper_states: list of gripper states (1.0=open, 0.0=closed)
        object_lifted: bool, whether object was successfully lifted
        object_released: bool, whether object was successfully released at target
        frames: list of RGB frames (if capture_video=True)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Lift position is always same as pregrasp (only Z changes during descend/lift)
    lift_pos = pregrasp_pos.copy()

    trace = []
    phase_labels = []
    gripper_states = []
    ik_failures = 0
    frames = []

    prev_joints = list(robot.arm.get_joint_positions())

    # ========================================================================
    # Phase 1: REACH (Bezier curve) - home → pregrasp
    # This is learned by the diffusion model
    # ========================================================================
    for i in range(reach_steps):
        t = i / max(reach_steps - 1, 1)
        target_pos = parabola3D(home_pos, pregrasp_pos, reach_cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper open during reach
            gripper.actuate(1.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            trace.append(tip.get_position().copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError):
            ik_failures += 1
            if len(trace) > 0:
                trace.append(trace[-1].copy())
            else:
                trace.append(home_pos.copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # ========================================================================
    # Phase 2 & 3: DESCEND + GRASP (hard-coded from grasp_and_lift.py)
    # ========================================================================
    grasp_result = descend_and_grasp(
        task_env, pregrasp_pos, grasp_pos, orientation,
        object_shape_name=object_shape_name,
        descend_steps=descend_steps, grasp_steps=grasp_steps,
        steps_per_point=steps_per_point, capture_video=capture_video,
        verbose=True
    )
    trace.extend(grasp_result['trace'])
    phase_labels.extend(grasp_result['phase_labels'])
    gripper_states.extend(grasp_result['gripper_states'])
    ik_failures += grasp_result['ik_failures']
    frames.extend(grasp_result['frames'])
    object_z_before_lift = grasp_result['object_z_before']

    # ========================================================================
    # Phase 4: LIFT (hard-coded from grasp_and_lift.py)
    # ========================================================================
    lift_result = lift(
        task_env, grasp_pos, lift_pos, orientation,
        object_shape_name=object_shape_name,
        lift_steps=lift_steps, steps_per_point=steps_per_point,
        capture_video=capture_video, verbose=True,
        prev_joints=grasp_result['prev_joints']
    )
    trace.extend(lift_result['trace'])
    phase_labels.extend(lift_result['phase_labels'])
    gripper_states.extend(lift_result['gripper_states'])
    ik_failures += lift_result['ik_failures']
    frames.extend(lift_result['frames'])
    object_lifted = lift_result['object_lifted']
    object_z_after_lift = lift_result['object_z_after']
    prev_joints = lift_result['prev_joints']

    # ========================================================================
    # Phase 5: CARRY (Bezier curve) - lift → prerelease
    # This is learned by the diffusion model
    # ========================================================================
    print(f"\n  CARRY: lift_pos={lift_pos} -> prerelease={prerelease_pos}")
    for i in range(carry_steps):
        t = i / max(carry_steps - 1, 1)
        target_pos = parabola3D(lift_pos, prerelease_pos, carry_cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            # Gripper closed during carry (holding object)
            gripper.actuate(0.0, 0.2)

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            trace.append(tip.get_position().copy())
            phase_labels.append("carry")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)
        except (IKError, ConfigurationPathError):
            ik_failures += 1
            if len(trace) > 0:
                trace.append(trace[-1].copy())
            else:
                trace.append(lift_pos.copy())
            phase_labels.append("carry")
            gripper_states.append(0.0)
            if capture_video:
                frame = capture_frame(task_env)
                if frame is not None:
                    frames.append(frame)

    # ========================================================================
    # Phase 6 & 7: DESCEND_RELEASE + RELEASE (hard-coded from grasp_and_lift.py)
    # ========================================================================
    release_result = descend_and_release(
        task_env, prerelease_pos, release_pos, orientation,
        object_shape_name=object_shape_name,
        descend_steps=descend_release_steps, release_steps=release_steps,
        steps_per_point=steps_per_point, capture_video=capture_video,
        verbose=True, prev_joints=prev_joints
    )
    trace.extend(release_result['trace'])
    phase_labels.extend(release_result['phase_labels'])
    gripper_states.extend(release_result['gripper_states'])
    ik_failures += release_result['ik_failures']
    frames.extend(release_result['frames'])
    object_released = release_result['object_released']
    object_z_after_release = release_result['object_final_pos'][2]

    return {
        'trace': np.array(trace),
        'phase_labels': phase_labels,
        'gripper_states': gripper_states,
        'ik_failures': ik_failures,
        'object_lifted': object_lifted,
        'object_released': object_released,
        'object_z_before': object_z_before_lift,
        'object_z_after_lift': object_z_after_lift,
        'object_z_after_release': object_z_after_release,
        'frames': frames,
    }


def compute_jitter_metric(trace):
    """Compute jitter as sum of direction changes (higher = more jitter)."""
    if len(trace) < 3:
        return 0.0
    velocities = np.diff(trace, axis=0)
    # Normalize velocities
    norms = np.linalg.norm(velocities, axis=1, keepdims=True)
    norms = np.where(norms < 1e-8, 1e-8, norms)
    velocities_norm = velocities / norms
    # Compute dot products between consecutive velocities
    dots = np.sum(velocities_norm[:-1] * velocities_norm[1:], axis=1)
    # Jitter = sum of (1 - dot), i.e., direction changes
    jitter = np.sum(1 - dots)
    return jitter


def main():
    print("=" * 60)
    print("FULL PICK-AND-PLACE TRAJECTORY TEST")
    print("=" * 60)

    task_name = "stack_blocks"

    # Phase configuration
    REACH_STEPS = 64
    DESCEND_STEPS = 8
    GRASP_STEPS = 8
    LIFT_STEPS = 8
    CARRY_STEPS = 64
    DESCEND_RELEASE_STEPS = 8
    RELEASE_STEPS = 8
    TOTAL_STEPS = (REACH_STEPS + DESCEND_STEPS + GRASP_STEPS + LIFT_STEPS +
                   CARRY_STEPS + DESCEND_RELEASE_STEPS + RELEASE_STEPS)
    STEPS_PER_POINT = 5

    print(f"\nPhase configuration:")
    print(f"  REACH:           {REACH_STEPS} steps (Bezier curve)")
    print(f"  DESCEND:         {DESCEND_STEPS} steps (linear Z)")
    print(f"  GRASP:           {GRASP_STEPS} steps (close gripper)")
    print(f"  LIFT:            {LIFT_STEPS} steps (linear Z)")
    print(f"  CARRY:           {CARRY_STEPS} steps (Bezier curve)")
    print(f"  DESCEND_RELEASE: {DESCEND_RELEASE_STEPS} steps (linear Z)")
    print(f"  RELEASE:         {RELEASE_STEPS} steps (open gripper)")
    print(f"  TOTAL:           {TOTAL_STEPS} steps")

    # Enable video capture
    CAPTURE_VIDEO = True
    obs_config = create_obs_config(save_video=CAPTURE_VIDEO)
    action_mode = create_action_mode("abs")

    # Create video output directory
    os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
    print(f"\nVideo output directory: {VIDEO_OUTPUT_DIR}")

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    tasks = get_task_classes([task_name])
    t_cls = tasks[0]
    task_env = rlbench_env.get_task(t_cls)
    task_env.set_variation(0)

    # ========================================================================
    # LOAD PRE-COMPUTED INITIALIZATION DATA
    # Run stack_block_init.py first to generate this file!
    # ========================================================================
    INIT_FILE = os.path.join(os.path.dirname(__file__), "stack_blocks_init.npz")

    if not os.path.exists(INIT_FILE):
        print(f"\n  ERROR: Init file not found: {INIT_FILE}")
        print(f"  Please run stack_block_init.py first to generate it.")
        rlbench_env.shutdown()
        return

    print(f"\n{'='*60}")
    print("LOADING PRE-COMPUTED INITIALIZATION DATA")
    print(f"{'='*60}")
    print(f"  Loading from: {INIT_FILE}")

    init_data = np.load(INIT_FILE)

    # Load positions
    home_pos = init_data['home_pos']
    FIXED_OBJECT_POS = init_data['object_pos']
    TARGET_POS = init_data['target_pos']
    pregrasp_pos = init_data['pregrasp_pos']
    grasp_pos = init_data['grasp_pos']
    lift_pos = init_data['lift_pos']
    prerelease_pos = init_data['prerelease_pos']
    release_pos = init_data['release_pos']

    # Load grasp parameters
    GRASP_ORIENTATION = init_data['grasp_orientation']
    GRASP_XY_OFFSET = init_data['grasp_xy_offset']
    PREGRASP_HEIGHT = float(init_data['pregrasp_height'])
    GRASP_HEIGHT = float(init_data['grasp_height'])
    LIFT_HEIGHT = float(init_data['lift_height'])
    PREPLACE_HEIGHT = float(init_data['preplace_height'])
    PLACE_HEIGHT = float(init_data['place_height'])

    # Load control point parameters
    canonical_params = [tuple(p) for p in init_data['canonical_params']]
    control_point_radius = float(init_data['control_point_radius'])

    # Load valid control point indices
    reach_valid_indices = list(init_data['reach_valid_indices'])
    carry_final_valid_indices = list(init_data['carry_valid_indices'])

    print(f"\n  Loaded positions:")
    print(f"    HOME pos:       [{home_pos[0]:.4f}, {home_pos[1]:.4f}, {home_pos[2]:.4f}]")
    print(f"    Object pos:     [{FIXED_OBJECT_POS[0]:.4f}, {FIXED_OBJECT_POS[1]:.4f}, {FIXED_OBJECT_POS[2]:.4f}]")
    print(f"    Target pos:     [{TARGET_POS[0]:.4f}, {TARGET_POS[1]:.4f}, {TARGET_POS[2]:.4f}]")
    print(f"    Pregrasp pos:   [{pregrasp_pos[0]:.4f}, {pregrasp_pos[1]:.4f}, {pregrasp_pos[2]:.4f}]")
    print(f"    Grasp pos:      [{grasp_pos[0]:.4f}, {grasp_pos[1]:.4f}, {grasp_pos[2]:.4f}]")
    print(f"    Lift pos:       [{lift_pos[0]:.4f}, {lift_pos[1]:.4f}, {lift_pos[2]:.4f}]")
    print(f"    Prerelease pos: [{prerelease_pos[0]:.4f}, {prerelease_pos[1]:.4f}, {prerelease_pos[2]:.4f}]")
    print(f"    Release pos:    [{release_pos[0]:.4f}, {release_pos[1]:.4f}, {release_pos[2]:.4f}]")

    print(f"\n  Loaded grasp parameters:")
    print(f"    GRASP_ORIENTATION: [{GRASP_ORIENTATION[0]:.4f}, {GRASP_ORIENTATION[1]:.4f}, {GRASP_ORIENTATION[2]:.4f}]")
    print(f"    GRASP_XY_OFFSET:   [{GRASP_XY_OFFSET[0]:.4f}, {GRASP_XY_OFFSET[1]:.4f}]")
    print(f"    PREGRASP_HEIGHT:   {PREGRASP_HEIGHT:.4f}m")
    print(f"    GRASP_HEIGHT:      {GRASP_HEIGHT:.4f}m")
    print(f"    PREPLACE_HEIGHT:   {PREPLACE_HEIGHT:.4f}m")
    print(f"    PLACE_HEIGHT:      {PLACE_HEIGHT:.4f}m")

    print(f"\n  Loaded control points:")
    print(f"    Total canonical CPs: {len(canonical_params)}")
    print(f"    Valid REACH CPs:     {len(reach_valid_indices)}")
    print(f"    Valid CARRY CPs:     {len(carry_final_valid_indices)}")

    # ========================================================================
    # SELECT CONTROL POINTS FOR TESTING
    # ========================================================================
    num_reach = min(5, len(reach_valid_indices))
    num_carry = min(2, len(carry_final_valid_indices))
    reach_test_indices = [reach_valid_indices[i] for i in np.linspace(0, len(reach_valid_indices)-1, num_reach, dtype=int)]
    carry_test_indices = [carry_final_valid_indices[i] for i in np.linspace(0, len(carry_final_valid_indices)-1, num_carry, dtype=int)]

    print(f"\n{'='*60}")
    print(f"TESTING FULL PICK-AND-PLACE TRAJECTORIES")
    print(f"  REACH CPs: {reach_test_indices}")
    print(f"  CARRY CPs: {carry_test_indices}")
    print(f"  Total combinations: {len(reach_test_indices) * len(carry_test_indices)}")
    print(f"{'='*60}")

    all_results = []

    # Import Shape here for use in trial loop
    from pyrep.objects.shape import Shape

    for trial_idx, (reach_cp_idx, carry_cp_idx) in enumerate(
        [(r, c) for r in reach_test_indices for c in carry_test_indices]
    ):
        print(f"\n--- Trial {trial_idx+1} (reach_cp={reach_cp_idx}, carry_cp={carry_cp_idx}) ---")

        np.random.seed(42)
        task_env.reset()

        robot = task_env._scene.robot
        tip = robot.arm.get_tip()
        gripper = robot.gripper

        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)

        for _ in range(20):
            task_env._scene.pyrep.step()

        # Use loaded home_pos (don't recompute)
        home_pos_trial = home_pos.copy()

        # FIX OBJECT POSITION - ensure block is at consistent location (use loaded position)
        target_block = Shape('stack_blocks_target0')
        target_block.set_position(FIXED_OBJECT_POS)

        # Hide other blocks for cleaner visualization
        if task_name == "stack_blocks":
            for i in range(1, 4):
                try:
                    other_block = Shape(f'stack_blocks_target{i}')
                    other_block.set_position([10, 10, 0])
                except:
                    pass
            for i in range(4):
                try:
                    distractor = Shape(f'stack_blocks_distractor{i}')
                    distractor.set_position([10, 10 + i*0.1, 0])
                except:
                    pass

        for _ in range(10):  # Let physics settle
            task_env._scene.pyrep.step()

        object_pos = target_block.get_position().copy()

        # IMPORTANT: Override success condition AFTER reset(), because reset() calls
        # init_episode() which re-registers the original success conditions (requiring 2+ blocks)
        if task_name == "stack_blocks":
            from pyrep.objects.proximity_sensor import ProximitySensor
            from rlbench.backend.conditions import DetectedSeveralCondition, NothingGrasped
            success_detector = ProximitySensor('stack_blocks_success')
            # Only require 1 block to be detected (instead of default 2-4)
            task_env._task._success_conditions = [
                DetectedSeveralCondition([target_block], success_detector, 1),
                NothingGrasped(robot.gripper)
            ]

        # Use pre-loaded waypoint positions (don't recompute)
        pregrasp_pos_trial = pregrasp_pos.copy()
        grasp_pos_trial = grasp_pos.copy()
        lift_pos_trial = lift_pos.copy()
        prerelease_pos_trial = prerelease_pos.copy()
        release_pos_trial = release_pos.copy()

        print(f"  HOME:       [{home_pos_trial[0]:.4f}, {home_pos_trial[1]:.4f}, {home_pos_trial[2]:.4f}]")
        print(f"  Object:     [{object_pos[0]:.4f}, {object_pos[1]:.4f}, {object_pos[2]:.4f}] (fixed)")
        print(f"  Pregrasp:   [{pregrasp_pos_trial[0]:.4f}, {pregrasp_pos_trial[1]:.4f}, {pregrasp_pos_trial[2]:.4f}]")
        print(f"  Grasp:      [{grasp_pos_trial[0]:.4f}, {grasp_pos_trial[1]:.4f}, {grasp_pos_trial[2]:.4f}]")
        print(f"  Lift:       [{lift_pos_trial[0]:.4f}, {lift_pos_trial[1]:.4f}, {lift_pos_trial[2]:.4f}]")
        print(f"  Target:     [{TARGET_POS[0]:.4f}, {TARGET_POS[1]:.4f}, {TARGET_POS[2]:.4f}]")
        print(f"  Prerelease: [{prerelease_pos_trial[0]:.4f}, {prerelease_pos_trial[1]:.4f}, {prerelease_pos_trial[2]:.4f}]")
        print(f"  Release:    [{release_pos_trial[0]:.4f}, {release_pos_trial[1]:.4f}, {release_pos_trial[2]:.4f}]")

        # Compute reach control point
        angle, dist_frac, pos_frac = canonical_params[reach_cp_idx]
        reach_cp = compute_control_point_from_params(
            home_pos_trial, pregrasp_pos_trial, control_point_radius, angle, dist_frac, pos_frac
        )

        # Compute carry control point
        angle, dist_frac, pos_frac = canonical_params[carry_cp_idx]
        carry_cp = compute_control_point_from_params(
            lift_pos_trial, prerelease_pos_trial, control_point_radius, angle, dist_frac, pos_frac
        )

        # Generate full pick-and-place trajectory
        result = generate_full_pick_place_trajectory(
            task_env, home_pos_trial, pregrasp_pos_trial, grasp_pos_trial,
            reach_cp, GRASP_ORIENTATION,
            prerelease_pos_trial, release_pos_trial, carry_cp,
            reach_steps=REACH_STEPS, descend_steps=DESCEND_STEPS,
            grasp_steps=GRASP_STEPS, lift_steps=LIFT_STEPS,
            carry_steps=CARRY_STEPS, descend_release_steps=DESCEND_RELEASE_STEPS,
            release_steps=RELEASE_STEPS,
            steps_per_point=STEPS_PER_POINT,
            capture_video=CAPTURE_VIDEO
        )

        trace = result['trace']
        jitter = compute_jitter_metric(trace)
        object_lifted = result['object_lifted']
        object_released = result['object_released']
        z_rise = result['object_z_after_lift'] - result['object_z_before']

        # Check RLBench task success condition
        task_success, task_terminate = task_env._task.success()

        # Debug: Check proximity sensor detection
        if task_name == "stack_blocks":
            from pyrep.objects.proximity_sensor import ProximitySensor
            success_detector = ProximitySensor('stack_blocks_success')
            block_detected = success_detector.is_detected(target_block)
            grasped_objects = robot.gripper.get_grasped_objects()
            print(f"  DEBUG: Block detected by sensor: {block_detected}, Gripper holding: {len(grasped_objects)} objects")

        print(f"  Total steps: {len(trace)}")
        print(f"  Jitter: {jitter:.4f}")
        print(f"  IK failures: {result['ik_failures']}")
        print(f"  Object Z rise: {z_rise*1000:.1f}mm")
        print(f"  Object lifted: {'YES' if object_lifted else 'NO'}")
        print(f"  Object released: {'YES' if object_released else 'NO'}")
        print(f"  RLBench task success: {'YES' if task_success else 'NO'}")

        # Save video
        if CAPTURE_VIDEO and len(result['frames']) > 0:
            video_path = os.path.join(VIDEO_OUTPUT_DIR, f"trial_{trial_idx+1}_R{reach_cp_idx}_C{carry_cp_idx}.mp4")
            imageio.mimwrite(video_path, result['frames'], fps=30, codec='libx264', quality=8)
            print(f"  Video saved: {video_path} ({len(result['frames'])} frames)")

        all_results.append({
            "reach_cp_idx": reach_cp_idx,
            "carry_cp_idx": carry_cp_idx,
            "trace": trace,
            "phase_labels": result['phase_labels'],
            "gripper_states": result['gripper_states'],
            "jitter": jitter,
            "ik_failures": result['ik_failures'],
            "object_lifted": object_lifted,
            "object_released": object_released,
            "task_success": task_success,
            "object_z_rise": z_rise,
            "home_pos": home_pos_trial,
            "pregrasp_pos": pregrasp_pos_trial,
            "grasp_pos": grasp_pos_trial,
            "lift_pos": lift_pos_trial,
            "prerelease_pos": prerelease_pos_trial,
            "release_pos": release_pos_trial,
            "target_pos": TARGET_POS,
            "object_pos": object_pos,
            "reach_cp": reach_cp,
            "carry_cp": carry_cp,
        })

    # ========================================================================
    # SUMMARY
    # ========================================================================
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")

    n_lifted = sum(1 for r in all_results if r['object_lifted'])
    n_released = sum(1 for r in all_results if r['object_released'])
    n_success = sum(1 for r in all_results if r['object_lifted'] and r['object_released'])
    n_task_success = sum(1 for r in all_results if r['task_success'])
    avg_jitter = np.mean([r['jitter'] for r in all_results])
    avg_z_rise = np.mean([r['object_z_rise'] for r in all_results])

    print(f"\n  Trajectories tested: {len(all_results)}")
    print(f"  Objects lifted: {n_lifted}/{len(all_results)}")
    print(f"  Objects released: {n_released}/{len(all_results)}")
    print(f"  Full success (lifted + released): {n_success}/{len(all_results)}")
    print(f"  RLBench task success: {n_task_success}/{len(all_results)}")
    print(f"  Average jitter: {avg_jitter:.4f}")
    print(f"  Average Z rise: {avg_z_rise*1000:.1f}mm")

    print(f"\n  {'Trial':>5} | {'Reach':>5} | {'Carry':>5} | {'Jitter':>8} | {'Z Rise':>10} | {'Lifted':>6} | {'Released':>8} | {'TaskOK':>6}")
    print(f"  {'-'*5}-+-{'-'*5}-+-{'-'*5}-+-{'-'*8}-+-{'-'*10}-+-{'-'*6}-+-{'-'*8}-+-{'-'*6}")
    for i, r in enumerate(all_results):
        lifted_str = "YES" if r['object_lifted'] else "NO"
        released_str = "YES" if r['object_released'] else "NO"
        task_ok_str = "YES" if r['task_success'] else "NO"
        print(f"  {i+1:5d} | {r['reach_cp_idx']:5d} | {r['carry_cp_idx']:5d} | {r['jitter']:8.4f} | {r['object_z_rise']*1000:8.1f}mm | {lifted_str:>6} | {released_str:>8} | {task_ok_str:>6}")

    # ========================================================================
    # VISUALIZATION
    # ========================================================================
    print(f"\n{'='*60}")
    print("Creating visualization...")
    print(f"{'='*60}")

    fig = plt.figure(figsize=(18, 6))

    r0 = all_results[0]
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_results)))

    # Define phase step indices
    reach_end = REACH_STEPS
    lift_end = REACH_STEPS + DESCEND_STEPS + GRASP_STEPS + LIFT_STEPS
    carry_end = lift_end + CARRY_STEPS

    # Plot 1: 3D view - Full trajectory (all 10 demos)
    ax1 = fig.add_subplot(131, projection='3d')
    for i, r in enumerate(all_results):
        trace = r['trace']
        success_label = "Y" if (r['object_lifted'] and r['object_released']) else "N"
        ax1.plot(trace[:, 0], trace[:, 1], trace[:, 2], '-', color=colors[i],
                 linewidth=1.5, alpha=0.8, label=f'R{r["reach_cp_idx"]}C{r["carry_cp_idx"]} {success_label}')
        # Mark reach and carry control points
        ax1.scatter(*r["reach_cp"], c=[colors[i]], marker='x', s=30, alpha=0.5)
        ax1.scatter(*r["carry_cp"], c=[colors[i]], marker='+', s=30, alpha=0.5)

    ax1.scatter(*r0["home_pos"], c='black', marker='o', s=150, zorder=10, label='HOME')
    ax1.scatter(*r0["pregrasp_pos"], c='orange', marker='*', s=200, zorder=10, label='Pregrasp')
    ax1.scatter(*r0["grasp_pos"], c='red', marker='v', s=150, zorder=10, label='Grasp')
    ax1.scatter(*r0["lift_pos"], c='purple', marker='^', s=150, zorder=10, label='Lift')
    ax1.scatter(*r0["prerelease_pos"], c='cyan', marker='*', s=200, zorder=10, label='Prerelease')
    ax1.scatter(*r0["release_pos"], c='magenta', marker='v', s=150, zorder=10, label='Release')
    ax1.scatter(*r0["object_pos"], c='green', marker='s', s=100, zorder=10, label='Object')
    ax1.scatter(*r0["target_pos"], c='blue', marker='D', s=100, zorder=10, label='Target')

    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title(f'Full Trajectory ({n_success}/{len(all_results)} success)')
    ax1.legend(fontsize=5, loc='upper left')

    # Plot 2: 3D view - REACH phase only (start to pre-grasp)
    ax2 = fig.add_subplot(132, projection='3d')
    for i, r in enumerate(all_results):
        trace = r['trace']
        # REACH phase: steps 0 to REACH_STEPS
        reach_trace = trace[:reach_end]
        success_label = "Y" if (r['object_lifted'] and r['object_released']) else "N"
        ax2.plot(reach_trace[:, 0], reach_trace[:, 1], reach_trace[:, 2], '-', color=colors[i],
                 linewidth=1.5, alpha=0.8, label=f'R{r["reach_cp_idx"]}C{r["carry_cp_idx"]} {success_label}')
        # Mark reach control point
        ax2.scatter(*r["reach_cp"], c=[colors[i]], marker='x', s=50, alpha=0.7)

    ax2.scatter(*r0["home_pos"], c='black', marker='o', s=150, zorder=10, label='HOME (start)')
    ax2.scatter(*r0["pregrasp_pos"], c='orange', marker='*', s=200, zorder=10, label='Pregrasp (end)')
    ax2.scatter(*r0["object_pos"], c='green', marker='s', s=100, zorder=10, label='Object')

    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Y (m)')
    ax2.set_zlabel('Z (m)')
    ax2.set_title(f'REACH Phase (start -> pre-grasp, {REACH_STEPS} steps)')
    ax2.legend(fontsize=5, loc='upper left')

    # Plot 3: 3D view - CARRY phase only (pre-grasp to pre-release)
    ax3 = fig.add_subplot(133, projection='3d')
    for i, r in enumerate(all_results):
        trace = r['trace']
        # CARRY phase: from lift_end to carry_end
        carry_trace = trace[lift_end:carry_end]
        success_label = "Y" if (r['object_lifted'] and r['object_released']) else "N"
        ax3.plot(carry_trace[:, 0], carry_trace[:, 1], carry_trace[:, 2], '-', color=colors[i],
                 linewidth=1.5, alpha=0.8, label=f'R{r["reach_cp_idx"]}C{r["carry_cp_idx"]} {success_label}')
        # Mark carry control point
        ax3.scatter(*r["carry_cp"], c=[colors[i]], marker='+', s=50, alpha=0.7)

    ax3.scatter(*r0["lift_pos"], c='purple', marker='^', s=150, zorder=10, label='Lift (start)')
    ax3.scatter(*r0["prerelease_pos"], c='cyan', marker='*', s=200, zorder=10, label='Prerelease (end)')
    ax3.scatter(*r0["object_pos"], c='green', marker='s', s=100, zorder=10, label='Object')
    ax3.scatter(*r0["target_pos"], c='blue', marker='D', s=100, zorder=10, label='Target')

    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Y (m)')
    ax3.set_zlabel('Z (m)')
    ax3.set_title(f'CARRY Phase (lift -> pre-release, {CARRY_STEPS} steps)')
    ax3.legend(fontsize=5, loc='upper left')

    plt.tight_layout()
    plt.savefig('pick_place_full_debug.png', dpi=150)
    print(f"\nPlot saved to: pick_place_full_debug.png")
    plt.show()

    rlbench_env.shutdown()
    print("\n" + "=" * 60)
    print("TEST COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()
