#!/usr/bin/env python
"""
Visualize EE trajectories with wall obstacle for pick-and-place (stack_blocks) task.

This script:
1. Loads precomputed init data from stack_blocks_init.npz
2. Generates REACH and CARRY trajectories for all valid control point modes
3. Checks wall collision for each trajectory
4. Saves a visualization with:
   - First camera frame as background
   - All EE trajectories overlaid
   - Wall plane visualization
   - Blue = success (no wall collision), Red = failed (wall collision)

Usage:
  # Default wall position
  python visualize_wall_trajectories.py

  # Custom wall position
  python visualize_wall_trajectories.py --wall_y=0.05

  # With wall opening
  python visualize_wall_trajectories.py --wall_y=0.05 --wall_opening_min_x=0.2 --wall_opening_max_x=0.35

  # Use predefined wall style
  python visualize_wall_trajectories.py --style=1
"""
import sys
import os
import pickle

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

import numpy as np

# Use non-interactive backend for matplotlib (must be before pyplot import)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment

from absl import app, flags

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    parabola3D,
    HOME_JOINTS,
)

from wall_collision import (
    DEFAULT_WALL_CONFIG,
    DEFAULT_OPENING_CONFIG,
    WALL_STYLES,
    create_wall,
    check_wall_collision,
    check_ee_trajectory_wall_collision,
    WallCollisionTracker,
)


FLAGS = flags.FLAGS

# Default save path: RLBench_pick_place/block_setting
DEFAULT_SAVE_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),  # RLBench_pick_place/
    "block_setting"
)
flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save visualization.")
flags.DEFINE_integer("num_modes", 8, "Number of control point modes.")
flags.DEFINE_float("control_point_radius", 0.05, "Control point radius.")
flags.DEFINE_bool("interactive", False, "Interactive mode for wall tuning")
flags.DEFINE_bool("use_precomputed_init", True, "Use precomputed init data")

# Wall style (predefined configs: 1, 2, 3)
flags.DEFINE_integer("style", 0, "Predefined wall style (1, 2, or 3). 0 = use individual flags below.")

# Wall configuration (used when style=0)
flags.DEFINE_float("wall_y", DEFAULT_WALL_CONFIG["wall_y"], "Wall Y position")
flags.DEFINE_float("wall_min_x", DEFAULT_WALL_CONFIG["wall_min_x"], "Wall min X")
flags.DEFINE_float("wall_max_x", DEFAULT_WALL_CONFIG["wall_max_x"], "Wall max X")
flags.DEFINE_float("wall_min_z", DEFAULT_WALL_CONFIG["wall_min_z"], "Wall min Z")
flags.DEFINE_float("wall_max_z", DEFAULT_WALL_CONFIG["wall_max_z"], "Wall max Z")

# Wall opening (used when style=0)
flags.DEFINE_bool("enable_opening", False, "Enable opening (disabled by default)")
flags.DEFINE_float("wall_opening_min_x", 0.0, "Opening min X (0 to disable)")
flags.DEFINE_float("wall_opening_max_x", 0.0, "Opening max X")
flags.DEFINE_float("wall_opening_min_z", 0.0, "Opening min Z")
flags.DEFINE_float("wall_opening_max_z", 0.0, "Opening max Z")

# Collision check mode
flags.DEFINE_bool("check_full_arm", False, "Check full robot arm for collision (default: only check EE)")

# Phase steps
REACH_STEPS = 64
CARRY_STEPS = 64
STEPS_PER_POINT = 5


