"""
Dataset generator for push/close drawer task with FIXED start/end points and VARIED control points.

Design:
  - Fixed start position (robot HOME) and fixed end position (drawer handle)
  - Control points are defined with 3 parameters: (angle, dist_frac, pos_frac)
  - Each control point mode generates different curved paths to reach the same handle
  - Around each control point, sample 5 noisy variations for diversity

Train: num_modes control points x demos_per_mode noisy samples = total train demos

Usage:
  python dataset_generator_fixed_endpoints.py --num_modes=10 --demos_per_mode=5 --save_video
"""
import sys
import os

# Add parent directory for imports
sys.path.insert(0, os.path.dirname(__file__))

from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity, JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

import pickle
import numpy as np
import imageio

from absl import app
from absl import flags

from types import SimpleNamespace

from push_utils import (
    DEFAULT_PHASE_STEPS,
    get_drawer_handle_position,
    set_drawer_open,
    fix_cabinet_orientation,
    reset_robot_to_default,
    compute_push_waypoints,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    generate_push_trajectory,
    check_task_success,
    prefilter_control_points_fast,
)

from close_drawer_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    DRAWER_VARIATION as DEFAULT_DRAWER_VARIATION,
    DRAWER_OPEN_AMOUNT as DEFAULT_DRAWER_OPEN_AMOUNT,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
    CP_ANGLE_NOISE as DEFAULT_CP_ANGLE_NOISE,
    CP_DIST_NOISE as DEFAULT_CP_DIST_NOISE,
    NUM_MODES as DEFAULT_NUM_MODES,
    DEMOS_PER_MODE as DEFAULT_DEMOS_PER_MODE,
)


FLAGS = flags.FLAGS

# Default save path: use DPPO_DATA_DIR environment variable
DEFAULT_SAVE_PATH = os.path.join(
    os.environ.get("DPPO_DATA_DIR", "/tmp/dppo_data"),
    "close_drawer"
)

flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save the demos.")
flags.DEFINE_list("tasks", ["close_drawer"], "The tasks to collect.")
flags.DEFINE_string("joint_action_mode", "abs", 'Joint action mode: "vel", "abs", or "delta".')
flags.DEFINE_bool("save_video", False, "Whether to save video recordings with EE trajectory overlay.")

# Control point parameters (defaults from config file)
flags.DEFINE_integer("num_modes", DEFAULT_NUM_MODES, "Number of different control point modes for training.")
flags.DEFINE_integer("demos_per_mode", DEFAULT_DEMOS_PER_MODE, "Number of noisy demos per control point mode.")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point sampling radius (meters).")

# Noise parameters for generating variations around each control point (defaults from config file)
# Note: Position is fixed at 0.5, only angle and distance have noise
flags.DEFINE_float("cp_angle_noise", DEFAULT_CP_ANGLE_NOISE, "Noise range for angle (radians) around base control point.")
flags.DEFINE_float("cp_dist_noise", DEFAULT_CP_DIST_NOISE, "Noise range for distance fraction around base control point.")

# Drawer configuration (defaults from config file)
flags.DEFINE_integer("drawer_variation", DEFAULT_DRAWER_VARIATION, "Drawer variation: 0=bottom, 1=middle, 2=top")
flags.DEFINE_float("drawer_open_amount", DEFAULT_DRAWER_OPEN_AMOUNT, "How far the drawer is open")


def check_and_make(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


# ============================================================================
# Video Saving with EE Trajectory Overlay
# ============================================================================

def project_world_to_image(points_3d, camera, image_size):
    """
    Project 3D world points to 2D image using camera intrinsic matrix.
    Works with RLBench/CoppeliaSim cameras.
    """
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels=None):
    """
    Overlay EE trajectory on video frames with phase-based coloring.
    """
    import cv2  # Import here to avoid Qt initialization issues on cluster

    if len(frames) == 0 or len(ee_trace) == 0:
        return frames

    image_size = frames[0].shape[:2]
    projected = project_world_to_image(ee_trace, camera, image_size)

    # Define colors for different phases (BGR format for cv2)
    phase_colors = {
        "reach": (0, 255, 0),         # Green - learned/shaped
        "push": (255, 165, 0),        # Orange - linear push
    }
    default_color = (255, 255, 255)  # White

    frames_with_overlay = []

    for frame_idx, frame in enumerate(frames):
        frame_overlay = frame.copy()
        trace_idx = min(frame_idx, len(projected) - 1)

        # Draw trajectory lines up to current frame
        for i in range(1, trace_idx + 1):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            if phase_labels is not None and i < len(phase_labels):
                color = phase_colors.get(phase_labels[i], default_color)
            else:
                color = default_color

            cv2.line(frame_overlay, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)

        # Draw current position marker
        if trace_idx < len(projected) and projected[trace_idx] is not None:
            curr_pt = projected[trace_idx]
            if 0 <= curr_pt[0] < image_size[1] and 0 <= curr_pt[1] < image_size[0]:
                cv2.circle(frame_overlay, curr_pt, 6, (255, 255, 255), -1)
                cv2.circle(frame_overlay, curr_pt, 6, (0, 0, 0), 2)

        frames_with_overlay.append(frame_overlay)

    return frames_with_overlay


