"""
Dataset generator for grasp (pick_up_cup) task with parameterized approach angle and height.

Design:
  - Each grasp mode is defined by (approach_angle, grasp_height)
  - 4 angles x 4 heights = 16 modes
  - Around each mode, sample noisy variations for diversity
  - Total: 16 modes x 10 demos/mode = 160 demos

Usage:
  python dataset_generator_fixed_endpoints.py --num_modes=16 --demos_per_mode=10 --save_video
"""
import sys
import os

# Add parent directory for imports
sys.path.insert(0, os.path.dirname(__file__))

from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity, JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

import pickle
import numpy as np
import imageio

from absl import app
from absl import flags

from types import SimpleNamespace

from grasp_utils import (
    DEFAULT_PHASE_STEPS,
    get_cup_position,
    get_cup_base_z,
    set_cup_position,
    reset_robot_to_default,
    compute_grasp_waypoints,
    generate_grasp_trajectory,
    check_task_success,
    check_grasp_success_manual,
)

from grasp_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    CUP_VARIATION as DEFAULT_CUP_VARIATION,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
    ANGLE_NOISE as DEFAULT_ANGLE_NOISE,
    HEIGHT_NOISE as DEFAULT_HEIGHT_NOISE,
    NUM_MODES as DEFAULT_NUM_MODES,
    DEMOS_PER_MODE as DEFAULT_DEMOS_PER_MODE,
    APPROACH_ANGLES_DEG,
    GRASP_HEIGHTS,
    FIXED_CUP_POSITION,
    HOME_JOINTS,
    generate_canonical_mode_params,
)


FLAGS = flags.FLAGS

# Default save path: use DPPO_DATA_DIR environment variable
DEFAULT_SAVE_PATH = os.path.join(
    os.environ.get("DPPO_DATA_DIR", "/tmp/dppo_data"),
    "grasp"
)

flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save the demos.")
flags.DEFINE_list("tasks", ["pick_up_cup"], "The tasks to collect.")
flags.DEFINE_string("joint_action_mode", "abs", 'Joint action mode: "vel", "abs", or "delta".')
flags.DEFINE_bool("save_video", False, "Whether to save video recordings with EE trajectory overlay.")

# Grasp mode parameters
flags.DEFINE_integer("num_modes", DEFAULT_NUM_MODES, "Number of different grasp modes for training.")
flags.DEFINE_integer("demos_per_mode", DEFAULT_DEMOS_PER_MODE, "Number of noisy demos per grasp mode.")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point sampling radius (meters).")

# Noise parameters for generating variations around each mode
flags.DEFINE_float("angle_noise", DEFAULT_ANGLE_NOISE, "Noise range for angle (radians) around base mode.")
flags.DEFINE_float("height_noise", DEFAULT_HEIGHT_NOISE, "Noise range for height (meters) around base mode.")

# Cup configuration
flags.DEFINE_integer("cup_variation", DEFAULT_CUP_VARIATION, "Cup variation (color index 0-19)")


def check_and_make(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


# ============================================================================
# Video Saving with EE Trajectory Overlay
# ============================================================================

def project_world_to_image(points_3d, camera, image_size):
    """
    Project 3D world points to 2D image using camera intrinsic matrix.
    Works with RLBench/CoppeliaSim cameras.
    """
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels=None):
    """
    Overlay EE trajectory on video frames with phase-based coloring.
    """
    import cv2  # Import here to avoid Qt initialization issues on cluster

    if len(frames) == 0 or len(ee_trace) == 0:
        return frames

    image_size = frames[0].shape[:2]
    projected = project_world_to_image(ee_trace, camera, image_size)

    # Define colors for different phases (BGR format for cv2)
    phase_colors = {
        "reach": (0, 255, 0),           # Green - shaped approach
        "descend": (0, 200, 255),       # Orange - linear descent
        "hold_grasp": (255, 255, 0),    # Cyan - hold before grasp
        "close_gripper": (255, 0, 0),   # Blue - closing gripper
        "lift": (255, 0, 255),          # Magenta - lifting
    }
    default_color = (255, 255, 255)  # White

    frames_with_overlay = []

    for frame_idx, frame in enumerate(frames):
        frame_overlay = frame.copy()
        trace_idx = min(frame_idx, len(projected) - 1)

        # Draw trajectory lines up to current frame
        for i in range(1, trace_idx + 1):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            if phase_labels is not None and i < len(phase_labels):
                color = phase_colors.get(phase_labels[i], default_color)
            else:
                color = default_color

            cv2.line(frame_overlay, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)

        # Draw current position marker
        if trace_idx < len(projected) and projected[trace_idx] is not None:
            curr_pt = projected[trace_idx]
            if 0 <= curr_pt[0] < image_size[1] and 0 <= curr_pt[1] < image_size[0]:
                cv2.circle(frame_overlay, curr_pt, 6, (255, 255, 255), -1)
                cv2.circle(frame_overlay, curr_pt, 6, (0, 0, 0), 2)

        frames_with_overlay.append(frame_overlay)

    return frames_with_overlay


