#!/usr/bin/env python
"""
Visualize EE trajectories with wall obstacle for close_drawer task.

This script:
1. Generates 8 modes of trajectories (same as dataset_generator_fixed_endpoints.py)
2. Checks wall collision for each trajectory
3. Saves a visualization with:
   - First camera frame as background
   - All 8 EE trajectories overlaid
   - Wall plane visualization
   - Blue = success (no wall collision), Red = failed (wall collision)

Usage:
  # Default wall position (y=-0.2)
  python visualize_wall_trajectories.py

  # Custom wall position
  python visualize_wall_trajectories.py --wall_y=-0.15

  # With wall opening
  python visualize_wall_trajectories.py --wall_y=-0.2 --wall_opening_min_x=0.2 --wall_opening_max_x=0.4

  # Interactive wall adjustment (prints positions for tuning)
  python visualize_wall_trajectories.py --interactive
"""
import sys
import os
import pickle

sys.path.insert(0, os.path.dirname(__file__))

import numpy as np

# Use non-interactive backend for matplotlib (must be before pyplot import)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# NOTE: cv2 is imported INSIDE functions that use it (after RLBench launches)
# to avoid Qt initialization conflicts on headless clusters

from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.dummy import Dummy

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment

from absl import app, flags

from push_utils import (
    DEFAULT_PHASE_STEPS,
    get_drawer_handle_position,
    set_drawer_open,
    fix_cabinet_orientation,
    reset_robot_to_default,
    compute_push_waypoints,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    generate_phase_positions,
    move_robot_to_start,
    parabola3D,
    HOME_JOINTS,
    lock_other_drawers,
    generate_push_trajectory,
)

from close_drawer_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    DRAWER_VARIATION as DEFAULT_DRAWER_VARIATION,
    DRAWER_OPEN_AMOUNT as DEFAULT_DRAWER_OPEN_AMOUNT,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
)

from wall_collision import (
    DEFAULT_WALL_CONFIG,
    DEFAULT_OPENING_CONFIG,
    WALL_STYLES,
    create_wall,
    check_wall_collision,
    WallCollisionTracker,
)


FLAGS = flags.FLAGS

# Default save path: RLBench_close_drawer/block_setting
DEFAULT_SAVE_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),  # RLBench_close_drawer/
    "block_setting"
)
flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save visualization.")
flags.DEFINE_integer("num_modes", 8, "Number of control point modes.")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point radius.")
flags.DEFINE_integer("drawer_variation", DEFAULT_DRAWER_VARIATION, "Drawer: 0=bottom, 1=middle, 2=top")
flags.DEFINE_float("drawer_open_amount", DEFAULT_DRAWER_OPEN_AMOUNT, "How far drawer is open")
flags.DEFINE_bool("interactive", False, "Interactive mode for wall tuning")

# Wall style (predefined configs: 1, 2, 3)
# Style 1: Only 0 degree could success (no opening)
# Style 2: Only 270 degree could success (no opening)
# Style 3: With opening enabled
flags.DEFINE_integer("style", 0, "Predefined wall style (1, 2, or 3). 0 = use individual flags below.")

# Wall configuration (used when style=0)
flags.DEFINE_float("wall_y", DEFAULT_WALL_CONFIG["wall_y"], "Wall Y position")
flags.DEFINE_float("wall_min_x", DEFAULT_WALL_CONFIG["wall_min_x"], "Wall min X")
flags.DEFINE_float("wall_max_x", DEFAULT_WALL_CONFIG["wall_max_x"], "Wall max X")
flags.DEFINE_float("wall_min_z", DEFAULT_WALL_CONFIG["wall_min_z"], "Wall min Z")
flags.DEFINE_float("wall_max_z", DEFAULT_WALL_CONFIG["wall_max_z"], "Wall max Z")

# Wall opening (used when style=0)
flags.DEFINE_bool("enable_opening", False, "Enable opening at handle position (disabled by default)")
flags.DEFINE_float("wall_opening_min_x", 0.0, "Opening min X (0 to disable, overrides enable_opening)")
flags.DEFINE_float("wall_opening_max_x", 0.0, "Opening max X")
flags.DEFINE_float("wall_opening_min_z", 0.0, "Opening min Z")
flags.DEFINE_float("wall_opening_max_z", 0.0, "Opening max Z")

# Collision check mode
flags.DEFINE_bool("check_full_arm", False, "Check full robot arm for collision (default: only check EE)")