def generate_reach_trajectory(task_env, home_pos, pregrasp_pos, orientation, cp_idx, canonical_params, control_point_radius):
    """
    Generate REACH EE trajectory positions without actually executing them.
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    # Move to HOME position
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(10):
        task_env._scene.pyrep.step()

    home_pos_actual = np.array(tip.get_position())

    # Compute control point
    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp = compute_control_point_from_params(
        home_pos_actual, pregrasp_pos, control_point_radius, angle, dist_frac, pos_frac
    )

    # Generate positions
    positions = []
    prev_joints = list(robot.arm.get_joint_positions())

    for i in range(REACH_STEPS):
        t = i / max(REACH_STEPS - 1, 1)
        target_pos = parabola3D(home_pos_actual, pregrasp_pos, cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            for _ in range(STEPS_PER_POINT):
                task_env._scene.pyrep.step()

            positions.append(tip.get_position().copy())
            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError):
            if len(positions) > 0:
                positions.append(positions[-1].copy())
            else:
                positions.append(home_pos_actual.copy())

    return np.array(positions), cp


def generate_carry_trajectory(task_env, lift_pos, prerelease_pos, orientation, cp_idx, canonical_params, control_point_radius):
    """
    Generate CARRY EE trajectory positions without actually executing them.
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    # Move to lift position
    try:
        joint_positions = robot.arm.solve_ik_via_jacobian(lift_pos, euler=orientation)
        robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        for _ in range(10):
            task_env._scene.pyrep.step()
    except (IKError, ConfigurationPathError):
        return None, None

    lift_pos_actual = np.array(tip.get_position())

    # Compute control point
    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp = compute_control_point_from_params(
        lift_pos_actual, prerelease_pos, control_point_radius, angle, dist_frac, pos_frac
    )

    # Generate positions
    positions = []
    prev_joints = list(robot.arm.get_joint_positions())

    for i in range(CARRY_STEPS):
        t = i / max(CARRY_STEPS - 1, 1)
        target_pos = parabola3D(lift_pos_actual, prerelease_pos, cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            for _ in range(STEPS_PER_POINT):
                task_env._scene.pyrep.step()

            positions.append(tip.get_position().copy())
            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError):
            if len(positions) > 0:
                positions.append(positions[-1].copy())
            else:
                positions.append(lift_pos_actual.copy())

    return np.array(positions), cp