def save_demo(demo, episode_path, save_video=False, ee_trace=None, camera=None, phase_labels=None):
    """
    Save low-dimensional state (and optional RGB video with trajectory overlay).
    """
    with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
        pickle.dump(demo, f)

    if save_video:
        frames = []
        for obs in demo:
            # Use front camera (with custom position set)
            if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
                frames.append(obs.front_rgb)

        if len(frames) > 0:
            # Overlay EE trajectory on frames if provided
            if ee_trace is not None and camera is not None and len(ee_trace) > 0:
                frames = overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels)

            video_path = os.path.join(episode_path, "video.mp4")
            imageio.mimwrite(video_path, frames, fps=30, codec='libx264', quality=8)
            print(f"\t  Video saved: {video_path} ({len(frames)} frames)")


def sample_noisy_mode_params(base_params, angle_noise, height_noise):
    """
    Sample a noisy variation of grasp mode parameters.

    Args:
        base_params: tuple of (angle, height)
        angle_noise: noise range for angle (radians)
        height_noise: noise range for height (meters)

    Returns:
        tuple of (noisy_angle, noisy_height)
    """
    angle, height = base_params

    noisy_angle = angle + np.random.uniform(-angle_noise, angle_noise)
    # Keep angle in [0, 2*pi] range
    noisy_angle = noisy_angle % (2 * np.pi)

    # Apply height noise (if any)
    if height_noise > 0:
        noisy_height = height + np.random.uniform(-height_noise, height_noise)
        # Keep height in reasonable range for cup rim grasp
        noisy_height = np.clip(noisy_height, 0.08, 0.14)
    else:
        noisy_height = height  # No noise, keep original

    return (noisy_angle, noisy_height)


def validate_grasp_position(demo, cup_pos, expected_angle, control_point_radius,
                            angle_tolerance_deg=20.0, radius_tolerance=0.015):
    """
    Validate that the EE grasp position is within acceptable range of expected contact point.

    Args:
        demo: list of observations from trajectory
        cup_pos: np.ndarray(3,), cup center position
        expected_angle: float, expected approach angle in radians
        control_point_radius: float, expected distance from cup center (gripper offset)
        angle_tolerance_deg: float, max deviation from expected angle in degrees
        radius_tolerance: float, max deviation from expected radius in meters

    Returns:
        tuple: (is_valid, actual_angle_deg, actual_radius, error_msg)
    """
    # Get grasp position (at step 72, end of descend phase)
    grasp_idx = min(72, len(demo) - 1)
    ee_pos = demo[grasp_idx].gripper_pose[:3]

    # Calculate actual angle and radius from cup center
    rel_x = ee_pos[0] - cup_pos[0]
    rel_y = ee_pos[1] - cup_pos[1]
    actual_angle = np.arctan2(rel_y, rel_x)
    actual_radius = np.sqrt(rel_x**2 + rel_y**2)

    actual_angle_deg = np.degrees(actual_angle) % 360
    expected_angle_deg = np.degrees(expected_angle) % 360

    # Calculate angle difference (handle wraparound)
    angle_diff = abs(actual_angle_deg - expected_angle_deg)
    if angle_diff > 180:
        angle_diff = 360 - angle_diff

    # Check radius
    radius_diff = abs(actual_radius - control_point_radius)

    # Validation
    is_valid = True
    error_msg = None

    if angle_diff > angle_tolerance_deg:
        is_valid = False
        error_msg = f"Angle deviation {angle_diff:.1f}° > {angle_tolerance_deg}° tolerance"
    elif radius_diff > radius_tolerance:
        is_valid = False
        error_msg = f"Radius deviation {radius_diff*100:.1f}cm > {radius_tolerance*100:.1f}cm tolerance"

    return is_valid, actual_angle_deg, actual_radius, error_msg


