"""
Dataset generator for close drawer task WITH WALL OBSTACLE.

This extends dataset_generator_fixed_endpoints.py to add:
  - A 2D wall between robot and drawer handle
  - Collision detection for all robot arm links
  - Trajectory stops (and episode fails) when robot touches wall
  - Wall visualization in video frames

Design:
  - Wall is a plane at y = wall_y (between robot at y≈0 and handle at y≈-0.35)
  - Wall has configurable opening to allow some trajectories through
  - Trajectories that hit the wall are marked as failed
  - Only successful (non-colliding) demos are saved

Usage:
  python dataset_generator_with_wall.py --num_modes=8 --demos_per_mode=10 --save_video

  # With custom wall position:
  python dataset_generator_with_wall.py --wall_y=-0.15 --save_video

  # With opening in wall:
  python dataset_generator_with_wall.py --wall_opening_min_x=0.2 --wall_opening_max_x=0.4 --save_video
"""
import sys
import os

# Add parent directory for imports
sys.path.insert(0, os.path.dirname(__file__))

from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity, JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

import pickle
import numpy as np
import imageio

from absl import app
from absl import flags

from types import SimpleNamespace

from push_utils import (
    DEFAULT_PHASE_STEPS,
    get_drawer_handle_position,
    set_drawer_open,
    fix_cabinet_orientation,
    reset_robot_to_default,
    compute_push_waypoints,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    check_task_success,
    prefilter_control_points_fast,
    move_robot_to_start,
    generate_phase_positions,
    parabola3D,
    HOME_JOINTS,
    lock_other_drawers,
)

from close_drawer_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    DRAWER_VARIATION as DEFAULT_DRAWER_VARIATION,
    DRAWER_OPEN_AMOUNT as DEFAULT_DRAWER_OPEN_AMOUNT,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
    CP_ANGLE_NOISE as DEFAULT_CP_ANGLE_NOISE,
    CP_DIST_NOISE as DEFAULT_CP_DIST_NOISE,
    NUM_MODES as DEFAULT_NUM_MODES,
    DEMOS_PER_MODE as DEFAULT_DEMOS_PER_MODE,
)

from wall_collision import (
    DEFAULT_WALL_CONFIG,
    create_wall,
    check_wall_collision,
    WallCollisionTracker,
    overlay_wall_on_frames,
    project_wall_to_image,
)


FLAGS = flags.FLAGS

# Default save path
DEFAULT_SAVE_PATH = os.path.join(
    os.environ.get("DPPO_DATA_DIR", "/tmp/dppo_data"),
    "close_drawer_with_wall"
)

flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save the demos.")
flags.DEFINE_list("tasks", ["close_drawer"], "The tasks to collect.")
flags.DEFINE_string("joint_action_mode", "abs", 'Joint action mode: "vel", "abs", or "delta".')
flags.DEFINE_bool("save_video", False, "Whether to save video recordings with wall overlay.")

# Control point parameters
flags.DEFINE_integer("num_modes", DEFAULT_NUM_MODES, "Number of different control point modes.")
flags.DEFINE_integer("demos_per_mode", DEFAULT_DEMOS_PER_MODE, "Number of noisy demos per mode.")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point radius (m).")
flags.DEFINE_float("cp_angle_noise", DEFAULT_CP_ANGLE_NOISE, "Angle noise (radians).")
flags.DEFINE_float("cp_dist_noise", DEFAULT_CP_DIST_NOISE, "Distance noise fraction.")

# Drawer configuration
flags.DEFINE_integer("drawer_variation", DEFAULT_DRAWER_VARIATION, "Drawer: 0=bottom, 1=middle, 2=top")
flags.DEFINE_float("drawer_open_amount", DEFAULT_DRAWER_OPEN_AMOUNT, "How far drawer is open")

# Wall configuration
flags.DEFINE_float("wall_y", DEFAULT_WALL_CONFIG["wall_y"], "Wall Y position (plane equation)")
flags.DEFINE_float("wall_min_x", DEFAULT_WALL_CONFIG["wall_min_x"], "Wall min X bound")
flags.DEFINE_float("wall_max_x", DEFAULT_WALL_CONFIG["wall_max_x"], "Wall max X bound")
flags.DEFINE_float("wall_min_z", DEFAULT_WALL_CONFIG["wall_min_z"], "Wall min Z bound")
flags.DEFINE_float("wall_max_z", DEFAULT_WALL_CONFIG["wall_max_z"], "Wall max Z bound")