def save_demo(demo, episode_path, save_video=False, ee_trace=None, camera=None, phase_labels=None):
    """
    Save low-dimensional state (and optional RGB video with trajectory overlay).
    """
    with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
        pickle.dump(demo, f)

    if save_video:
        frames = []
        for obs in demo:
            # Use front camera (with custom position set)
            if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
                frames.append(obs.front_rgb)

        if len(frames) > 0:
            # Overlay EE trajectory on frames if provided
            if ee_trace is not None and camera is not None and len(ee_trace) > 0:
                frames = overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels)

            video_path = os.path.join(episode_path, "video.mp4")
            imageio.mimwrite(video_path, frames, fps=30, codec='libx264', quality=8)
            print(f"\t  Video saved: {video_path} ({len(frames)} frames)")


def sample_noisy_control_point_params(base_params, angle_noise, dist_noise):
    """
    Sample a noisy variation of control point parameters.
    Only adds noise to angle and distance. Position is kept fixed at 0.5.

    Args:
        base_params: tuple of (angle, dist_frac, pos_frac)
        angle_noise: noise range for angle (radians)
        dist_noise: noise range for distance fraction

    Returns:
        tuple of (noisy_angle, noisy_dist, pos_frac) - pos_frac unchanged
    """
    angle, dist_frac, pos_frac = base_params

    noisy_angle = angle + np.random.uniform(-angle_noise, angle_noise)
    noisy_dist = np.clip(dist_frac + np.random.uniform(-dist_noise, dist_noise), 0.1, 1.0)
    # Position is kept fixed at 0.5, no noise added

    return (noisy_angle, noisy_dist, pos_frac)


def select_diverse_control_point_indices(canonical_params, num_modes, seed=42):
    """
    Select diverse control point indices from the canonical set.
    Samples from different angular sectors for good coverage.
    """
    np.random.seed(seed)

    n_total = len(canonical_params)

    if num_modes >= n_total:
        return list(range(n_total))

    # Group by angular sector (5 sectors of 72 degrees each)
    n_sectors = 5
    sector_size = 2 * np.pi / n_sectors

    sector_indices = {i: [] for i in range(n_sectors)}
    for idx, (angle, dist, pos) in enumerate(canonical_params):
        sector = int(angle / sector_size) % n_sectors
        sector_indices[sector].append(idx)

    # Sample from each sector as evenly as possible
    selected = []
    modes_per_sector = num_modes // n_sectors
    extra_modes = num_modes % n_sectors

    for sector in range(n_sectors):
        n_from_sector = modes_per_sector + (1 if sector < extra_modes else 0)
        if len(sector_indices[sector]) > 0:
            sampled = np.random.choice(
                sector_indices[sector],
                size=min(n_from_sector, len(sector_indices[sector])),
                replace=False
            )
            selected.extend(sampled.tolist())

    # If we still need more, sample randomly from remaining
    while len(selected) < num_modes:
        remaining = [i for i in range(n_total) if i not in selected]
        if not remaining:
            break
        selected.append(np.random.choice(remaining))

    return selected[:num_modes]