def run_training_split(task_env, cfg, variation_path, canonical_params, cup_pos, cup_ori):
    """
    Generate training demos with multiple grasp modes and noisy variations per mode.

    IMPORTANT: Retries until we get EXACTLY demos_per_mode successful demos per mode.
    If a demo fails (IK error or task failure), resample noise and retry.
    """
    num_modes = min(cfg.num_modes, len(canonical_params))
    demos_per_mode = cfg.demos_per_mode
    max_attempts_per_demo = 50  # Max attempts to get a single successful demo

    print(f"\n{'='*70}")
    print(f"Generating TRAIN split")
    print(f"  Modes: {num_modes}, Demos per mode: {demos_per_mode}")
    print(f"  Fixed cup position: {FIXED_CUP_POSITION}")
    print(f"  Robot HOME joints: {HOME_JOINTS}")
    print(f"  Control point radius: {cfg.control_point_radius}m")
    print(f"  Noise: angle={cfg.angle_noise}rad ({np.degrees(cfg.angle_noise):.1f}deg), height={cfg.height_noise}m")
    print(f"  Max attempts per demo: {max_attempts_per_demo}")
    print(f"{'='*70}\n")

    episodes_path = os.path.join(variation_path, "train", "episodes")
    check_and_make(episodes_path)

    all_metadata = []
    episode_idx = 0

    # Get front camera for video overlay
    custom_cam = task_env._scene._cam_front if cfg.save_video else None

    for mode_idx in range(num_modes):
        base_params = canonical_params[mode_idx]
        base_angle, base_height = base_params

        print(f"\n--- Mode {mode_idx}/{num_modes} ---")
        print(f"    Base params: angle={np.degrees(base_angle):.1f}deg, height={base_height:.4f}m")

        successful_demos_in_mode = 0
        demo_idx_in_mode = 0  # Track which demo slot we're filling

        while successful_demos_in_mode < demos_per_mode:
            print(f"  Demo {successful_demos_in_mode}/{demos_per_mode} (Episode {episode_idx})")

            attempt = 0
            success_this_demo = False

            while attempt < max_attempts_per_demo and not success_this_demo:
                attempt += 1

                # First attempt of first demo uses base params (no noise)
                # All retries and subsequent demos use noise
                if successful_demos_in_mode == 0 and attempt == 1:
                    current_params = base_params
                else:
                    current_params = sample_noisy_mode_params(
                        base_params,
                        cfg.angle_noise,
                        cfg.height_noise
                    )

                current_angle, current_height = current_params

                if attempt > 1:
                    print(f"    Retry {attempt}/{max_attempts_per_demo}: "
                          f"angle={np.degrees(current_angle):.1f}deg")

                try:
                    # Reset robot to default position before reset
                    reset_robot_to_default(task_env)

                    # Reset environment
                    task_env.reset()

                    # Set cup to fixed position (CRITICAL for reproducibility)
                    if FIXED_CUP_POSITION is not None:
                        set_cup_position(task_env, FIXED_CUP_POSITION)

                    # Get current cup position and verify it's at the fixed position
                    current_cup_pos, current_cup_ori = get_cup_position(task_env)
                    cup_drift = np.linalg.norm(current_cup_pos[:2] - FIXED_CUP_POSITION[:2])
                    if cup_drift > 0.005:  # 5mm tolerance
                        raise RuntimeError(f"Cup drifted {cup_drift:.4f}m from fixed position")

                    # Generate trajectory
                    demo, traj_metadata = generate_grasp_trajectory(
                        task_env,
                        start_pos=np.zeros(3),  # Will be overwritten by HOME
                        cup_pos=current_cup_pos,
                        cup_ori=current_cup_ori,
                        approach_angle=current_angle,
                        grasp_height=current_height,
                        control_point_radius=cfg.control_point_radius,
                        waypoint_params=None,
                        phase_steps=None,
                        steps_per_point=5,
                    )

                    if len(demo) == 0:
                        raise RuntimeError("Empty trajectory")

                    # Check if task succeeded
                    task_success = check_task_success(task_env)

                    # Also check manual grasp success (cup lifted)
                    initial_cup_z = traj_metadata.get("initial_cup_z", current_cup_pos[2])
                    grasp_success = check_grasp_success_manual(task_env, initial_cup_z)

                    # IMPORTANT: Only count as success if task actually succeeded
                    if not task_success:
                        print(f"\t  Task failed (grasp not successful), retrying...")
                        continue

                    # Validate grasp position is within expected range
                    # NOTE: gripper_offset (0.04m) is the actual EE distance from cup center,
                    # NOT control_point_radius (0.08m) which is for pregrasp
                    gripper_offset = 0.04  # This matches grasp_utils.py default
                    grasp_valid, actual_angle, actual_radius, grasp_error = validate_grasp_position(
                        demo, current_cup_pos, current_angle, gripper_offset,
                        angle_tolerance_deg=20.0,  # Allow ±20° deviation
                        radius_tolerance=0.015     # Allow ±1.5cm deviation from expected radius
                    )
                    if not grasp_valid:
                        print(f"\t  Grasp position invalid: {grasp_error}, retrying...")
                        continue

                    # Save demo with video overlay
                    episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                    os.makedirs(episode_path, exist_ok=True)
                    save_demo(
                        demo, episode_path, save_video=cfg.save_video,
                        ee_trace=traj_metadata.get("trace"),
                        camera=custom_cam,
                        phase_labels=traj_metadata.get("phase_labels")
                    )

                    # Save metadata
                    metadata = {
                        'mode': mode_idx,
                        'demo_in_mode': successful_demos_in_mode,
                        'approach_angle': current_angle,
                        'grasp_height': current_height,
                        'base_angle': base_angle,
                        'base_height': base_height,
                        'cup_pos': current_cup_pos.tolist(),
                        'cup_ori': current_cup_ori.tolist(),
                        'with_noise': not (successful_demos_in_mode == 0 and attempt == 1),
                        'noise_attempt': attempt,
                        'task_success': task_success,
                        'grasp_success': grasp_success,
                        # Fixed positions for evaluation reproducibility
                        'fixed_cup_position': FIXED_CUP_POSITION.tolist() if FIXED_CUP_POSITION is not None else None,
                        'home_joints': HOME_JOINTS.tolist(),
                        **{k: v for k, v in traj_metadata.items()
                           if k not in ['trace', 'phase_labels', 'gripper_states', 'waypoints']},
                    }
                    np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                    # Save EE trajectory separately for visualization
                    ee_trace = traj_metadata.get("trace")
                    if ee_trace is not None:
                        np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)
                    all_metadata.append(metadata)

                    print(f"\t  SUCCESS! Demo length: {len(demo)} steps, attempts: {attempt}")
                    episode_idx += 1
                    successful_demos_in_mode += 1
                    success_this_demo = True

                except Exception as e:
                    error_msg = str(e)
                    if attempt < max_attempts_per_demo:
                        print(f"\t  Attempt {attempt} failed: {error_msg[:50]}...")
                    else:
                        print(f"\t  FAILED after {max_attempts_per_demo} attempts: {error_msg}")
                        import traceback
                        traceback.print_exc()

            if not success_this_demo:
                print(f"\t  WARNING: Could not generate demo after {max_attempts_per_demo} attempts!")
                print(f"\t  Skipping this demo slot for mode {mode_idx}")
                break  # Move to next mode if we can't generate more demos

        print(f"  Mode {mode_idx} complete: {successful_demos_in_mode}/{demos_per_mode} demos")

    # Save all metadata for training split
    np.save(os.path.join(variation_path, "train", "train_metadata.npy"), all_metadata)

    print(f"\n  Generated {episode_idx} episodes")
    print(f"  Saved metadata to train_metadata.npy")

    return all_metadata