def check_ee_trajectory_wall_collision(ee_positions, wall_config, debug=False):
    """
    Check if EE trajectory crosses THROUGH the wall (EE only, not full arm).

    Collision is detected when EE transitions from Y >= wall_y to Y < wall_y
    while within the wall's X-Z bounds. If EE bypasses the wall (crosses Y threshold
    while outside X-Z bounds), it can safely re-enter the X-Z bounds later.

    Returns:
        collision: bool
        collision_idx: int or None, index where collision occurred
    """
    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i in range(1, len(ee_positions)):
        prev_pos = ee_positions[i - 1]
        curr_pos = ee_positions[i]

        # Check if EE crossed the wall Y plane in this step
        # (was on robot side, now on cabinet side)
        if prev_pos[1] >= wall_y and curr_pos[1] < wall_y:
            # Interpolate to find approximate position at wall crossing
            # t = fraction along segment where Y = wall_y
            if abs(curr_pos[1] - prev_pos[1]) > 1e-6:
                t = (wall_y - prev_pos[1]) / (curr_pos[1] - prev_pos[1])
                cross_x = prev_pos[0] + t * (curr_pos[0] - prev_pos[0])
                cross_z = prev_pos[2] + t * (curr_pos[2] - prev_pos[2])
            else:
                cross_x = curr_pos[0]
                cross_z = curr_pos[2]

            # Check if crossing point is within wall X-Z bounds
            in_x_bounds = min_x <= cross_x <= max_x
            in_z_bounds = min_z <= cross_z <= max_z

            if debug:
                print(f"    [DEBUG] Y crossing at step {i}: "
                      f"cross_pos=[{cross_x:.3f}, {wall_y:.3f}, {cross_z:.3f}], "
                      f"in_x_bounds={in_x_bounds} (x in [{min_x:.2f},{max_x:.2f}]), "
                      f"in_z_bounds={in_z_bounds}")

            if not (in_x_bounds and in_z_bounds):
                # Crossed Y plane but outside wall bounds - bypassed the wall
                if debug:
                    print(f"    [DEBUG] -> Bypassed wall (outside X-Z bounds)")
                continue

            # Check if crossing point is within opening
            if opening is not None:
                if (opening["min_x"] <= cross_x <= opening["max_x"] and
                    opening["min_z"] <= cross_z <= opening["max_z"]):
                    if debug:
                        print(f"    [DEBUG] -> Passed through opening")
                    continue

            # Collision: crossed wall while within bounds
            if debug:
                print(f"    [DEBUG] -> COLLISION at wall")
            return True, i

    return False, None


def check_full_arm_trajectory_wall_collision(task_env, ee_positions, wall_config, debug=False):
    """
    Check if any robot arm link crosses the wall during trajectory execution.
    This actually moves the robot to each position and checks all links.

    Returns:
        collision: bool
        collision_idx: int or None, index where collision occurred
        collision_link: str or None, name of the colliding link
    """
    from pyrep.errors import ConfigurationPathError, IKError

    robot = task_env._scene.robot
    arm = robot.arm
    tip = arm.get_tip()

    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i, target_pos in enumerate(ee_positions):
        # Try to solve IK for this position
        try:
            # Get current orientation
            current_ori = tip.get_orientation()
            joint_positions = arm.solve_ik_via_jacobian(target_pos, euler=current_ori)
            arm.set_joint_positions(joint_positions, disable_dynamics=True)

            # Step simulation to update link positions
            for _ in range(2):
                task_env._scene.pyrep.step()

            # Check all links
            collision, link_name, link_pos = check_wall_collision(task_env, wall_config)
            if collision:
                if debug:
                    print(f"    [DEBUG] Full arm collision at step {i}/{len(ee_positions)}: "
                          f"link={link_name}, pos=[{link_pos[0]:.3f}, {link_pos[1]:.3f}, {link_pos[2]:.3f}], "
                          f"wall_y={wall_y:.3f}")
                return True, i, link_name

        except (ConfigurationPathError, IKError) as e:
            if debug:
                print(f"    [DEBUG] IK failed at step {i}: {e}")
            continue

    return False, None, None