def run_training_split(task_env, cfg, variation_path, canonical_params, selected_indices,
                       handle_pos, handle_ori):
    """
    Generate training demos with multiple control point modes and noisy variations per mode.

    For each mode:
      - Demo 0: Use base params (no noise). If IK fails, reject entire mode.
      - Demo 1+: Use reject sampling - if IK fails, try another noise sample (up to max attempts).

    Noise is only applied to angle and distance, position is fixed at 0.5.
    """
    num_modes = cfg.num_modes
    demos_per_mode = cfg.demos_per_mode
    max_noise_attempts = 10  # Max attempts to find a valid noisy sample

    print(f"\n{'='*70}")
    print(f"Generating TRAIN split")
    print(f"  Modes: {num_modes}, Demos per mode: {demos_per_mode}")
    print(f"  Handle position: {handle_pos}")
    print(f"  Control point radius: {cfg.control_point_radius}m")
    print(f"  CP noise: angle={cfg.cp_angle_noise}rad, dist={cfg.cp_dist_noise} (position fixed at 0.5)")
    print(f"  Reject sampling: up to {max_noise_attempts} attempts per noisy demo")
    print(f"{'='*70}\n")

    episodes_path = os.path.join(variation_path, "train", "episodes")
    check_and_make(episodes_path)

    all_metadata = []
    episode_idx = 0
    rejected_modes = []  # Track which modes were rejected

    # Get front camera for video overlay (position already set in main)
    custom_cam = task_env._scene._cam_front if cfg.save_video else None

    mode_indices = selected_indices[:num_modes]

    for mode_idx, cp_idx in enumerate(mode_indices):
        base_params = canonical_params[cp_idx]
        print(f"\n--- Mode {mode_idx}/{num_modes} (CP index: {cp_idx}) ---")
        print(f"    Base params: angle={np.degrees(base_params[0]):.1f}deg, "
              f"dist={base_params[1]:.2f}, pos={base_params[2]:.2f}")

        mode_rejected = False

        for demo_idx in range(demos_per_mode):
            if mode_rejected:
                # Skip remaining demos for this mode
                print(f"  Demo {demo_idx}/{demos_per_mode} - SKIPPED (mode rejected)")
                continue

            print(f"  Demo {demo_idx}/{demos_per_mode} (Episode {episode_idx})")

            # Determine number of attempts based on demo type
            if demo_idx == 0:
                # First demo: use base params, only 1 attempt (no noise to resample)
                max_attempts = 1
                current_params = base_params
                add_noise = False
            else:
                # Subsequent demos: use reject sampling for noisy params
                max_attempts = max_noise_attempts
                current_params = None  # Will be sampled in the loop
                add_noise = True

            attempt = 0
            success_this_demo = False

            while attempt < max_attempts and not success_this_demo:
                attempt += 1

                # Sample noisy params if needed (for demos with noise)
                if add_noise:
                    current_params = sample_noisy_control_point_params(
                        base_params,
                        cfg.cp_angle_noise,
                        cfg.cp_dist_noise
                    )
                    if attempt > 1:
                        print(f"    Resample attempt {attempt}/{max_attempts}: "
                              f"angle={np.degrees(current_params[0]):.1f}deg, dist={current_params[1]:.2f}")

                try:
                    # Reset robot to default position before reset
                    reset_robot_to_default(task_env)

                    # Reset environment, fix cabinet orientation, and set drawer to open
                    task_env.reset()
                    fix_cabinet_orientation(task_env)
                    actual_drawer_open = set_drawer_open(task_env, cfg.drawer_variation, cfg.drawer_open_amount)

                    # Re-get handle position after reset
                    current_handle_pos, current_handle_ori = get_drawer_handle_position(
                        task_env, cfg.drawer_variation
                    )

                    # Create a temporary canonical params array with our current params
                    temp_params = np.array([current_params])

                    # Generate trajectory
                    demo, traj_metadata = generate_push_trajectory(
                        task_env,
                        start_pos=np.zeros(3),
                        handle_pos=current_handle_pos,
                        handle_ori=current_handle_ori,
                        cp_idx=0,
                        canonical_params=temp_params,
                        control_point_radius=cfg.control_point_radius,
                        waypoint_params=None,
                        phase_steps=None,
                        steps_per_point=5,
                        target_drawer_idx=cfg.drawer_variation,
                    )

                    if len(demo) == 0:
                        raise RuntimeError("Empty trajectory")

                    # Validation checks
                    trace = traj_metadata.get("trace")
                    phase_indices = traj_metadata.get("phase_indices", {})

                    if trace is not None and len(trace) > 0 and phase_indices:
                        reach_range = phase_indices.get("reach", (0, 0))
                        if reach_range[1] > 0:
                            reach_end_pos = trace[reach_range[1] - 1]
                            dist_to_handle = np.linalg.norm(reach_end_pos - current_handle_pos)

                            push_range = phase_indices.get("push", (0, 0))
                            if push_range[1] > push_range[0]:
                                push_start_pos = trace[push_range[0]]
                                push_end_pos = trace[push_range[1] - 1]
                                push_distance = np.linalg.norm(push_end_pos - push_start_pos)

                                print(f"\t  Validation:")
                                print(f"\t    Handle position: {current_handle_pos}")
                                print(f"\t    Reach end → Handle distance: {dist_to_handle:.4f}m")
                                print(f"\t    Drawer joint: {actual_drawer_open:.4f}m (config: {cfg.drawer_open_amount:.4f}m)")
                                print(f"\t    Push distance: {push_distance:.4f}m")

                                if dist_to_handle > 0.01:
                                    print(f"\t    ⚠ WARNING: Reach phase ended {dist_to_handle:.4f}m from handle!")
                                if abs(push_distance - actual_drawer_open) > 0.02:
                                    print(f"\t    ⚠ WARNING: Push distance differs from drawer opening!")

                    # Check if task succeeded
                    task_success = check_task_success(task_env)

                    # Debug: check drawer joint position after trajectory
                    from pyrep.objects.joint import Joint
                    drawer_names = ['bottom', 'middle', 'top']
                    final_drawer_joint = Joint(f'drawer_joint_{drawer_names[cfg.drawer_variation]}')
                    final_drawer_pos = final_drawer_joint.get_joint_position()
                    print(f"\t    Drawer joint after trajectory: {final_drawer_pos:.4f}m (need <0.04m for success)")

                    # Save demo with video overlay
                    episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                    os.makedirs(episode_path, exist_ok=True)
                    save_demo(
                        demo, episode_path, save_video=cfg.save_video,
                        ee_trace=traj_metadata.get("trace"),
                        camera=custom_cam,
                        phase_labels=traj_metadata.get("phase_labels")
                    )

                    # Save metadata
                    metadata = {
                        'mode': mode_idx,
                        'demo_in_mode': demo_idx,
                        'cp_idx': int(cp_idx),
                        'canonical_cp_params': tuple(current_params),
                        'base_cp_params': tuple(base_params),
                        'handle_pos': current_handle_pos.tolist(),
                        'handle_ori': current_handle_ori.tolist(),
                        'with_noise': add_noise,
                        'noise_attempt': attempt if add_noise else 0,
                        'success': task_success,
                        **{k: v for k, v in traj_metadata.items() if k not in ['trace', 'phase_labels', 'gripper_states', 'waypoints']},
                    }
                    np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                    # Save EE trajectory separately for visualization
                    ee_trace = traj_metadata.get("trace")
                    if ee_trace is not None:
                        np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)
                    all_metadata.append(metadata)

                    print(f"\t  Demo length: {len(demo)} steps, success: {task_success}, noise: {add_noise}")
                    episode_idx += 1
                    success_this_demo = True

                except Exception as e:
                    error_msg = str(e)
                    if "IK failure" in error_msg:
                        if demo_idx == 0:
                            # First demo failed - reject entire mode
                            print(f"\t  ✗ Mode {mode_idx} REJECTED (base params IK failure): {error_msg}")
                            mode_rejected = True
                            rejected_modes.append(mode_idx)
                            break
                        else:
                            # Noisy demo failed - try another sample
                            if attempt < max_attempts:
                                print(f"\t  ✗ IK failure, resampling noise... ({attempt}/{max_attempts})")
                            else:
                                print(f"\t  ✗ IK failure after {max_attempts} attempts, skipping this demo")
                    else:
                        # Other error
                        print(f"\t  Attempt failed: {e}")
                        if attempt >= max_attempts:
                            print(f"\t  Failed after {max_attempts} attempts, skipping this demo")

    # Save all metadata for training split
    np.save(os.path.join(variation_path, "train", "train_metadata.npy"), all_metadata)

    print(f"\n  Generated {episode_idx} episodes")
    print(f"  Rejected modes: {len(rejected_modes)} {rejected_modes if rejected_modes else ''}")
    print(f"  Saved metadata to train_metadata.npy")

    return all_metadata