def run_test_split(task_env, cfg, variation_path, cup_pos, cup_ori):
    """
    Generate test demos with novel angles not in training (e.g., 45 degrees).

    This tests generalization to unseen approach angles.
    """
    # Test angles: 45, 135, 225, 315 degrees (between training angles)
    test_angles_deg = [45]  # Start with just 45 degrees
    test_height = GRASP_HEIGHTS[0]  # Use same height as training

    print(f"\n{'='*70}")
    print(f"Generating TEST split (novel angles)")
    print(f"  Test angles: {test_angles_deg} degrees")
    print(f"  Test height: {test_height}m")
    print(f"  Fixed cup position: {FIXED_CUP_POSITION}")
    print(f"{'='*70}\n")

    episodes_path = os.path.join(variation_path, "test", "episodes")
    check_and_make(episodes_path)

    all_metadata = []
    episode_idx = 0
    max_attempts = 50

    # Get front camera for video overlay
    custom_cam = task_env._scene._cam_front if cfg.save_video else None

    for angle_deg in test_angles_deg:
        angle_rad = np.radians(angle_deg)
        print(f"\n--- Test angle: {angle_deg}° ---")

        attempt = 0
        success = False

        while attempt < max_attempts and not success:
            attempt += 1

            try:
                # Reset robot to default position
                reset_robot_to_default(task_env)
                task_env.reset()

                # Set cup to fixed position
                if FIXED_CUP_POSITION is not None:
                    set_cup_position(task_env, FIXED_CUP_POSITION)

                # Get current cup position and verify it's at the fixed position
                current_cup_pos, current_cup_ori = get_cup_position(task_env)
                cup_drift = np.linalg.norm(current_cup_pos[:2] - FIXED_CUP_POSITION[:2])
                if cup_drift > 0.005:  # 5mm tolerance
                    raise RuntimeError(f"Cup drifted {cup_drift:.4f}m from fixed position")

                # Generate trajectory
                demo, traj_metadata = generate_grasp_trajectory(
                    task_env,
                    start_pos=np.zeros(3),
                    cup_pos=current_cup_pos,
                    cup_ori=current_cup_ori,
                    approach_angle=angle_rad,
                    grasp_height=test_height,
                    control_point_radius=cfg.control_point_radius,
                    waypoint_params=None,
                    phase_steps=None,
                    steps_per_point=5,
                )

                if len(demo) == 0:
                    raise RuntimeError("Empty trajectory")

                task_success = check_task_success(task_env)
                initial_cup_z = traj_metadata.get("initial_cup_z", current_cup_pos[2])
                grasp_success = check_grasp_success_manual(task_env, initial_cup_z)

                if not task_success:
                    print(f"  Attempt {attempt}: Task failed, retrying...")
                    continue

                # Validate grasp position is within expected range
                gripper_offset = 0.04  # This matches grasp_utils.py default
                grasp_valid, actual_angle, actual_radius, grasp_error = validate_grasp_position(
                    demo, current_cup_pos, angle_rad, gripper_offset,
                    angle_tolerance_deg=20.0,
                    radius_tolerance=0.015
                )
                if not grasp_valid:
                    print(f"  Attempt {attempt}: Grasp position invalid: {grasp_error}, retrying...")
                    continue

                # Save demo
                episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                os.makedirs(episode_path, exist_ok=True)
                save_demo(
                    demo, episode_path, save_video=cfg.save_video,
                    ee_trace=traj_metadata.get("trace"),
                    camera=custom_cam,
                    phase_labels=traj_metadata.get("phase_labels")
                )

                # Save metadata
                metadata = {
                    'mode': -1,  # Test mode
                    'demo_in_mode': 0,
                    'approach_angle': angle_rad,
                    'grasp_height': test_height,
                    'test_angle_deg': angle_deg,
                    'cup_pos': current_cup_pos.tolist(),
                    'cup_ori': current_cup_ori.tolist(),
                    'task_success': task_success,
                    'grasp_success': grasp_success,
                    'fixed_cup_position': FIXED_CUP_POSITION.tolist() if FIXED_CUP_POSITION is not None else None,
                    'home_joints': HOME_JOINTS.tolist(),
                    **{k: v for k, v in traj_metadata.items()
                       if k not in ['trace', 'phase_labels', 'gripper_states', 'waypoints']},
                }
                np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                ee_trace = traj_metadata.get("trace")
                if ee_trace is not None:
                    np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)
                all_metadata.append(metadata)

                print(f"  SUCCESS! Test demo at {angle_deg}° saved (attempt {attempt})")
                episode_idx += 1
                success = True

            except Exception as e:
                print(f"  Attempt {attempt} failed: {str(e)[:50]}...")

        if not success:
            print(f"  FAILED to generate test demo at {angle_deg}° after {max_attempts} attempts")

    # Save test metadata
    np.save(os.path.join(variation_path, "test", "test_metadata.npy"), all_metadata)

    print(f"\n  Generated {episode_idx} test episodes")
    return all_metadata