def generate_ee_trajectory_only(task_env, handle_pos, handle_ori, cp_idx, canonical_params,
                                 control_point_radius, phase_steps, drawer_variation):
    """
    Generate EE trajectory positions without actually executing them.
    This is faster than full trajectory execution and just for visualization.
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    # Move to HOME position
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    for _ in range(10):
        task_env._scene.pyrep.step()

    home_pos = np.array(tip.get_position())
    home_ori = np.array(tip.get_orientation())

    # Compute waypoints
    waypoints = compute_push_waypoints(home_pos, handle_pos, handle_ori, None)

    # Compute control point
    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp_reach = compute_control_point_from_params(
        waypoints["start"], waypoints["handle"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    # Generate positions
    positions, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    return np.array(positions), waypoints, cp_reach, phase_indices


def project_world_to_image(points_3d, camera, image_size):
    """Project 3D world points to 2D image coordinates."""
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        # Transform to camera frame using camera matrix
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def project_wall_to_image(wall_config, camera, image_size):
    """Project wall corners to 2D."""
    wall_y = wall_config["wall_y"]
    corners_3d = [
        [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
        [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
        [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
    ]
    return project_world_to_image(corners_3d, camera, image_size)


def visualize_on_frame(frame, all_trajectories, wall_config, camera, results, style=0, num_base_modes=8):
    """
    Overlay all trajectories and wall on a single frame.

    Args:
        frame: numpy array (H, W, 3)
        all_trajectories: list of (positions, collision, collision_idx)
        wall_config: wall configuration dict
        camera: camera object
        results: list of dicts with trajectory info
        style: wall style (3 = extra mode uses dotted line)
        num_base_modes: number of base modes (extra modes have index >= this)
    """
    import cv2  # Import here to avoid Qt conflicts on headless clusters

    image_size = frame.shape[:2]
    frame_overlay = frame.copy()

    wall_y = wall_config["wall_y"]
    opening = wall_config.get("opening", None)

    # Draw wall with opening as a hole (not separate colors)
    if opening is not None:
        # Split wall into 4 rectangles around the opening:
        # Top, Bottom, Left, Right pieces
        wall_pieces = []

        # Bottom piece (below opening)
        if wall_config["wall_min_z"] < opening["min_z"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, wall_config["wall_min_z"]],
                [wall_config["wall_max_x"], wall_y, wall_config["wall_min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_min_x"], wall_y, opening["min_z"]],
            ])

        # Top piece (above opening)
        if opening["max_z"] < wall_config["wall_max_z"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, opening["max_z"]],
                [wall_config["wall_max_x"], wall_y, opening["max_z"]],
                [wall_config["wall_max_x"], wall_y, wall_config["wall_max_z"]],
                [wall_config["wall_min_x"], wall_y, wall_config["wall_max_z"]],
            ])

        # Left piece (left of opening, between opening Z bounds)
        if wall_config["wall_min_x"] < opening["min_x"]:
            wall_pieces.append([
                [wall_config["wall_min_x"], wall_y, opening["min_z"]],
                [opening["min_x"], wall_y, opening["min_z"]],
                [opening["min_x"], wall_y, opening["max_z"]],
                [wall_config["wall_min_x"], wall_y, opening["max_z"]],
            ])

        # Right piece (right of opening, between opening Z bounds)
        if opening["max_x"] < wall_config["wall_max_x"]:
            wall_pieces.append([
                [opening["max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["min_z"]],
                [wall_config["wall_max_x"], wall_y, opening["max_z"]],
                [opening["max_x"], wall_y, opening["max_z"]],
            ])

        # Draw each wall piece
        overlay = frame_overlay.copy()
        for piece_3d in wall_pieces:
            piece_2d = project_world_to_image(piece_3d, camera, image_size)
            if all(c is not None for c in piece_2d):
                pts = np.array(piece_2d, dtype=np.int32)
                cv2.fillPoly(overlay, [pts], (100, 100, 255))  # Light red in BGR

        cv2.addWeighted(overlay, 0.3, frame_overlay, 0.7, 0, frame_overlay)

        # Draw wall outline
        wall_corners = project_wall_to_image(wall_config, camera, image_size)
        if all(c is not None for c in wall_corners):
            pts = np.array(wall_corners, dtype=np.int32)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 200), 2, cv2.LINE_AA)

        # Draw opening outline
        opening_corners = [
            [opening["min_x"], wall_y, opening["min_z"]],
            [opening["max_x"], wall_y, opening["min_z"]],
            [opening["max_x"], wall_y, opening["max_z"]],
            [opening["min_x"], wall_y, opening["max_z"]],
        ]
        opening_2d = project_world_to_image(opening_corners, camera, image_size)
        if all(c is not None for c in opening_2d):
            pts = np.array(opening_2d, dtype=np.int32)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 200), 2, cv2.LINE_AA)

    else:
        # No opening - draw solid wall
        wall_corners = project_wall_to_image(wall_config, camera, image_size)
        if all(c is not None for c in wall_corners):
            pts = np.array(wall_corners, dtype=np.int32)

            # Semi-transparent red wall
            overlay = frame_overlay.copy()
            cv2.fillPoly(overlay, [pts], (100, 100, 255))  # Light red in BGR
            cv2.addWeighted(overlay, 0.3, frame_overlay, 0.7, 0, frame_overlay)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 200), 2, cv2.LINE_AA)

    # Draw each trajectory
    for result in results:
        positions = result["positions"]
        collision = result["collision"]
        collision_idx = result["collision_idx"]
        mode_idx = result["mode_idx"]

        # Project trajectory to 2D
        projected = project_world_to_image(positions, camera, image_size)

        # Color: BLUE for success, RED for collision
        if collision:
            color = (0, 0, 255)  # Red in BGR for collision
        else:
            color = (255, 100, 0)  # Blue in BGR for success

        # Determine where to stop drawing (at collision point if collided)
        if collision and collision_idx is not None:
            draw_end_idx = min(collision_idx + 1, len(projected))  # +1 to include collision point
        else:
            draw_end_idx = len(projected)

        # Check if this is an extra mode (style 3) that should use dotted line
        is_extra_mode = (style == 3 and mode_idx >= num_base_modes)

        # Draw trajectory line (thinner line, thickness=1)
        for i in range(1, draw_end_idx):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            if is_extra_mode:
                # Draw dotted line: only draw every other segment
                if i % 3 == 0:  # Draw every 3rd segment for dotted effect
                    cv2.line(frame_overlay, p1, p2, color, thickness=1, lineType=cv2.LINE_AA)
            else:
                cv2.line(frame_overlay, p1, p2, color, thickness=1, lineType=cv2.LINE_AA)

        # Draw start marker (small circle)
        if projected[0] is not None:
            cv2.circle(frame_overlay, projected[0], 3, color, -1)

        # Draw end marker at actual end position (only for success, collision uses color)
        end_idx = draw_end_idx - 1
        if end_idx >= 0 and end_idx < len(projected) and projected[end_idx] is not None:
            pt = projected[end_idx]
            if 0 <= pt[0] < image_size[1] and 0 <= pt[1] < image_size[0]:
                if not collision:
                    # Circle marker for success only
                    cv2.circle(frame_overlay, pt, 4, (0, 255, 0), -1)
                    cv2.circle(frame_overlay, pt, 4, (0, 0, 0), 1)

        # Add mode label near start
        if projected[0] is not None:
            label_pos = (projected[0][0] + 5, projected[0][1] - 5)
            cv2.putText(frame_overlay, f"M{mode_idx}", label_pos,
                       cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)

    # Add legend
    n_success = sum(1 for r in results if not r["collision"])
    n_fail = sum(1 for r in results if r["collision"])

    cv2.putText(frame_overlay, f"Success: {n_success}/{len(results)}",
               (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 100, 0), 2)
    cv2.putText(frame_overlay, f"Wall Collision: {n_fail}/{len(results)}",
               (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
    cv2.putText(frame_overlay, f"Wall Y={wall_config['wall_y']:.2f}",
               (10, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (100, 100, 255), 1)

    return frame_overlay


def visualize_3d_plot(all_results, wall_config, save_path, style=0, num_base_modes=8):
    """
    Create 3D matplotlib plot of trajectories and wall.
    View angle matches the camera frame (looking from +X, +Y toward origin).

    Args:
        all_results: list of trajectory results
        wall_config: wall configuration dict
        save_path: directory to save the plot
        style: wall style (3 = extra mode uses dotted line)
        num_base_modes: number of base modes (extra modes have index >= this)
    """
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Collect trajectory bounds for proper axis scaling
    all_x, all_y, all_z = [], [], []

    # Plot trajectories first (to get bounds)
    for result in all_results:
        positions = result["positions"]
        collision = result["collision"]
        collision_idx = result["collision_idx"]
        mode_idx = result["mode_idx"]

        # Stop at collision point if collision occurred
        if collision and collision_idx is not None:
            draw_end = min(collision_idx + 1, len(positions))
        else:
            draw_end = len(positions)

        pos_to_draw = positions[:draw_end]
        x, y, z = pos_to_draw[:, 0], pos_to_draw[:, 1], pos_to_draw[:, 2]

        all_x.extend(x)
        all_y.extend(y)
        all_z.extend(z)

        if collision:
            color = 'red'
            alpha = 0.8
        else:
            color = 'blue'
            alpha = 0.8

        # Check if this is an extra mode (style 3) that should use dotted line
        is_extra_mode = (style == 3 and mode_idx >= num_base_modes)
        linestyle = ':' if is_extra_mode else '-'

        ax.plot(x, y, z, linestyle, color=color, linewidth=2, alpha=alpha,
               label=f'Mode {mode_idx}' if mode_idx == 0 else '')

        # End marker (only for success trajectories)
        if not collision:
            ax.scatter(x[-1], y[-1], z[-1], c='green', marker='*', s=150)

    # Start point (average)
    start_pos = np.mean([r["positions"][0] for r in all_results], axis=0)
    ax.scatter(*start_pos, color='blue', s=200, marker='o', label='Start',
              edgecolors='black', linewidths=2)

    # Calculate trajectory-focused bounds with padding
    x_min, x_max = min(all_x), max(all_x)
    y_min, y_max = min(all_y), max(all_y)
    z_min, z_max = min(all_z), max(all_z)

    # Add padding (20% on each side)
    x_pad = (x_max - x_min) * 0.3
    y_pad = (y_max - y_min) * 0.3
    z_pad = (z_max - z_min) * 0.3

    # Draw wall (clipped to trajectory area for better visibility)
    wall_y = wall_config["wall_y"]
    # Clip wall X bounds to trajectory area
    wall_min_x = max(wall_config["wall_min_x"], x_min - x_pad)
    wall_max_x = min(wall_config["wall_max_x"], x_max + x_pad)
    # Clip wall Z bounds to trajectory area
    wall_min_z = max(wall_config["wall_min_z"], z_min - z_pad)
    wall_max_z = min(wall_config["wall_max_z"], z_max + z_pad)

    # Draw wall with opening as a hole (not separate colors)
    opening = wall_config.get("opening", None)

    if opening is not None:
        # Clip opening to wall bounds (for display)
        op_min_x = max(opening["min_x"], wall_min_x)
        op_max_x = min(opening["max_x"], wall_max_x)
        op_min_z = max(opening["min_z"], wall_min_z)
        op_max_z = min(opening["max_z"], wall_max_z)

        # Split wall into 4 rectangles around the opening
        wall_pieces = []

        # Bottom piece (below opening)
        if wall_min_z < op_min_z:
            wall_pieces.append([
                [wall_min_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, op_min_z],
                [wall_min_x, wall_y, op_min_z],
            ])

        # Top piece (above opening)
        if op_max_z < wall_max_z:
            wall_pieces.append([
                [wall_min_x, wall_y, op_max_z],
                [wall_max_x, wall_y, op_max_z],
                [wall_max_x, wall_y, wall_max_z],
                [wall_min_x, wall_y, wall_max_z],
            ])

        # Left piece (left of opening, between opening Z bounds)
        if wall_min_x < op_min_x:
            wall_pieces.append([
                [wall_min_x, wall_y, op_min_z],
                [op_min_x, wall_y, op_min_z],
                [op_min_x, wall_y, op_max_z],
                [wall_min_x, wall_y, op_max_z],
            ])

        # Right piece (right of opening, between opening Z bounds)
        if op_max_x < wall_max_x:
            wall_pieces.append([
                [op_max_x, wall_y, op_min_z],
                [wall_max_x, wall_y, op_min_z],
                [wall_max_x, wall_y, op_max_z],
                [op_max_x, wall_y, op_max_z],
            ])

        # Draw each wall piece
        for piece in wall_pieces:
            piece_poly = Poly3DCollection([piece], alpha=0.3, facecolor='red',
                                          edgecolor='darkred', linewidth=1)
            ax.add_collection3d(piece_poly)

    else:
        # No opening - draw solid wall
        wall_vertices = [
            [wall_min_x, wall_y, wall_min_z],
            [wall_max_x, wall_y, wall_min_z],
            [wall_max_x, wall_y, wall_max_z],
            [wall_min_x, wall_y, wall_max_z],
        ]
        wall_poly = Poly3DCollection([wall_vertices], alpha=0.3, facecolor='red',
                                      edgecolor='darkred', linewidth=2)
        ax.add_collection3d(wall_poly)

    # Set axis limits focused on trajectories
    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)
    ax.set_zlim(z_min - z_pad, z_max + z_pad)

    # Labels
    ax.set_xlabel('X (m)', fontsize=12)
    ax.set_ylabel('Y (m)', fontsize=12)
    ax.set_zlabel('Z (m)', fontsize=12)

    n_success = sum(1 for r in all_results if not r["collision"])
    n_fail = sum(1 for r in all_results if r["collision"])

    ax.set_title(f'Close Drawer - {len(all_results)} Mode Trajectories with Wall\n'
                f'Success: {n_success}/{len(all_results)} | '
                f'Wall Collision: {n_fail}/{len(all_results)} | '
                f'Wall Y={wall_config["wall_y"]:.2f}',
                fontsize=13, fontweight='bold')

    # Legend
    from matplotlib.lines import Line2D
    from matplotlib.patches import Patch
    wall_label = 'Wall (with opening)' if opening else 'Wall'
    legend_elements = [
        Patch(facecolor='red', alpha=0.3, edgecolor='darkred', label=wall_label),
        Line2D([0], [0], color='blue', linewidth=2, label='Success'),
        Line2D([0], [0], color='red', linewidth=2, label='Wall Collision'),
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)

    # View angle to match camera frame
    # Camera is at [1.25, 0.4, 1.58], looking toward cabinet at Y~-0.6
    # In the frame: left side of image = one direction, right = other
    # The red (failing) trajectories curve toward more negative Y (toward cabinet)
    # azim=-135: view from +X, -Y quadrant, X increases to left, Y decreases going "in"
    ax.view_init(elev=30, azim=-135)

    plt.tight_layout()
    plot_path = os.path.join(save_path, 'trajectories_3d.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"Saved 3D plot to {plot_path}")
    plt.close()


def main(argv):
    print(f"{'='*70}")
    print("TRAJECTORY + WALL VISUALIZATION")
    print(f"{'='*70}")

    # Build wall config based on style or individual flags
    if FLAGS.style in WALL_STYLES:
        # Use predefined style
        wall_config = WALL_STYLES[FLAGS.style].copy()
        # Deep copy opening if exists
        if wall_config.get("opening") is not None:
            wall_config["opening"] = wall_config["opening"].copy()
        print(f"Using predefined style {FLAGS.style}")
    else:
        # Use individual flags (style=0 or invalid)
        wall_config = {
            "wall_y": FLAGS.wall_y,
            "wall_min_x": FLAGS.wall_min_x,
            "wall_max_x": FLAGS.wall_max_x,
            "wall_min_z": FLAGS.wall_min_z,
            "wall_max_z": FLAGS.wall_max_z,
            "wall_thickness": 0.002,
            "wall_color": [1.0, 0.2, 0.2],
            "wall_transparency": 0.6,
            "opening": None,
        }

        if (FLAGS.wall_opening_min_x != 0 or FLAGS.wall_opening_max_x != 0):
            wall_config["opening"] = {
                "min_x": FLAGS.wall_opening_min_x,
                "max_x": FLAGS.wall_opening_max_x,
                "min_z": FLAGS.wall_opening_min_z,
                "max_z": FLAGS.wall_opening_max_z,
            }
        elif FLAGS.enable_opening:
            wall_config["opening"] = DEFAULT_OPENING_CONFIG.copy()

    print(f"Wall Configuration:")
    print(f"  Y position: {wall_config['wall_y']}")
    print(f"  X bounds: [{wall_config['wall_min_x']}, {wall_config['wall_max_x']}]")
    print(f"  Z bounds: [{wall_config['wall_min_z']}, {wall_config['wall_max_z']}]")
    if wall_config.get("opening"):
        print(f"  Opening: X=[{wall_config['opening']['min_x']}, {wall_config['opening']['max_x']}], "
              f"Z=[{wall_config['opening']['min_z']}, {wall_config['opening']['max_z']}]")
    print()

    # Create save directory
    os.makedirs(FLAGS.save_path, exist_ok=True)

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.front_camera.rgb = True
    obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.drawer_variation)

    # Initialize
    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)
    set_drawer_open(task_env, FLAGS.drawer_variation, FLAGS.drawer_open_amount)

    handle_pos, handle_ori = get_drawer_handle_position(task_env, FLAGS.drawer_variation)
    print(f"Handle position: {handle_pos}")

    # Setup camera
    front_cam = task_env._scene._cam_front
    front_cam.set_position(CAMERA_POSITION)
    front_cam.set_orientation(CAMERA_ORIENTATION)

    # Move robot to HOME position (so EE is at trajectory start point)
    robot = task_env._scene.robot
    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    for _ in range(20):
        task_env._scene.pyrep.step()

    # Print positions for reference
    tip = robot.arm.get_tip()
    home_ee_pos = np.array(tip.get_position())
    print(f"HOME EE position: [{home_ee_pos[0]:.4f}, {home_ee_pos[1]:.4f}, {home_ee_pos[2]:.4f}]")
    print(f"Handle position:  [{handle_pos[0]:.4f}, {handle_pos[1]:.4f}, {handle_pos[2]:.4f}]")
    print(f"Wall config: Y={wall_config['wall_y']:.4f}, X=[{wall_config['wall_min_x']:.4f}, {wall_config['wall_max_x']:.4f}]")

    # Create wall in scene for visualization
    wall = create_wall(task_env, wall_config)

    # Step to render wall and robot at HOME
    for _ in range(10):
        task_env._scene.pyrep.step()

    # Capture first frame with wall and robot at HOME position
    obs = task_env._scene.get_observation()
    first_frame = obs.front_rgb.copy()

    # Generate canonical control points (8 modes)
    canonical_params = generate_canonical_control_point_params(FLAGS.num_modes)
    phase_steps = DEFAULT_PHASE_STEPS.copy()

    # For style 3, add an extra mode: angle=315°, dist=1.0, pos_frac=0.5
    if FLAGS.style == 3:
        extra_mode = (np.radians(315), 1.0, 0.5)  # (angle, dist_frac, pos_frac)
        canonical_params = list(canonical_params) + [extra_mode]
        print(f"\nStyle 3: Adding extra mode (angle=315°, dist=1.0, pos_frac=0.5)")

    num_modes = len(canonical_params)
    print(f"\nGenerating {num_modes} trajectory modes...")
    print(f"Collision check mode: {'Full arm (all links)' if FLAGS.check_full_arm else 'EE only (default)'}")
    print(f"{'='*70}")

    all_results = []

    for mode_idx in range(num_modes):
        print(f"\nMode {mode_idx}:")

        try:
            # Generate EE trajectory
            positions, waypoints, cp_reach, phase_indices = generate_ee_trajectory_only(
                task_env, handle_pos, handle_ori, mode_idx, canonical_params,
                FLAGS.control_point_radius, phase_steps, FLAGS.drawer_variation
            )

            # Debug: print trajectory structure
            # phase_indices is dict of {phase_name: (start_idx, end_idx)}
            reach_range = phase_indices.get("reach", (0, len(positions)))
            push_range = phase_indices.get("push", (reach_range[1], len(positions)))
            reach_end = reach_range[1]  # End of reach phase
            print(f"  Trajectory: {len(positions)} steps total, reach=[{reach_range[0]},{reach_range[1]}], push=[{push_range[0]},{push_range[1]}]")
            print(f"  Start pos: [{positions[0][0]:.3f}, {positions[0][1]:.3f}, {positions[0][2]:.3f}]")
            print(f"  Handle pos (reach end): [{positions[reach_end-1][0]:.3f}, {positions[reach_end-1][1]:.3f}, {positions[reach_end-1][2]:.3f}]")
            print(f"  End pos (push end): [{positions[-1][0]:.3f}, {positions[-1][1]:.3f}, {positions[-1][2]:.3f}]")

            # Check wall collision based on flag
            if FLAGS.check_full_arm:
                # Check full robot arm (slower, requires IK at each step)
                collision, collision_idx, collision_link = check_full_arm_trajectory_wall_collision(
                    task_env, positions, wall_config, debug=True)
            else:
                # Check EE only (default, faster)
                collision, collision_idx = check_ee_trajectory_wall_collision(positions, wall_config, debug=True)
                collision_link = "tip" if collision else None

            angle_deg = np.degrees(canonical_params[mode_idx][0])
            dist_frac = canonical_params[mode_idx][1]

            if collision:
                # Determine which phase the collision occurred in
                phase_name = "reach" if collision_idx < reach_range[1] else "push"
                link_info = f" (link: {collision_link})" if FLAGS.check_full_arm else ""
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> WALL COLLISION at step {collision_idx} (in {phase_name} phase){link_info}")
            else:
                print(f"  Angle: {angle_deg:.0f} deg, Dist: {dist_frac:.1f} -> SUCCESS")

            all_results.append({
                "mode_idx": mode_idx,
                "positions": positions,
                "collision": collision,
                "collision_idx": collision_idx,
                "cp_params": canonical_params[mode_idx],
                "cp_reach": cp_reach,
            })

        except Exception as e:
            print(f"  Mode {mode_idx} failed: {e}")
            continue

    print(f"\n{'='*70}")
    n_success = sum(1 for r in all_results if not r["collision"])
    n_fail = sum(1 for r in all_results if r["collision"])
    print(f"Results: {n_success}/{len(all_results)} success, {n_fail}/{len(all_results)} wall collision")

    # Create visualization on first frame
    vis_frame = visualize_on_frame(first_frame, all_results, wall_config, front_cam, all_results,
                                   style=FLAGS.style, num_base_modes=FLAGS.num_modes)

    # Save visualization (import cv2 here to avoid Qt conflicts)
    import cv2
    vis_path = os.path.join(FLAGS.save_path, 'trajectories_on_frame.png')
    cv2.imwrite(vis_path, cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR))
    print(f"\nSaved frame visualization to {vis_path}")

    # Create 3D plot
    visualize_3d_plot(all_results, wall_config, FLAGS.save_path,
                      style=FLAGS.style, num_base_modes=FLAGS.num_modes)

    # Save wall config
    np.save(os.path.join(FLAGS.save_path, 'wall_config.npy'), wall_config)

    # Save successful demos in same format as dataset_generator_fixed_endpoints.py
    successful_results = [r for r in all_results if not r["collision"]]
    if successful_results:
        # Create style-specific directory
        style_name = f"style{FLAGS.style}" if FLAGS.style in WALL_STYLES else "default"
        demos_path = os.path.join(FLAGS.save_path, style_name, "episodes")
        os.makedirs(demos_path, exist_ok=True)

        print(f"\nSaving {len(successful_results)} successful demos to {demos_path}")
        print("Executing trajectories to collect full state/action data...")

        for idx, result in enumerate(successful_results):
            episode_path = os.path.join(demos_path, f"episode{idx}")
            os.makedirs(episode_path, exist_ok=True)

            # Reset environment and drawer for each demo
            fix_cabinet_orientation(task_env)
            set_drawer_open(task_env, FLAGS.drawer_variation, FLAGS.drawer_open_amount)

            # Actually execute the trajectory to collect full observations
            # This gives us states and actions needed for BC finetuning
            mode_idx = result["mode_idx"]
            temp_params = np.array([canonical_params[mode_idx]])

            try:
                demo, traj_metadata = generate_push_trajectory(
                    task_env,
                    start_pos=np.zeros(3),  # Will use HOME position
                    handle_pos=handle_pos,
                    handle_ori=handle_ori,
                    cp_idx=0,
                    canonical_params=temp_params,
                    control_point_radius=FLAGS.control_point_radius,
                    waypoint_params=None,
                    phase_steps=phase_steps,
                    steps_per_point=1,
                    target_drawer_idx=FLAGS.drawer_variation,
                )

                # Save full observations (low_dim_obs.pkl) - needed for BC finetuning
                with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
                    pickle.dump(demo, f)

                # Save EE trajectory
                ee_trace = traj_metadata.get("trace")
                if ee_trace is not None:
                    np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)
                else:
                    np.save(os.path.join(episode_path, "ee_trajectory.npy"), result["positions"])

                # Save metadata
                angle, dist_frac, pos_frac = result["cp_params"]
                metadata = {
                    'mode': result["mode_idx"],
                    'demo_in_mode': 0,
                    'cp_idx': result["mode_idx"],
                    'canonical_cp_params': tuple(result["cp_params"]),
                    'base_cp_params': tuple(result["cp_params"]),
                    'handle_pos': handle_pos.tolist(),
                    'handle_ori': handle_ori.tolist(),
                    'with_noise': False,
                    'noise_attempt': 0,
                    'success': True,
                    'wall_config': wall_config,
                    'style': FLAGS.style,
                }
                np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                angle_deg = np.degrees(angle)
                print(f"  Episode {idx}: Mode {result['mode_idx']} (angle={angle_deg:.0f}°, dist={dist_frac:.1f}), {len(demo)} steps")

            except Exception as e:
                print(f"  Episode {idx}: Mode {result['mode_idx']} FAILED: {e}")
                # Still save EE trajectory and metadata even if execution failed
                np.save(os.path.join(episode_path, "ee_trajectory.npy"), result["positions"])
                angle, dist_frac, pos_frac = result["cp_params"]
                metadata = {
                    'mode': result["mode_idx"],
                    'demo_in_mode': 0,
                    'cp_idx': result["mode_idx"],
                    'canonical_cp_params': tuple(result["cp_params"]),
                    'base_cp_params': tuple(result["cp_params"]),
                    'handle_pos': handle_pos.tolist(),
                    'handle_ori': handle_ori.tolist(),
                    'with_noise': False,
                    'noise_attempt': 0,
                    'success': False,
                    'wall_config': wall_config,
                    'style': FLAGS.style,
                    'error': str(e),
                }
                np.save(os.path.join(episode_path, "metadata.npy"), metadata)

        # Save summary metadata for all successful demos
        summary = {
            'style': FLAGS.style,
            'wall_config': wall_config,
            'num_successful': len(successful_results),
            'successful_modes': [r["mode_idx"] for r in successful_results],
            'handle_pos': handle_pos.tolist(),
            'handle_ori': handle_ori.tolist(),
        }
        np.save(os.path.join(FLAGS.save_path, style_name, "summary.npy"), summary)
        print(f"Saved summary to {os.path.join(FLAGS.save_path, style_name, 'summary.npy')}")
    else:
        print("\nNo successful demos to save.")

    # Interactive mode
    if FLAGS.interactive:
        print("\n" + "="*70)
        print("INTERACTIVE MODE - Key positions for wall tuning:")
        print("="*70)
        print(f"Robot HOME EE position: ~ [0.28, 0.0, 0.95]")
        print(f"Drawer handle position: {handle_pos}")
        print(f"Current wall Y: {wall_config['wall_y']}")
        print(f"\nSuggested wall Y range: [{handle_pos[1]:.2f}, 0.0]")
        print(f"  - Closer to {handle_pos[1]:.2f}: blocks more trajectories")
        print(f"  - Closer to 0.0: allows more trajectories")
        print("\nTry adjusting with:")
        print(f"  python visualize_wall_trajectories.py --wall_y=-0.25")
        print(f"  python visualize_wall_trajectories.py --wall_y=-0.15")

    rlbench_env.shutdown()
    print("\nDone!")


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