def project_world_to_image(points_3d, camera, image_size):
    """Project 3D world points to 2D image coordinates."""
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def project_wall_to_image(wall_config, camera, image_size):
    """Project wall corners to 2D."""
    wall_y = wall_config["wall_y"]
    corners_3d = [
        [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
        [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
    ]
    return project_world_to_image(corners_3d, camera, image_size)


def visualize_on_frame(frame, all_results, wall_config, camera, phase_name="reach"):
    """
    Overlay all trajectories and wall on a single frame.
    """
    import cv2

    image_size = frame.shape[:2]
    frame_overlay = frame.copy()

    wall_y = wall_config["wall_y"]
    opening = wall_config.get("opening", None)

    # Draw wall
    if opening is not None:
        # Draw wall with opening as a hole
        wall_pieces = []

        # Bottom piece
        if wall_config["wall_min_z"] < opening["min_z"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
                [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_min_x"], wall_y, opening["min_z"]],
            ])

        # Top piece
        if opening["max_z"] < wall_config["wall_max_z"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, opening["max_z"]],
                [wall_config["wall_max_x"], wall_y, opening["max_z"]],
                [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
                [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
            ])

        # Left piece
        if wall_config["wall_min_x"] < opening["min_x"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, opening["min_z"]],
                [opening["min_x"], wall_y, opening["min_z"]],
                [opening["min_x"], wall_y, opening["max_z"]],
                [wall_config["wall_min_x"], wall_y, opening["max_z"]],
            ])

        # Right piece
        if opening["max_x"] < wall_config["wall_max_x"]:
            wall_pieces.append([
                [opening["max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["max_z"]],
                [opening["max_x"], wall_y, opening["max_z"]],
            ])

        overlay = frame_overlay.copy()
        for piece_3d in wall_pieces:
            piece_2d = project_world_to_image(piece_3d, camera, image_size)
            if all(c is not None for c in piece_2d):
                pts = np.array(piece_2d, dtype=np.int32)
                cv2.fillPoly(overlay, [pts], (100, 100, 255))

        cv2.addWeighted(overlay, 0.3, frame_overlay, 0.7, 0, frame_overlay)

        # Draw wall outline
        wall_corners = project_wall_to_image(wall_config, camera, image_size)
        if all(c is not None for c in wall_corners):
            pts = np.array(wall_corners, dtype=np.int32)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 200), 2, cv2.LINE_AA)

    else:
        # No opening - draw solid wall
        wall_corners = project_wall_to_image(wall_config, camera, image_size)
        if all(c is not None for c in wall_corners):
            pts = np.array(wall_corners, dtype=np.int32)

            overlay = frame_overlay.copy()
            cv2.fillPoly(overlay, [pts], (100, 100, 255))
            cv2.addWeighted(overlay, 0.3, frame_overlay, 0.7, 0, frame_overlay)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 200), 2, cv2.LINE_AA)

    # Draw each trajectory
    for result in all_results:
        positions = result["positions"]
        collision = result["collision"]
        collision_idx = result["collision_idx"]
        mode_idx = result["mode_idx"]

        # Project trajectory to 2D
        projected = project_world_to_image(positions, camera, image_size)

        # Color: BLUE for success, RED for collision
        if collision:
            color = (0, 0, 255)
        else:
            color = (255, 100, 0)

        # Determine where to stop drawing
        if collision and collision_idx is not None:
            draw_end_idx = min(collision_idx + 1, len(projected))
        else:
            draw_end_idx = len(projected)

        # Draw trajectory line
        for i in range(1, draw_end_idx):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            cv2.line(frame_overlay, p1, p2, color, thickness=1, lineType=cv2.LINE_AA)

        # Draw start marker
        if projected[0] is not None:
            cv2.circle(frame_overlay, projected[0], 3, color, -1)

        # Draw end marker
        end_idx = draw_end_idx - 1
        if end_idx >= 0 and end_idx < len(projected) and projected[end_idx] is not None:
            pt = projected[end_idx]
            if 0 <= pt[0] < image_size[1] and 0 <= pt[1] < image_size[0]:
                if not collision:
                    cv2.circle(frame_overlay, pt, 4, (0, 255, 0), -1)
                    cv2.circle(frame_overlay, pt, 4, (0, 0, 0), 1)

        # Add mode label
        if projected[0] is not None:
            label_pos = (projected[0][0] + 5, projected[0][1] - 5)
            cv2.putText(frame_overlay, f"M{mode_idx}", label_pos,
                       cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)

    # Add legend
    n_success = sum(1 for r in all_results if not r["collision"])
    n_fail = sum(1 for r in all_results if r["collision"])

    cv2.putText(frame_overlay, f"{phase_name.upper()} - Success: {n_success}/{len(all_results)}",
               (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 100, 0), 2)
    cv2.putText(frame_overlay, f"Wall Collision: {n_fail}/{len(all_results)}",
               (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
    cv2.putText(frame_overlay, f"Wall Y={wall_config['wall_y']:.2f}",
               (10, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (100, 100, 255), 1)

    return frame_overlay


def visualize_3d_plot(all_results, wall_config, save_path, phase_name="reach"):
    """
    Create 3D matplotlib plot of trajectories and wall.
    """
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Collect trajectory bounds
    all_x, all_y, all_z = [], [], []

    # Plot trajectories
    for result in all_results:
        positions = result["positions"]
        collision = result["collision"]
        collision_idx = result["collision_idx"]
        mode_idx = result["mode_idx"]

        if collision and collision_idx is not None:
            draw_end = min(collision_idx + 1, len(positions))
        else:
            draw_end = len(positions)

        pos_to_draw = positions[:draw_end]
        x, y, z = pos_to_draw[:, 0], pos_to_draw[:, 1], pos_to_draw[:, 2]

        all_x.extend(x)
        all_y.extend(y)
        all_z.extend(z)

        if collision:
            color = 'red'
            alpha = 0.8
        else:
            color = 'blue'
            alpha = 0.8

        ax.plot(x, y, z, '-', color=color, linewidth=2, alpha=alpha,
               label=f'Mode {mode_idx}' if mode_idx == 0 else '')

        # End marker
        if not collision:
            ax.scatter(x[-1], y[-1], z[-1], c='green', marker='*', s=150)

    # Start point
    start_pos = np.mean([r["positions"][0] for r in all_results], axis=0)
    ax.scatter(*start_pos, color='blue', s=200, marker='o', label='Start',
              edgecolors='black', linewidths=2)

    # Calculate bounds
    x_min, x_max = min(all_x), max(all_x)
    y_min, y_max = min(all_y), max(all_y)
    z_min, z_max = min(all_z), max(all_z)

    x_pad = (x_max - x_min) * 0.3
    y_pad = (y_max - y_min) * 0.3
    z_pad = (z_max - z_min) * 0.3

    # Draw wall
    wall_y = wall_config["wall_y"]
    wall_min_x = max(wall_config["wall_min_x"], x_min - x_pad)
    wall_max_x = min(wall_config["wall_max_x"], x_max + x_pad)
    wall_min_z = max(wall_config["wall_min_z"], z_min - z_pad)
    wall_max_z = min(wall_config["wall_max_z"], z_max + z_pad)

    opening = wall_config.get("opening", None)

    if opening is not None:
        op_min_x = max(opening["min_x"], wall_min_x)
        op_max_x = min(opening["max_x"], wall_max_x)
        op_min_z = max(opening["min_z"], wall_min_z)
        op_max_z = min(opening["max_z"], wall_max_z)

        wall_pieces = []
        if wall_min_z < op_min_z:
            wall_pieces.append([
                [wall_min_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, op_min_z],
                [wall_min_x, wall_y, op_min_z],
            ])
        if op_max_z < wall_max_z:
            wall_pieces.append([
                [wall_min_x, wall_y, op_max_z],
                [wall_max_x, wall_y, op_max_z],
                [wall_max_x, wall_y, wall_max_z],
                [wall_min_x, wall_y, wall_max_z],
            ])
        if wall_min_x < op_min_x:
            wall_pieces.append([
                [wall_min_x, wall_y, op_min_z],
                [op_min_x, wall_y, op_min_z],
                [op_min_x, wall_y, op_max_z],
                [wall_min_x, wall_y, op_max_z],
            ])
        if op_max_x < wall_max_x:
            wall_pieces.append([
                [op_max_x, wall_y, op_min_z],
                [wall_max_x, wall_y, op_min_z],
                [wall_max_x, wall_y, op_max_z],
                [op_max_x, wall_y, op_max_z],
            ])

        for piece in wall_pieces:
            piece_poly = Poly3DCollection([piece], alpha=0.3, facecolor='red',
                                          edgecolor='darkred', linewidth=1)
            ax.add_collection3d(piece_poly)
    else:
        wall_vertices = [
            [wall_min_x, wall_y, wall_min_z],
            [wall_max_x, wall_y, wall_min_z],
            [wall_max_x, wall_y, wall_max_z],
            [wall_min_x, wall_y, wall_max_z],
        ]
        wall_poly = Poly3DCollection([wall_vertices], alpha=0.3, facecolor='red',
                                      edgecolor='darkred', linewidth=2)
        ax.add_collection3d(wall_poly)

    # Set axis limits
    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)
    ax.set_zlim(z_min - z_pad, z_max + z_pad)

    # Labels
    ax.set_xlabel('X (m)', fontsize=12)
    ax.set_ylabel('Y (m)', fontsize=12)
    ax.set_zlabel('Z (m)', fontsize=12)

    n_success = sum(1 for r in all_results if not r["collision"])
    n_fail = sum(1 for r in all_results if r["collision"])

    ax.set_title(f'Pick-and-Place {phase_name.upper()} - {len(all_results)} Mode Trajectories with Wall\n'
                f'Success: {n_success}/{len(all_results)} | '
                f'Wall Collision: {n_fail}/{len(all_results)} | '
                f'Wall Y={wall_config["wall_y"]:.2f}',
                fontsize=13, fontweight='bold')

    # Legend
    from matplotlib.lines import Line2D
    from matplotlib.patches import Patch
    wall_label = 'Wall (with opening)' if opening else 'Wall'
    legend_elements = [
        Patch(facecolor='red', alpha=0.3, edgecolor='darkred', label=wall_label),
        Line2D([0], [0], color='blue', linewidth=2, label='Success'),
        Line2D([0], [0], color='red', linewidth=2, label='Wall Collision'),
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)

    # View angle
    ax.view_init(elev=30, azim=-135)

    plt.tight_layout()
    plot_path = os.path.join(save_path, f'{phase_name}_trajectories_3d.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"Saved 3D plot to {plot_path}")
    plt.close()


def main(argv):
    print(f"{'='*70}")
    print("PICK-AND-PLACE TRAJECTORY + WALL VISUALIZATION")
    print(f"{'='*70}")

    # Build wall config based on style or individual flags
    if FLAGS.style in WALL_STYLES:
        wall_config = WALL_STYLES[FLAGS.style].copy()
        if wall_config.get("opening") is not None:
            wall_config["opening"] = wall_config["opening"].copy()
        print(f"Using predefined style {FLAGS.style}")
    else:
        wall_config = {
            "wall_y": FLAGS.wall_y,
            "wall_min_x": FLAGS.wall_min_x,
            "wall_max_x": FLAGS.wall_max_x,
            "wall_min_z": FLAGS.wall_min_z,
            "wall_max_z": FLAGS.wall_max_z,
            "wall_thickness": 0.002,
            "wall_color": [1.0, 0.2, 0.2],
            "wall_transparency": 0.6,
            "opening": None,
        }

        if (FLAGS.wall_opening_min_x != 0 or FLAGS.wall_opening_max_x != 0):
            wall_config["opening"] = {
                "min_x": FLAGS.wall_opening_min_x,
                "max_x": FLAGS.wall_opening_max_x,
                "min_z": FLAGS.wall_opening_min_z,
                "max_z": FLAGS.wall_opening_max_z,
            }
        elif FLAGS.enable_opening:
            wall_config["opening"] = DEFAULT_OPENING_CONFIG.copy()

    print(f"Wall Configuration:")
    print(f"  Y position: {wall_config['wall_y']}")
    print(f"  X bounds: [{wall_config['wall_min_x']}, {wall_config['wall_max_x']}]")
    print(f"  Z bounds: [{wall_config['wall_min_z']}, {wall_config['wall_max_z']}]")
    if wall_config.get("opening"):
        print(f"  Opening: X=[{wall_config['opening']['min_x']}, {wall_config['opening']['max_x']}], "
              f"Z=[{wall_config['opening']['min_z']}, {wall_config['opening']['max_z']}]")
    print()

    # Create save directory
    os.makedirs(FLAGS.save_path, exist_ok=True)

    # Load precomputed init data
    if FLAGS.use_precomputed_init:
        init_file = os.path.join(os.path.dirname(__file__), "stack_blocks_init.npz")
        if not os.path.exists(init_file):
            print(f"ERROR: Init file not found: {init_file}")
            print(f"Please run stack_block_init.py first to generate it.")
            return

        print(f"Loading precomputed init data from: {init_file}")
        init_data = dict(np.load(init_file))

        home_pos = init_data['home_pos']
        pregrasp_pos = init_data['pregrasp_pos']
        lift_pos = init_data['lift_pos']
        prerelease_pos = init_data['prerelease_pos']
        grasp_orientation = init_data['grasp_orientation']
        canonical_params = init_data['canonical_params']
        control_point_radius = float(init_data['control_point_radius'])
        valid_reach_indices = list(init_data['reach_valid_indices'])
        valid_carry_indices = list(init_data['carry_valid_indices'])

        print(f"  Home pos: {home_pos}")
        print(f"  Pregrasp pos: {pregrasp_pos}")
        print(f"  Lift pos: {lift_pos}")
        print(f"  Prerelease pos: {prerelease_pos}")
        print(f"  Valid REACH CPs: {len(valid_reach_indices)}")
        print(f"  Valid CARRY CPs: {len(valid_carry_indices)}")
    else:
        print("ERROR: Must use precomputed init data")
        return

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.front_camera.rgb = True
    obs_config.front_camera.image_size = [256, 256]

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("stack_blocks")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(0)

    # Initialize
    np.random.seed(42)
    descriptions, obs = task_env.reset()

    # Setup robot at HOME
    robot = task_env._scene.robot
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(20):
        task_env._scene.pyrep.step()

    tip = robot.arm.get_tip()
    home_ee_pos = np.array(tip.get_position())
    print(f"HOME EE position: [{home_ee_pos[0]:.4f}, {home_ee_pos[1]:.4f}, {home_ee_pos[2]:.4f}]")

    # Create wall in scene
    wall = create_wall(task_env, wall_config)

    # Step to render wall
    for _ in range(10):
        task_env._scene.pyrep.step()

    # Capture first frame
    front_cam = task_env._scene._cam_front
    obs = task_env._scene.get_observation()
    first_frame = obs.front_rgb.copy()

    # ========================================================================
    # REACH PHASE TRAJECTORIES
    # ========================================================================
    print(f"\n{'='*70}")
    print("GENERATING REACH TRAJECTORIES")
    print(f"{'='*70}")

    reach_results = []
    for cp_idx in valid_reach_indices:
        print(f"\nREACH Mode {cp_idx}:")

        try:
            # Reset
            np.random.seed(42)
            task_env.reset()
            robot = task_env._scene.robot
            robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)
            for _ in range(20):
                task_env._scene.pyrep.step()

            # Generate trajectory
            positions, cp = generate_reach_trajectory(
                task_env, home_pos, pregrasp_pos, grasp_orientation,
                cp_idx, canonical_params, control_point_radius
            )

            # Check wall collision
            collision, collision_idx = check_ee_trajectory_wall_collision(
                positions, wall_config, debug=True
            )

            angle_deg = np.degrees(canonical_params[cp_idx][0])
            dist_frac = canonical_params[cp_idx][1]

            if collision:
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> WALL COLLISION at step {collision_idx}")
            else:
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> SUCCESS")

            reach_results.append({
                "mode_idx": cp_idx,
                "positions": positions,
                "collision": collision,
                "collision_idx": collision_idx,
                "cp_params": canonical_params[cp_idx],
                "cp": cp,
            })

        except Exception as e:
            print(f"  Mode {cp_idx} failed: {e}")
            continue

    # ========================================================================
    # CARRY PHASE TRAJECTORIES
    # ========================================================================
    print(f"\n{'='*70}")
    print("GENERATING CARRY TRAJECTORIES")
    print(f"{'='*70}")

    carry_results = []
    for cp_idx in valid_carry_indices:
        print(f"\nCARRY Mode {cp_idx}:")

        try:
            # Reset
            np.random.seed(42)
            task_env.reset()
            robot = task_env._scene.robot
            robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)
            for _ in range(20):
                task_env._scene.pyrep.step()

            # Generate trajectory
            positions, cp = generate_carry_trajectory(
                task_env, lift_pos, prerelease_pos, grasp_orientation,
                cp_idx, canonical_params, control_point_radius
            )

            if positions is None:
                print(f"  Mode {cp_idx} failed: could not reach lift position")
                continue

            # Check wall collision
            collision, collision_idx = check_ee_trajectory_wall_collision(
                positions, wall_config, debug=True
            )

            angle_deg = np.degrees(canonical_params[cp_idx][0])
            dist_frac = canonical_params[cp_idx][1]

            if collision:
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> WALL COLLISION at step {collision_idx}")
            else:
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> SUCCESS")

            carry_results.append({
                "mode_idx": cp_idx,
                "positions": positions,
                "collision": collision,
                "collision_idx": collision_idx,
                "cp_params": canonical_params[cp_idx],
                "cp": cp,
            })

        except Exception as e:
            print(f"  Mode {cp_idx} failed: {e}")
            continue

    # ========================================================================
    # SUMMARY
    # ========================================================================
    print(f"\n{'='*70}")
    print("RESULTS SUMMARY")
    print(f"{'='*70}")

    reach_success = sum(1 for r in reach_results if not r["collision"])
    reach_fail = sum(1 for r in reach_results if r["collision"])
    print(f"REACH: {reach_success}/{len(reach_results)} success, {reach_fail}/{len(reach_results)} wall collision")

    carry_success = sum(1 for r in carry_results if not r["collision"])
    carry_fail = sum(1 for r in carry_results if r["collision"])
    print(f"CARRY: {carry_success}/{len(carry_results)} success, {carry_fail}/{len(carry_results)} wall collision")

    # ========================================================================
    # VISUALIZATIONS
    # ========================================================================
    import cv2

    # Reset and capture frame with wall
    np.random.seed(42)
    task_env.reset()
    robot = task_env._scene.robot
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    for _ in range(20):
        task_env._scene.pyrep.step()
    wall = create_wall(task_env, wall_config)
    for _ in range(10):
        task_env._scene.pyrep.step()
    obs = task_env._scene.get_observation()
    first_frame = obs.front_rgb.copy()

    # REACH visualization
    if len(reach_results) > 0:
        vis_frame_reach = visualize_on_frame(first_frame.copy(), reach_results, wall_config, front_cam, "reach")
        vis_path_reach = os.path.join(FLAGS.save_path, 'reach_trajectories_on_frame.png')
        cv2.imwrite(vis_path_reach, cv2.cvtColor(vis_frame_reach, cv2.COLOR_RGB2BGR))
        print(f"\nSaved REACH frame visualization to {vis_path_reach}")
        visualize_3d_plot(reach_results, wall_config, FLAGS.save_path, "reach")

    # CARRY visualization
    if len(carry_results) > 0:
        vis_frame_carry = visualize_on_frame(first_frame.copy(), carry_results, wall_config, front_cam, "carry")
        vis_path_carry = os.path.join(FLAGS.save_path, 'carry_trajectories_on_frame.png')
        cv2.imwrite(vis_path_carry, cv2.cvtColor(vis_frame_carry, cv2.COLOR_RGB2BGR))
        print(f"Saved CARRY frame visualization to {vis_path_carry}")
        visualize_3d_plot(carry_results, wall_config, FLAGS.save_path, "carry")

    # Save wall config
    np.save(os.path.join(FLAGS.save_path, 'wall_config.npy'), wall_config)

    # Save results
    results_data = {
        'wall_config': wall_config,
        'reach_results': [
            {
                'mode_idx': r['mode_idx'],
                'collision': r['collision'],
                'collision_idx': r['collision_idx'],
                'cp_params': list(r['cp_params']),
            }
            for r in reach_results
        ],
        'carry_results': [
            {
                'mode_idx': r['mode_idx'],
                'collision': r['collision'],
                'collision_idx': r['collision_idx'],
                'cp_params': list(r['cp_params']),
            }
            for r in carry_results
        ],
    }
    np.save(os.path.join(FLAGS.save_path, 'trajectory_results.npy'), results_data)

    # Interactive mode
    if FLAGS.interactive:
        print("\n" + "="*70)
        print("INTERACTIVE MODE - Key positions for wall tuning:")
        print("="*70)
        print(f"Robot HOME EE position: {home_ee_pos}")
        print(f"Pregrasp position: {pregrasp_pos}")
        print(f"Lift position: {lift_pos}")
        print(f"Prerelease position: {prerelease_pos}")
        print(f"Current wall Y: {wall_config['wall_y']}")
        print(f"\nSuggested wall Y range for REACH: [0.0, {pregrasp_pos[1]:.2f}]")
        print(f"  - Wall at Y=0.05 blocks most direct paths")
        print("\nTry adjusting with:")
        print(f"  python visualize_wall_trajectories.py --wall_y=0.05")
        print(f"  python visualize_wall_trajectories.py --wall_y=0.08")

    rlbench_env.shutdown()
    print("\nDone!")


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