def main(argv):
    print(f"{'='*70}")
    print("GRASP (PICK_UP_CUP) DATASET GENERATOR")
    print(f"{'='*70}")
    print(f"Save path: {FLAGS.save_path}")
    print(f"Joint action mode: {FLAGS.joint_action_mode}")
    print(f"Grasp modes: {FLAGS.num_modes}")
    print(f"Demos per mode: {FLAGS.demos_per_mode}")
    print(f"Total episodes: {FLAGS.num_modes * FLAGS.demos_per_mode}")
    print(f"Control point radius: {FLAGS.control_point_radius}m")
    print(f"Cup variation: {FLAGS.cup_variation}")
    print(f"Fixed cup position: {FIXED_CUP_POSITION}")
    print(f"Angle noise: {np.degrees(FLAGS.angle_noise):.1f} deg")
    print(f"Save video: {FLAGS.save_video}")
    print(f"{'='*70}\n")

    # Generate canonical mode parameters
    canonical_params = generate_canonical_mode_params()
    print(f"Generated {len(canonical_params)} canonical mode parameters:")
    print(f"  Approach angles: {APPROACH_ANGLES_DEG} degrees")
    print(f"  Grasp heights: {GRASP_HEIGHTS} meters")

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    if FLAGS.save_video:
        # Use front camera with custom position for good view
        obs_config.front_camera.rgb = True
        obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    if FLAGS.joint_action_mode == "abs":
        ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                            -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
        ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                              5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

        class CustomMoveArmThenGripper(MoveArmThenGripper):
            def action_bounds(self):
                return (ACT_MIN, ACT_MIN + ACT_RANGE)

        action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    else:
        action_mode = MoveArmThenGripper(JointVelocity(), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    # Get task
    task_class = task_file_to_task_class("pick_up_cup")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.cup_variation)

    # Initialize episode and get cup position
    descriptions, obs = task_env.reset()
    cup_pos, cup_ori = get_cup_position(task_env)
    print(f"\nCup position: {cup_pos}")
    print(f"Cup orientation: {cup_ori}")

    # Create output directory
    variation_path = os.path.join(
        FLAGS.save_path,
        f"variation{FLAGS.cup_variation}"
    )
    check_and_make(variation_path)

    # Get camera and set custom position for visualization
    front_cam = task_env._scene._cam_front
    front_cam.set_position(CAMERA_POSITION)
    front_cam.set_orientation(CAMERA_ORIENTATION)

    # Get the camera matrix after positioning
    camera_matrix = front_cam.get_matrix()[:3, :3]
    camera_fov = front_cam.get_perspective_angle()

    # Save canonical parameters and configuration
    np.save(os.path.join(variation_path, "canonical_params.npy"), canonical_params)
    config = {
        'cup_variation': FLAGS.cup_variation,
        'control_point_radius': FLAGS.control_point_radius,
        'num_modes': FLAGS.num_modes,
        'demos_per_mode': FLAGS.demos_per_mode,
        'angle_noise': FLAGS.angle_noise,
        'height_noise': FLAGS.height_noise,
        'cup_pos': cup_pos.tolist(),
        'cup_ori': cup_ori.tolist(),
        'approach_angles_deg': APPROACH_ANGLES_DEG,
        'grasp_heights': GRASP_HEIGHTS,
        'phase_steps': DEFAULT_PHASE_STEPS,
        'camera_position': CAMERA_POSITION,
        'camera_orientation': CAMERA_ORIENTATION,
        'camera_matrix': camera_matrix.tolist(),
        'camera_fov': camera_fov,
        # CRITICAL: Fixed positions for evaluation reproducibility
        'fixed_cup_position': FIXED_CUP_POSITION.tolist() if FIXED_CUP_POSITION is not None else None,
        'home_joints': HOME_JOINTS.tolist(),
    }
    np.save(os.path.join(variation_path, "config.npy"), config)

    # Create config namespace
    cfg = SimpleNamespace(
        save_video=FLAGS.save_video,
        control_point_radius=FLAGS.control_point_radius,
        angle_noise=FLAGS.angle_noise,
        height_noise=FLAGS.height_noise,
        cup_variation=FLAGS.cup_variation,
        num_modes=FLAGS.num_modes,
        demos_per_mode=FLAGS.demos_per_mode,
    )

    # Generate training split
    train_metadata = run_training_split(
        task_env, cfg, variation_path, canonical_params, cup_pos, cup_ori
    )

    # Generate test split (45 degree novel angle)
    test_metadata = run_test_split(
        task_env, cfg, variation_path, cup_pos, cup_ori
    )

    # Summary
    n_train = len(train_metadata)
    n_train_success = sum(1 for m in train_metadata if m.get("task_success", False))
    n_test = len(test_metadata)
    n_test_success = sum(1 for m in test_metadata if m.get("task_success", False))

    print(f"\n{'='*70}")
    print("DATA GENERATION COMPLETE")
    print(f"  Training episodes: {n_train} (all successful by design)")
    print(f"  Test episodes: {n_test} (success: {n_test_success}/{n_test})")
    print(f"  Total: {n_train + n_test} episodes")
    print(f"  Data saved to: {variation_path}")
    print(f"{'='*70}\n")

    rlbench_env.shutdown()


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