# Wall opening (optional - set to 0 to disable)
flags.DEFINE_float("wall_opening_min_x", 0.0, "Opening min X (0 to disable)")
flags.DEFINE_float("wall_opening_max_x", 0.0, "Opening max X (0 to disable)")
flags.DEFINE_float("wall_opening_min_z", 0.0, "Opening min Z (0 to disable)")
flags.DEFINE_float("wall_opening_max_z", 0.0, "Opening max Z (0 to disable)")

# Behavior flags
flags.DEFINE_bool("save_failed", False, "Whether to save failed (wall collision) episodes")


def check_and_make(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


# ============================================================================
# Video Saving with Wall and Trajectory Overlay
# ============================================================================

def project_world_to_image(points_3d, camera, image_size):
    """Project 3D world points to 2D image."""
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def overlay_trajectory_and_wall_on_frames(frames, ee_trace, camera, wall_config,
                                           phase_labels=None, collision_frame=None):
    """
    Overlay both EE trajectory and wall on video frames.
    """
    import cv2

    if len(frames) == 0:
        return frames

    image_size = frames[0].shape[:2]
    projected = project_world_to_image(ee_trace, camera, image_size) if len(ee_trace) > 0 else []

    # Phase colors (BGR for cv2)
    phase_colors = {
        "reach": (0, 255, 0),      # Green
        "push": (255, 165, 0),     # Orange
    }
    default_color = (255, 255, 255)
    collision_color = (0, 0, 255)  # Red for collision

    frames_with_overlay = []

    for frame_idx, frame in enumerate(frames):
        frame_overlay = frame.copy()
        trace_idx = min(frame_idx, len(projected) - 1) if len(projected) > 0 else 0

        # Determine if we're at/past collision
        is_collision_frame = (collision_frame is not None and frame_idx >= collision_frame)

        # Draw wall first (semi-transparent)
        wall_corners = project_wall_to_image(wall_config, camera, image_size)
        if all(c is not None for c in wall_corners):
            pts = np.array(wall_corners, dtype=np.int32)

            # Wall color changes on collision
            if is_collision_frame:
                wall_color = (0, 0, 200)  # Red
                wall_alpha = 0.5
            else:
                wall_color = (100, 100, 255)  # Light red
                wall_alpha = 0.3

            overlay = frame_overlay.copy()
            cv2.fillPoly(overlay, [pts], wall_color)
            cv2.addWeighted(overlay, wall_alpha, frame_overlay, 1 - wall_alpha, 0, frame_overlay)
            cv2.polylines(frame_overlay, [pts], True, (50, 50, 150), 2, cv2.LINE_AA)

        # Draw trajectory lines up to current frame
        if len(projected) > 0:
            for i in range(1, trace_idx + 1):
                p1 = projected[i - 1]
                p2 = projected[i]

                if p1 is None or p2 is None:
                    continue
                if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                    continue
                if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                    continue

                # Color based on phase and collision
                if collision_frame is not None and i >= collision_frame:
                    color = collision_color
                elif phase_labels is not None and i < len(phase_labels):
                    color = phase_colors.get(phase_labels[i], default_color)
                else:
                    color = default_color

                cv2.line(frame_overlay, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)

            # Draw current position marker
            if trace_idx < len(projected) and projected[trace_idx] is not None:
                curr_pt = projected[trace_idx]
                if 0 <= curr_pt[0] < image_size[1] and 0 <= curr_pt[1] < image_size[0]:
                    marker_color = collision_color if is_collision_frame else (255, 255, 255)
                    cv2.circle(frame_overlay, curr_pt, 6, marker_color, -1)
                    cv2.circle(frame_overlay, curr_pt, 6, (0, 0, 0), 2)

        # Add collision text if applicable
        if is_collision_frame:
            cv2.putText(frame_overlay, "WALL COLLISION", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

        frames_with_overlay.append(frame_overlay)

    return frames_with_overlay


def save_demo(demo, episode_path, save_video=False, ee_trace=None, camera=None,
              phase_labels=None, wall_config=None, collision_frame=None):
    """Save demo with optional video showing wall and trajectory."""
    with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
        pickle.dump(demo, f)

    if save_video:
        frames = []
        for obs in demo:
            if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
                frames.append(obs.front_rgb)

        if len(frames) > 0:
            if ee_trace is not None and camera is not None and wall_config is not None:
                frames = overlay_trajectory_and_wall_on_frames(
                    frames, ee_trace, camera, wall_config,
                    phase_labels, collision_frame
                )

            video_path = os.path.join(episode_path, "video.mp4")
            imageio.mimwrite(video_path, frames, fps=30, codec='libx264', quality=8)
            print(f"\t  Video saved: {video_path} ({len(frames)} frames)")


def sample_noisy_control_point_params(base_params, angle_noise, dist_noise):
    """Sample noisy variation of control point parameters."""
    angle, dist_frac, pos_frac = base_params
    noisy_angle = angle + np.random.uniform(-angle_noise, angle_noise)
    noisy_dist = np.clip(dist_frac + np.random.uniform(-dist_noise, dist_noise), 0.1, 1.0)
    return (noisy_angle, noisy_dist, pos_frac)


def generate_push_trajectory_with_wall(
    task_env,
    handle_pos,
    handle_ori,
    cp_idx,
    canonical_params,
    control_point_radius,
    wall_config,
    waypoint_params=None,
    phase_steps=None,
    steps_per_point=5,
    target_drawer_idx=2,
):
    """
    Generate push/close drawer trajectory with wall collision checking.

    Trajectory stops when any robot link crosses the wall plane.

    Returns:
        demo: list of observations
        metadata: dict with trajectory info
        wall_collision: bool, True if wall was hit
        collision_step: int or None, step where collision occurred
    """
    if phase_steps is None:
        phase_steps = DEFAULT_PHASE_STEPS.copy()

    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Move to HOME position
    actual_start_pos, actual_ori = move_robot_to_start(
        task_env, np.zeros(3), handle_ori, gripper_open=True
    )

    # Compute waypoints
    waypoints = compute_push_waypoints(actual_start_pos, handle_pos, handle_ori, waypoint_params)

    # Compute control point for reach phase
    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp_reach = compute_control_point_from_params(
        waypoints["start"], waypoints["handle"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    print(f"    cp_reach: [{cp_reach[0]:.4f}, {cp_reach[1]:.4f}, {cp_reach[2]:.4f}] "
          f"(angle={np.degrees(angle):.1f}deg)")

    # Generate target positions
    positions, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    total_steps = len(positions)

    # Initialize wall collision tracker
    wall_tracker = WallCollisionTracker(task_env, wall_config)

    # Execute trajectory
    demo = []
    trace = []
    trace_phase_labels = []
    gripper_states = []

    # Initial observation
    demo.append(task_env._scene.get_observation())
    trace.append(tip.get_position().copy())
    trace_phase_labels.append("reach")
    gripper_states.append(1.0)

    successful_steps = 1
    failed_steps = 0
    prev_joints = list(robot.arm.get_joint_positions())
    wall_collision = False
    collision_step = None

    for i in range(1, len(positions)):
        # Check for wall collision BEFORE executing step
        if wall_tracker.check_and_update():
            wall_collision = True
            collision_step = i
            print(f"\t  WALL COLLISION at step {i}! Link: {wall_tracker.collision_link}")
            break

        target_pos_step = positions[i]
        phase = phase_labels[i]

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)

            if phase == "reach":
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos_step, euler=actual_ori)
                robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            else:
                joint_positions = robot.arm.solve_ik_via_jacobian(target_pos_step)
                robot.arm.set_joint_target_positions(joint_positions)

            robot.arm.set_joint_target_velocities([0] * 7)
            gripper.actuate(1.0, 0.2)

            sim_steps = steps_per_point * 2 if phase == "push" else steps_per_point
            for _ in range(sim_steps):
                task_env._scene.pyrep.step()
                task_env._scene.task.step()
                lock_other_drawers(target_drawer_idx)

            obs = task_env._scene.get_observation()
            if not hasattr(obs, 'misc'):
                obs.misc = {}
            obs.misc['joint_position_action'] = np.concatenate([joint_positions, [1.0]])
            demo.append(obs)

            trace.append(tip.get_position().copy())
            trace_phase_labels.append(phase)
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)
            successful_steps += 1

        except (IKError, ConfigurationPathError) as e:
            failed_steps += 1
            print(f"    IK failed at step {i}/{len(positions)}")
            raise RuntimeError(f"IK failure at step {i}/{len(positions)}")

    # Build metadata
    metadata = {
        "waypoints_list": {k: v.tolist() for k, v in waypoints.items()},
        "cp_reach": cp_reach.tolist(),
        "phase_indices": phase_indices,
        "phase_steps": phase_steps,
        "ik_failures": failed_steps,
        "actual_orientation": actual_ori.tolist(),
        "trace": np.array(trace),
        "phase_labels": trace_phase_labels,
        "gripper_states": gripper_states,
        "wall_collision": wall_collision,
        "collision_step": collision_step,
        "wall_config": wall_config,
    }

    return demo, metadata, wall_collision, collision_step


def run_training_split_with_wall(task_env, cfg, variation_path, canonical_params,
                                  selected_indices, handle_pos, handle_ori, wall_config):
    """Generate training demos with wall collision checking."""
    num_modes = cfg.num_modes
    demos_per_mode = cfg.demos_per_mode
    max_noise_attempts = 10

    print(f"\n{'='*70}")
    print(f"Generating TRAIN split WITH WALL OBSTACLE")
    print(f"  Wall Y: {wall_config['wall_y']:.2f}")
    print(f"  Wall bounds: X=[{wall_config['wall_min_x']:.2f}, {wall_config['wall_max_x']:.2f}], "
          f"Z=[{wall_config['wall_min_z']:.2f}, {wall_config['wall_max_z']:.2f}]")
    if wall_config.get("opening"):
        print(f"  Wall opening: X=[{wall_config['opening']['min_x']:.2f}, {wall_config['opening']['max_x']:.2f}], "
              f"Z=[{wall_config['opening']['min_z']:.2f}, {wall_config['opening']['max_z']:.2f}]")
    print(f"  Modes: {num_modes}, Demos per mode: {demos_per_mode}")
    print(f"{'='*70}\n")

    # Create the wall in the scene
    wall = create_wall(task_env, wall_config)

    episodes_path = os.path.join(variation_path, "train", "episodes")
    check_and_make(episodes_path)

    # Also create failed episodes path if saving failed demos
    if cfg.save_failed:
        failed_path = os.path.join(variation_path, "train", "failed_episodes")
        check_and_make(failed_path)

    all_metadata = []
    episode_idx = 0
    failed_episode_idx = 0
    rejected_modes = []

    custom_cam = task_env._scene._cam_front if cfg.save_video else None
    mode_indices = selected_indices[:num_modes]

    # Statistics
    total_attempts = 0
    wall_collisions = 0
    ik_failures = 0

    for mode_idx, cp_idx in enumerate(mode_indices):
        base_params = canonical_params[cp_idx]
        print(f"\n--- Mode {mode_idx}/{num_modes} (CP index: {cp_idx}) ---")
        print(f"    Base params: angle={np.degrees(base_params[0]):.1f}deg, "
              f"dist={base_params[1]:.2f}, pos={base_params[2]:.2f}")

        mode_rejected = False

        for demo_idx in range(demos_per_mode):
            if mode_rejected:
                print(f"  Demo {demo_idx}/{demos_per_mode} - SKIPPED (mode rejected)")
                continue

            print(f"  Demo {demo_idx}/{demos_per_mode} (Episode {episode_idx})")

            if demo_idx == 0:
                max_attempts = 1
                current_params = base_params
                add_noise = False
            else:
                max_attempts = max_noise_attempts
                current_params = None
                add_noise = True

            attempt = 0
            success_this_demo = False

            while attempt < max_attempts and not success_this_demo:
                attempt += 1
                total_attempts += 1

                if add_noise:
                    current_params = sample_noisy_control_point_params(
                        base_params, cfg.cp_angle_noise, cfg.cp_dist_noise
                    )
                    if attempt > 1:
                        print(f"    Resample attempt {attempt}/{max_attempts}")

                try:
                    reset_robot_to_default(task_env)
                    task_env.reset()
                    fix_cabinet_orientation(task_env)
                    set_drawer_open(task_env, cfg.drawer_variation, cfg.drawer_open_amount)

                    current_handle_pos, current_handle_ori = get_drawer_handle_position(
                        task_env, cfg.drawer_variation
                    )

                    temp_params = np.array([current_params])

                    demo, traj_metadata, wall_collision, collision_step = generate_push_trajectory_with_wall(
                        task_env,
                        handle_pos=current_handle_pos,
                        handle_ori=current_handle_ori,
                        cp_idx=0,
                        canonical_params=temp_params,
                        control_point_radius=cfg.control_point_radius,
                        wall_config=wall_config,
                        waypoint_params=None,
                        phase_steps=None,
                        steps_per_point=5,
                        target_drawer_idx=cfg.drawer_variation,
                    )

                    if wall_collision:
                        wall_collisions += 1
                        print(f"\t  ✗ Wall collision at step {collision_step}")

                        if cfg.save_failed:
                            # Save failed episode
                            failed_ep_path = os.path.join(failed_path, f"episode{failed_episode_idx}")
                            os.makedirs(failed_ep_path, exist_ok=True)
                            save_demo(
                                demo, failed_ep_path, save_video=cfg.save_video,
                                ee_trace=traj_metadata.get("trace"),
                                camera=custom_cam,
                                phase_labels=traj_metadata.get("phase_labels"),
                                wall_config=wall_config,
                                collision_frame=collision_step,
                            )
                            failed_episode_idx += 1

                        # Wall collision on base params -> reject mode
                        if demo_idx == 0:
                            mode_rejected = True
                            rejected_modes.append(mode_idx)
                            print(f"\t  Mode {mode_idx} REJECTED (base params hit wall)")
                            break
                        else:
                            # Noisy demo hit wall, try another sample
                            continue

                    # No wall collision - successful demo
                    task_success = check_task_success(task_env)

                    episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                    os.makedirs(episode_path, exist_ok=True)
                    save_demo(
                        demo, episode_path, save_video=cfg.save_video,
                        ee_trace=traj_metadata.get("trace"),
                        camera=custom_cam,
                        phase_labels=traj_metadata.get("phase_labels"),
                        wall_config=wall_config,
                        collision_frame=None,
                    )

                    metadata = {
                        'mode': mode_idx,
                        'demo_in_mode': demo_idx,
                        'cp_idx': int(cp_idx),
                        'canonical_cp_params': tuple(current_params),
                        'base_cp_params': tuple(base_params),
                        'handle_pos': current_handle_pos.tolist(),
                        'handle_ori': current_handle_ori.tolist(),
                        'with_noise': add_noise,
                        'noise_attempt': attempt if add_noise else 0,
                        'success': task_success,
                        'wall_collision': False,
                        **{k: v for k, v in traj_metadata.items()
                           if k not in ['trace', 'phase_labels', 'gripper_states', 'waypoints', 'wall_config']},
                    }
                    np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                    ee_trace = traj_metadata.get("trace")
                    if ee_trace is not None:
                        np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)

                    all_metadata.append(metadata)

                    print(f"\t  Demo length: {len(demo)} steps, success: {task_success}")
                    episode_idx += 1
                    success_this_demo = True

                except Exception as e:
                    error_msg = str(e)
                    if "IK failure" in error_msg:
                        ik_failures += 1
                        if demo_idx == 0:
                            print(f"\t  ✗ Mode {mode_idx} REJECTED (IK failure): {error_msg}")
                            mode_rejected = True
                            rejected_modes.append(mode_idx)
                            break
                        else:
                            if attempt < max_attempts:
                                print(f"\t  ✗ IK failure, resampling...")
                            else:
                                print(f"\t  ✗ IK failure after {max_attempts} attempts")
                    else:
                        print(f"\t  Attempt failed: {e}")

    # Save metadata
    np.save(os.path.join(variation_path, "train", "train_metadata.npy"), all_metadata)

    # Save wall config
    np.save(os.path.join(variation_path, "wall_config.npy"), wall_config)

    print(f"\n{'='*70}")
    print(f"TRAINING GENERATION COMPLETE")
    print(f"  Successful episodes: {episode_idx}")
    print(f"  Failed (wall collision): {wall_collisions}")
    print(f"  IK failures: {ik_failures}")
    print(f"  Rejected modes: {len(rejected_modes)} {rejected_modes if rejected_modes else ''}")
    print(f"  Wall collision rate: {wall_collisions / max(total_attempts, 1) * 100:.1f}%")
    print(f"{'='*70}\n")

    return all_metadata