def main(argv):
    print(f"{'='*70}")
    print("CLOSE DRAWER DATASET GENERATOR")
    print(f"{'='*70}")
    print(f"Save path: {FLAGS.save_path}")
    print(f"Joint action mode: {FLAGS.joint_action_mode}")
    print(f"Control point modes: {FLAGS.num_modes}")
    print(f"Demos per mode: {FLAGS.demos_per_mode}")
    print(f"Total episodes: {FLAGS.num_modes * FLAGS.demos_per_mode}")
    print(f"Control point radius: {FLAGS.control_point_radius}m")
    print(f"Drawer variation: {FLAGS.drawer_variation} ({'bottom' if FLAGS.drawer_variation==0 else 'middle' if FLAGS.drawer_variation==1 else 'top'})")
    print(f"Save video: {FLAGS.save_video}")
    print(f"{'='*70}\n")

    # Generate canonical control point parameters
    # Structure: num_modes = num_angles * 2 (two distance values: 1.0 and 0.5)
    # Angles are evenly split across 360 degrees, position fixed at 0.5
    canonical_params = generate_canonical_control_point_params(FLAGS.num_modes)
    print(f"Generated {len(canonical_params)} canonical control point parameters:")
    num_angles = FLAGS.num_modes // 2
    print(f"  Angles: {num_angles} (every {360.0/num_angles:.1f}°)")
    print(f"  Distances: [1.0, 0.5]")
    print(f"  Position: fixed at 0.5")

    # Use all generated params directly (indices 0 to num_modes-1)
    selected_indices = list(range(len(canonical_params)))

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    if FLAGS.save_video:
        # Use front camera with custom position for good view (robot on left, cabinet on right)
        obs_config.front_camera.rgb = True
        obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    if FLAGS.joint_action_mode == "abs":
        ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                            -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
        ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                              5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

        class CustomMoveArmThenGripper(MoveArmThenGripper):
            def action_bounds(self):
                return (ACT_MIN, ACT_MIN + ACT_RANGE)

        action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    else:
        action_mode = MoveArmThenGripper(JointVelocity(), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    # Get task
    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.drawer_variation)

    # Initialize episode, fix cabinet orientation, and get handle position
    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)  # Fix cabinet to face robot
    set_drawer_open(task_env, FLAGS.drawer_variation, FLAGS.drawer_open_amount)

    handle_pos, handle_ori = get_drawer_handle_position(task_env, FLAGS.drawer_variation)
    print(f"\nDrawer handle position: {handle_pos}")
    print(f"Drawer handle orientation: {handle_ori}")

    # Create output directory
    variation_path = os.path.join(
        FLAGS.save_path,
        f"variation{FLAGS.drawer_variation}"
    )
    check_and_make(variation_path)

    # Get camera and set custom position for visualization
    front_cam = task_env._scene._cam_front
    front_cam.set_position(CAMERA_POSITION)
    front_cam.set_orientation(CAMERA_ORIENTATION)

    # Get the camera matrix after positioning (needed for visualization)
    camera_matrix = front_cam.get_matrix()[:3, :3]
    camera_fov = front_cam.get_perspective_angle()

    # Save canonical parameters and configuration
    np.save(os.path.join(variation_path, "canonical_params.npy"), canonical_params)
    config = {
        'drawer_variation': FLAGS.drawer_variation,
        'drawer_open_amount': FLAGS.drawer_open_amount,
        'control_point_radius': FLAGS.control_point_radius,
        'num_modes': FLAGS.num_modes,
        'demos_per_mode': FLAGS.demos_per_mode,
        'cp_angle_noise': FLAGS.cp_angle_noise,
        'cp_dist_noise': FLAGS.cp_dist_noise,
        'cp_pos_fixed': 0.5,  # Position is fixed at 0.5
        'handle_pos': handle_pos.tolist(),
        'handle_ori': handle_ori.tolist(),
        'selected_indices': selected_indices,
        'phase_steps': DEFAULT_PHASE_STEPS,
        'camera_position': CAMERA_POSITION,
        'camera_orientation': CAMERA_ORIENTATION,
        'camera_matrix': camera_matrix.tolist(),
        'camera_fov': camera_fov,
    }
    np.save(os.path.join(variation_path, "config.npy"), config)

    # Create config namespace (position noise removed - position is fixed at 0.5)
    cfg = SimpleNamespace(
        save_video=FLAGS.save_video,
        control_point_radius=FLAGS.control_point_radius,
        cp_angle_noise=FLAGS.cp_angle_noise,
        cp_dist_noise=FLAGS.cp_dist_noise,
        drawer_variation=FLAGS.drawer_variation,
        drawer_open_amount=FLAGS.drawer_open_amount,
        num_modes=FLAGS.num_modes,
        demos_per_mode=FLAGS.demos_per_mode,
    )

    # Generate training split only
    train_metadata = run_training_split(
        task_env, cfg, variation_path, canonical_params, selected_indices,
        handle_pos, handle_ori
    )

    # Summary
    n_success = sum(1 for m in train_metadata if m.get("success", False))
    n_rejected = sum(1 for m in train_metadata if m.get("rejected", False))
    n_valid = len(train_metadata) - n_rejected

    print(f"\n{'='*70}")
    print("DATA GENERATION COMPLETE")
    print(f"  Train modes: {FLAGS.num_modes}")
    print(f"  Total episodes attempted: {len(train_metadata)}")
    print(f"  Valid episodes: {n_valid} ({n_success} successful)")
    print(f"  Rejected (IK failure): {n_rejected}")
    print(f"  Data saved to: {variation_path}")
    print(f"{'='*70}\n")

    rlbench_env.shutdown()


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