def main(argv):
    print(f"{'='*70}")
    print("CLOSE DRAWER DATASET GENERATOR WITH WALL OBSTACLE")
    print(f"{'='*70}")
    print(f"Save path: {FLAGS.save_path}")
    print(f"Modes: {FLAGS.num_modes}, Demos/mode: {FLAGS.demos_per_mode}")
    print(f"Save video: {FLAGS.save_video}")
    print(f"Save failed demos: {FLAGS.save_failed}")
    print(f"{'='*70}\n")

    # Build wall config from flags
    wall_config = {
        "wall_y": FLAGS.wall_y,
        "wall_min_x": FLAGS.wall_min_x,
        "wall_max_x": FLAGS.wall_max_x,
        "wall_min_z": FLAGS.wall_min_z,
        "wall_max_z": FLAGS.wall_max_z,
        "wall_thickness": 0.002,
        "wall_color": [1.0, 0.2, 0.2],
        "wall_transparency": 0.6,
        "opening": None,
    }

    # Check if opening is defined
    if (FLAGS.wall_opening_min_x != 0 or FLAGS.wall_opening_max_x != 0 or
        FLAGS.wall_opening_min_z != 0 or FLAGS.wall_opening_max_z != 0):
        wall_config["opening"] = {
            "min_x": FLAGS.wall_opening_min_x,
            "max_x": FLAGS.wall_opening_max_x,
            "min_z": FLAGS.wall_opening_min_z,
            "max_z": FLAGS.wall_opening_max_z,
        }

    print(f"Wall Configuration:")
    print(f"  Y position: {wall_config['wall_y']}")
    print(f"  X bounds: [{wall_config['wall_min_x']}, {wall_config['wall_max_x']}]")
    print(f"  Z bounds: [{wall_config['wall_min_z']}, {wall_config['wall_max_z']}]")
    if wall_config["opening"]:
        print(f"  Opening: X=[{wall_config['opening']['min_x']}, {wall_config['opening']['max_x']}], "
              f"Z=[{wall_config['opening']['min_z']}, {wall_config['opening']['max_z']}]")
    print()

    # Generate canonical control points
    canonical_params = generate_canonical_control_point_params(FLAGS.num_modes)
    selected_indices = list(range(len(canonical_params)))

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    if FLAGS.save_video:
        obs_config.front_camera.rgb = True
        obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.drawer_variation)

    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)
    set_drawer_open(task_env, FLAGS.drawer_variation, FLAGS.drawer_open_amount)

    handle_pos, handle_ori = get_drawer_handle_position(task_env, FLAGS.drawer_variation)
    print(f"Drawer handle position: {handle_pos}")

    # Create output directory
    variation_path = os.path.join(FLAGS.save_path, f"variation{FLAGS.drawer_variation}")
    check_and_make(variation_path)

    # Setup camera
    front_cam = task_env._scene._cam_front
    front_cam.set_position(CAMERA_POSITION)
    front_cam.set_orientation(CAMERA_ORIENTATION)

    # Save configs
    np.save(os.path.join(variation_path, "canonical_params.npy"), canonical_params)
    np.save(os.path.join(variation_path, "wall_config.npy"), wall_config)

    config = {
        'drawer_variation': FLAGS.drawer_variation,
        'drawer_open_amount': FLAGS.drawer_open_amount,
        'control_point_radius': FLAGS.control_point_radius,
        'num_modes': FLAGS.num_modes,
        'demos_per_mode': FLAGS.demos_per_mode,
        'cp_angle_noise': FLAGS.cp_angle_noise,
        'cp_dist_noise': FLAGS.cp_dist_noise,
        'handle_pos': handle_pos.tolist(),
        'handle_ori': handle_ori.tolist(),
        'wall_config': wall_config,
    }
    np.save(os.path.join(variation_path, "config.npy"), config)

    cfg = SimpleNamespace(
        save_video=FLAGS.save_video,
        save_failed=FLAGS.save_failed,
        control_point_radius=FLAGS.control_point_radius,
        cp_angle_noise=FLAGS.cp_angle_noise,
        cp_dist_noise=FLAGS.cp_dist_noise,
        drawer_variation=FLAGS.drawer_variation,
        drawer_open_amount=FLAGS.drawer_open_amount,
        num_modes=FLAGS.num_modes,
        demos_per_mode=FLAGS.demos_per_mode,
    )

    # Generate training data
    train_metadata = run_training_split_with_wall(
        task_env, cfg, variation_path, canonical_params, selected_indices,
        handle_pos, handle_ori, wall_config
    )

    # Summary
    n_success = sum(1 for m in train_metadata if m.get("success", False))

    print(f"\n{'='*70}")
    print("DATA GENERATION COMPLETE")
    print(f"  Train episodes: {len(train_metadata)} ({n_success} task success)")
    print(f"  Data saved to: {variation_path}")
    print(f"{'='*70}\n")

    rlbench_env.shutdown()


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
