# filename: dataset_generator_pick_place_cp.py
"""
Dataset generator for pick-and-place with control-point shaping for REACH and CARRY arcs.

NEW APPROACH (from debug_reach_grasp_lift_carry_release.py):
  - Uses 7-phase trajectory with Bezier curves for REACH and CARRY
  - Hard-coded phases for DESCEND, GRASP, LIFT, DESCEND_RELEASE, RELEASE
  - Uses gripper.grasp() and gripper.release() for reliable object attachment
  - Only REACH and CARRY phases are learned by diffusion model

Tasks:
  - stack_blocks: Stack blocks on target (default)
  - meat_off_grill: Take chicken/steak off grill, place on side plate

Key design:
  - 7 phases total: REACH (64), DESCEND (8), GRASP (8), LIFT (8), CARRY (64), DESCEND_RELEASE (8), RELEASE (8)
  - Total trajectory: 168 steps
  - Gripper open/close is scripted based on phase transitions
  - Uses gripper.grasp() for reliable object attachment during grasp
  - Uses gripper.release() for reliable object detachment during release

Control Point Sampling:
  - Canonical params: 4 angles (0°, 90°, 180°, 270°) x 2 distances (0.5, 1.0) = 8 CPs
  - Position fixed at 0.5 (middle of trajectory)
  - Valid CPs pre-filtered for IK feasibility, collision, and jitter (in stack_blocks_init.npz)
  - For demo i, use CP index i for BOTH REACH and CARRY phases
  - If CP i is infeasible for CARRY, use the next valid CARRY CP (i+1, i+2, etc.)
  - Generates 8 demos (one per CP mode)

Data Saving:
  - Each episode saved to episodes/episode{idx}/low_dim_obs.pkl
  - Metadata includes: cp_idx, waypoints, phase_indices, success flag
  - Videos saved optionally to episode{idx}/video.mp4
"""

import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'custom_tasks'))

from rlbench.environment import Environment
from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

import numpy as np

from absl import app
from absl import flags

from utils import (
    # Flags
    define_common_flags, FLAGS,
    # File utilities
    check_and_make, save_demo,
    # Environment setup
    create_obs_config, create_action_mode, get_task_classes,
    # Target grid
    generate_3d_target_grid, print_target_grid_info,
    # Canonical control points
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    parabola3D,
    HOME_JOINTS,
)

from grasp_and_lift import (
    descend_and_grasp,
    lift,
    descend_and_release,
    capture_frame,
)

from pick_and_place_utils import (
    extract_demo_waypoints,
    DEFAULT_GRASP_ORIENTATION,
)

# Define common flags
define_common_flags()

# Script-specific flags
flags.DEFINE_float("control_point_radius", 0.05,
                   "Control point sampling radius (meters). Default 0.05m for pick-place.")
flags.DEFINE_float("object_xy_span", 0.04,
                   "XY span for object position variation (meters).")
flags.DEFINE_float("object_z_span", 0.0,
                   "Z span for object position variation (meters). 0 = 2D grid.")
flags.DEFINE_bool("use_precomputed_init", True,
                  "Whether to use precomputed init data from stack_blocks_init.npz")

# Noisy demo generation parameters
flags.DEFINE_integer("demos_per_mode", 10,
                     "Number of demos per CP mode (1 original + N-1 noisy). Default 10.")
flags.DEFINE_float("cp_angle_noise", 0.1745,
                   "Noise range for angle (radians) around base control point. Default 0.1745 (~10 degrees).")
flags.DEFINE_float("cp_dist_noise", 0.1,
                   "Noise range for distance fraction around base control point. Default 0.1.")
flags.DEFINE_integer("max_noise_attempts", 50,
                     "Max attempts to find a valid noisy sample before skipping. Default 50.")

# Supported tasks for pick-and-place
SUPPORTED_TASKS = ["stack_blocks", "meat_off_grill"]
DEFAULT_TASK = "stack_blocks"

# Phase configuration (matches debug_reach_grasp_lift_carry_release.py)
PHASE_STEPS = {
    "reach": 64,           # Phase 1: Bezier curve start -> pregrasp (LEARNED)
    "descend": 8,          # Phase 2: pregrasp -> grasp (linear Z)
    "grasp": 8,            # Phase 3: hold + close gripper + gripper.grasp()
    "lift": 8,             # Phase 4: grasp -> lift (linear Z)
    "carry": 64,           # Phase 5: Bezier curve lift -> prerelease (LEARNED)
    "descend_release": 8,  # Phase 6: prerelease -> release (linear Z)
    "release": 8,          # Phase 7: hold + open gripper + gripper.release()
}
TOTAL_STEPS = sum(PHASE_STEPS.values())  # 168


def generate_full_trajectory(task_env, home_pos, pregrasp_pos, grasp_pos,
                             reach_cp, orientation,
                             prerelease_pos, release_pos, carry_cp,
                             object_shape_name='stack_blocks_target0',
                             steps_per_point=5):
    """
    Generate full pick-and-place trajectory with physics.

    This is adapted from debug_reach_grasp_lift_carry_release.py.

    Phases:
    1. REACH: home → pregrasp (Bezier curve) - learned by diffusion
    2. DESCEND + GRASP: pregrasp → grasp, close gripper - hard-coded
    3. LIFT: grasp → lift position - hard-coded
    4. CARRY: lift → prerelease (Bezier curve) - learned by diffusion
    5. DESCEND_RELEASE + RELEASE: prerelease → release, open gripper - hard-coded

    Returns:
        demo: list of observations
        metadata: dict with trajectory info (includes both actual trace and intended targets)
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()
    gripper = robot.gripper

    # Lift position is same XY as pregrasp, same Z as pregrasp
    lift_pos = pregrasp_pos.copy()

    demo = []
    trace = []  # Actual EE positions after IK+physics
    intended_targets = []  # Intended Bezier curve targets (for visualization)
    phase_labels = []
    gripper_states = []
    ik_failures = 0

    prev_joints = list(robot.arm.get_joint_positions())

    # Initial observation
    demo.append(task_env._scene.get_observation())

    # ========================================================================
    # Phase 1: REACH (Bezier curve) - home → pregrasp
    # ========================================================================
    reach_steps = PHASE_STEPS["reach"]
    for i in range(reach_steps):
        t = i / max(reach_steps - 1, 1)
        target_pos = parabola3D(home_pos, pregrasp_pos, reach_cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            gripper.actuate(1.0, 0.2)  # Open

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            obs = task_env._scene.get_observation()
            if not hasattr(obs, 'misc'):
                obs.misc = {}
            obs.misc['joint_position_action'] = np.concatenate([joint_positions, [1.0]])
            demo.append(obs)

            trace.append(tip.get_position().copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)
            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError):
            ik_failures += 1
            demo.append(demo[-1])
            if len(trace) > 0:
                trace.append(trace[-1].copy())
            else:
                trace.append(home_pos.copy())
            phase_labels.append("reach")
            gripper_states.append(1.0)

    # ========================================================================
    # Phase 2 & 3: DESCEND + GRASP (hard-coded)
    # ========================================================================
    grasp_result = descend_and_grasp(
        task_env, pregrasp_pos, grasp_pos, orientation,
        object_shape_name=object_shape_name,
        descend_steps=PHASE_STEPS["descend"],
        grasp_steps=PHASE_STEPS["grasp"],
        steps_per_point=steps_per_point,
        capture_video=False,
        verbose=False
    )

    # Add observations for descend+grasp phases
    for i, (pos, phase, grip) in enumerate(zip(grasp_result['trace'],
                                                grasp_result['phase_labels'],
                                                grasp_result['gripper_states'])):
        obs = task_env._scene.get_observation()
        if not hasattr(obs, 'misc'):
            obs.misc = {}
        obs.misc['joint_position_action'] = np.concatenate([
            list(robot.arm.get_joint_positions()), [grip]
        ])
        demo.append(obs)

    trace.extend(grasp_result['trace'])
    phase_labels.extend(grasp_result['phase_labels'])
    gripper_states.extend(grasp_result['gripper_states'])
    ik_failures += grasp_result['ik_failures']

    # ========================================================================
    # Phase 4: LIFT (hard-coded)
    # ========================================================================
    lift_result = lift(
        task_env, grasp_pos, lift_pos, orientation,
        object_shape_name=object_shape_name,
        lift_steps=PHASE_STEPS["lift"],
        steps_per_point=steps_per_point,
        capture_video=False,
        verbose=False,
        prev_joints=grasp_result['prev_joints']
    )

    for i, (pos, phase, grip) in enumerate(zip(lift_result['trace'],
                                                lift_result['phase_labels'],
                                                lift_result['gripper_states'])):
        obs = task_env._scene.get_observation()
        if not hasattr(obs, 'misc'):
            obs.misc = {}
        obs.misc['joint_position_action'] = np.concatenate([
            list(robot.arm.get_joint_positions()), [grip]
        ])
        demo.append(obs)

    trace.extend(lift_result['trace'])
    phase_labels.extend(lift_result['phase_labels'])
    gripper_states.extend(lift_result['gripper_states'])
    ik_failures += lift_result['ik_failures']
    object_lifted = lift_result['object_lifted']
    prev_joints = lift_result['prev_joints']

    # ========================================================================
    # Phase 5: CARRY (Bezier curve) - lift → prerelease
    # ========================================================================
    carry_steps = PHASE_STEPS["carry"]
    for i in range(carry_steps):
        t = i / max(carry_steps - 1, 1)
        target_pos = parabola3D(lift_pos, prerelease_pos, carry_cp, t)

        try:
            robot.arm.set_joint_positions(prev_joints, disable_dynamics=True)
            joint_positions = robot.arm.solve_ik_via_jacobian(target_pos, euler=orientation)
            robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
            robot.arm.set_joint_target_velocities([0] * 7)

            gripper.actuate(0.0, 0.2)  # Closed

            for _ in range(steps_per_point):
                task_env._scene.pyrep.step()

            obs = task_env._scene.get_observation()
            if not hasattr(obs, 'misc'):
                obs.misc = {}
            obs.misc['joint_position_action'] = np.concatenate([joint_positions, [0.0]])
            demo.append(obs)

            trace.append(tip.get_position().copy())
            phase_labels.append("carry")
            gripper_states.append(0.0)
            prev_joints = list(joint_positions)

        except (IKError, ConfigurationPathError):
            ik_failures += 1
            demo.append(demo[-1])
            if len(trace) > 0:
                trace.append(trace[-1].copy())
            else:
                trace.append(lift_pos.copy())
            phase_labels.append("carry")
            gripper_states.append(0.0)

    # ========================================================================
    # Phase 6 & 7: DESCEND_RELEASE + RELEASE (hard-coded)
    # ========================================================================
    release_result = descend_and_release(
        task_env, prerelease_pos, release_pos, orientation,
        object_shape_name=object_shape_name,
        descend_steps=PHASE_STEPS["descend_release"],
        release_steps=PHASE_STEPS["release"],
        steps_per_point=steps_per_point,
        capture_video=False,
        verbose=False,
        prev_joints=prev_joints
    )

    for i, (pos, phase, grip) in enumerate(zip(release_result['trace'],
                                                release_result['phase_labels'],
                                                release_result['gripper_states'])):
        obs = task_env._scene.get_observation()
        if not hasattr(obs, 'misc'):
            obs.misc = {}
        obs.misc['joint_position_action'] = np.concatenate([
            list(robot.arm.get_joint_positions()), [grip]
        ])
        demo.append(obs)

    trace.extend(release_result['trace'])
    phase_labels.extend(release_result['phase_labels'])
    gripper_states.extend(release_result['gripper_states'])
    ik_failures += release_result['ik_failures']
    object_released = release_result['object_released']

    # Build phase indices
    phase_indices = {}
    idx = 0
    for phase_name, steps in PHASE_STEPS.items():
        phase_indices[phase_name] = (idx, idx + steps)
        idx += steps

    metadata = {
        "trace": np.array(trace),
        "phase_labels": phase_labels,
        "gripper_states": gripper_states,
        "phase_indices": phase_indices,
        "ik_failures": ik_failures,
        "object_lifted": object_lifted,
        "object_released": object_released,
    }

    return demo, metadata


def sample_noisy_cp_params(base_params, angle_noise, dist_noise):
    """
    Sample a noisy variation of control point parameters.
    Only adds noise to angle and distance. Position is kept fixed at 0.5.

    Args:
        base_params: tuple/list of (angle, dist_frac, pos_frac)
        angle_noise: noise range for angle (radians)
        dist_noise: noise range for distance fraction

    Returns:
        tuple of (noisy_angle, noisy_dist, pos_frac) - pos_frac unchanged
    """
    angle, dist_frac, pos_frac = base_params

    noisy_angle = angle + np.random.uniform(-angle_noise, angle_noise)
    # Normalize angle to [0, 2*pi]
    noisy_angle = noisy_angle % (2 * np.pi)

    noisy_dist = np.clip(dist_frac + np.random.uniform(-dist_noise, dist_noise), 0.1, 1.0)
    # Position is kept fixed at 0.5, no noise added

    return (noisy_angle, noisy_dist, pos_frac)


def run_split(tag, task_env, task_name, episodes_path, start_episode_idx,
              init_data, canonical_params, control_point_radius,
              valid_reach_indices, valid_carry_indices, save_video,
              demos_per_mode, cp_angle_noise, cp_dist_noise, max_noise_attempts):
    """
    Generate demos for one split (train/eval) with noisy variations.

    Control Point Sampling Strategy:
    - For each mode (anchor CP), generate demos_per_mode demos
    - Demo 0: use base CP params (no noise)
    - Demo 1+: sample noise, apply SAME noise to both REACH and CARRY
    - Use reject sampling: if IK fails, try another noise sample

    This generates: 8 modes × demos_per_mode demos = total demos (e.g., 80)

    Data Saving:
    - Each episode saved to: episodes_path/episode{idx}/low_dim_obs.pkl
    - Contains list of observations with joint positions, gripper state, etc.
    - Metadata saved separately with waypoints, cp_idx, success flag
    """
    check_and_make(episodes_path)

    # Load positions from init_data
    home_pos = init_data['home_pos']
    pregrasp_pos = init_data['pregrasp_pos']
    grasp_pos = init_data['grasp_pos']
    lift_pos = init_data['lift_pos']
    prerelease_pos = init_data['prerelease_pos']
    release_pos = init_data['release_pos']
    orientation = init_data['grasp_orientation']
    FIXED_OBJECT_POS = init_data['object_pos']
    TARGET_POS = init_data['target_pos']

    # Determine object shape name based on task
    if task_name == "stack_blocks":
        object_shape_name = 'stack_blocks_target0'
    elif task_name == "meat_off_grill":
        object_shape_name = 'chicken'  # or 'steak'
    else:
        object_shape_name = 'target0'

    metadata_list = []
    all_cp_indices = []  # Will store tuples of (reach_cp_idx, carry_cp_idx)
    episode_idx = start_episode_idx

    # Only use CPs that are valid for BOTH reach and carry
    valid_both = set(valid_reach_indices) & set(valid_carry_indices)
    mode_indices = sorted(list(valid_both))
    n_modes = len(mode_indices)

    print(f"\n{'='*70}")
    print(f"Generating {tag.upper()} split")
    print(f"  Modes (valid for both REACH and CARRY): {n_modes}")
    print(f"  Demos per mode: {demos_per_mode} (1 original + {demos_per_mode-1} noisy)")
    print(f"  Total target demos: {n_modes * demos_per_mode}")
    print(f"  CP noise: angle={cp_angle_noise:.2f}rad (~{np.degrees(cp_angle_noise):.1f}°), dist={cp_dist_noise:.2f}")
    print(f"  Max noise attempts: {max_noise_attempts}")
    print(f"  Mode indices: {mode_indices}")
    print(f"{'='*70}\n")

    for mode_idx, cp_idx in enumerate(mode_indices):
        base_params = canonical_params[cp_idx]
        base_angle, base_dist, base_pos = base_params

        print(f"\n--- Mode {mode_idx}/{n_modes} (CP index: {cp_idx}) ---")
        print(f"    Base params: angle={np.degrees(base_angle):.1f}°, dist={base_dist:.2f}, pos={base_pos:.2f}")

        for demo_in_mode in range(demos_per_mode):
            print(f"  Demo {demo_in_mode}/{demos_per_mode} (Episode {episode_idx})", end="")

            # Determine if we add noise
            if demo_in_mode == 0:
                # First demo: use base params (no noise), only 1 attempt
                max_attempts = 1
                add_noise = False
            else:
                # Subsequent demos: use reject sampling for noisy params
                max_attempts = max_noise_attempts
                add_noise = True

            attempt = 0
            success_this_demo = False

            while attempt < max_attempts and not success_this_demo:
                attempt += 1

                try:
                    # Reset environment with fixed seed for deterministic behavior
                    # Save random state first, then restore after reset so noise sampling varies
                    rng_state = np.random.get_state()
                    np.random.seed(42)
                    task_env.reset()
                    np.random.set_state(rng_state)  # Restore so noise sampling continues from where it was

                    # Get current CP params (with or without noise) - AFTER reset to ensure varied noise
                    if add_noise:
                        current_params = sample_noisy_cp_params(base_params, cp_angle_noise, cp_dist_noise)
                        if attempt > 1:
                            print(f"\n    Resample attempt {attempt}/{max_attempts}", end="")
                    else:
                        current_params = tuple(base_params)

                    current_angle, current_dist, current_pos = current_params

                    robot = task_env._scene.robot
                    tip = robot.arm.get_tip()
                    gripper = robot.gripper

                    # Move to HOME position
                    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
                    robot.arm.set_joint_target_velocities([0] * 7)
                    gripper.actuate(1.0, 0.2)

                    for _ in range(20):
                        task_env._scene.pyrep.step()

                    # Fix object position
                    if task_name == "stack_blocks":
                        target_block = Shape('stack_blocks_target0')
                        target_block.set_position(FIXED_OBJECT_POS)

                        # Hide other blocks
                        for i in range(1, 4):
                            try:
                                other_block = Shape(f'stack_blocks_target{i}')
                                other_block.set_position([10, 10, 0])
                            except:
                                pass
                        for i in range(4):
                            try:
                                distractor = Shape(f'stack_blocks_distractor{i}')
                                distractor.set_position([10, 10 + i*0.1, 0])
                            except:
                                pass

                        # Override success condition (only require 1 block)
                        from rlbench.backend.conditions import DetectedSeveralCondition, NothingGrasped
                        success_detector = ProximitySensor('stack_blocks_success')
                        task_env._task._success_conditions = [
                            DetectedSeveralCondition([target_block], success_detector, 1),
                            NothingGrasped(robot.gripper)
                        ]

                    # Step physics to let objects settle (matching init script)
                    for _ in range(5):
                        task_env._scene.pyrep.step()

                    # Compute control points using SAME params for both REACH and CARRY
                    reach_cp = compute_control_point_from_params(
                        home_pos, pregrasp_pos, control_point_radius,
                        current_angle, current_dist, current_pos
                    )
                    carry_cp = compute_control_point_from_params(
                        lift_pos, prerelease_pos, control_point_radius,
                        current_angle, current_dist, current_pos
                    )

                    # Generate trajectory
                    demo, traj_metadata = generate_full_trajectory(
                        task_env, home_pos, pregrasp_pos, grasp_pos,
                        reach_cp, orientation,
                        prerelease_pos, release_pos, carry_cp,
                        object_shape_name=object_shape_name,
                        steps_per_point=5
                    )

                    # Check for IK failures
                    if traj_metadata["ik_failures"] > 5:
                        raise RuntimeError(f"Too many IK failures: {traj_metadata['ik_failures']}")

                    # Check RLBench task success - reject if task failed
                    task_success, _ = task_env._task.success()
                    if not task_success:
                        raise RuntimeError("Task failed (block not placed on target)")

                    # Save demo
                    episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                    os.makedirs(episode_path, exist_ok=True)

                    # Get camera for trajectory overlay if saving video
                    front_cam = task_env._scene._cam_front if save_video else None
                    save_demo(
                        demo, episode_path, save_video=save_video,
                        ee_trace=traj_metadata.get("trace"),
                        camera=front_cam,
                        phase_labels=traj_metadata.get("phase_labels")
                    )

                    # Build metadata
                    metadata = {
                        "tag": tag,
                        "task_name": task_name,
                        "mode": mode_idx,
                        "demo_in_mode": demo_in_mode,
                        "cp_index": int(cp_idx),
                        "base_cp_params": list(base_params),
                        "current_cp_params": list(current_params),
                        "with_noise": add_noise,
                        "noise_attempt": attempt if add_noise else 0,
                        "home_pos": home_pos.tolist(),
                        "pregrasp_pos": pregrasp_pos.tolist(),
                        "grasp_pos": grasp_pos.tolist(),
                        "lift_pos": lift_pos.tolist(),
                        "prerelease_pos": prerelease_pos.tolist(),
                        "release_pos": release_pos.tolist(),
                        "object_pos": FIXED_OBJECT_POS.tolist(),
                        "target_pos": TARGET_POS.tolist(),
                        "reach_cp": reach_cp.tolist(),
                        "carry_cp": carry_cp.tolist(),
                        "orientation": orientation.tolist(),
                        "phase_steps": PHASE_STEPS,
                        "phase_indices": traj_metadata["phase_indices"],
                        "ik_failures": traj_metadata["ik_failures"],
                        "object_lifted": traj_metadata["object_lifted"],
                        "object_released": traj_metadata["object_released"],
                        "task_success": task_success,
                        "episode_idx": int(episode_idx),
                        "num_steps": len(demo),
                    }
                    metadata_list.append(metadata)
                    all_cp_indices.append((cp_idx, cp_idx))  # Same CP for both phases

                    lifted = "Y" if traj_metadata["object_lifted"] else "N"
                    released = "Y" if traj_metadata["object_released"] else "N"
                    task_ok = "Y" if task_success else "N"
                    noise_str = f"noise={add_noise}" if not add_noise else f"angle={np.degrees(current_angle):.1f}°,dist={current_dist:.2f}"
                    print(f" -> {len(demo)} steps | lift={lifted}, rel={released}, task={task_ok} | {noise_str}")

                    episode_idx += 1
                    success_this_demo = True

                except Exception as e:
                    if attempt >= max_attempts:
                        print(f" -> FAILED after {max_attempts} attempts: {e}")
                    # else: will retry with new noise

            if not success_this_demo and demo_in_mode == 0:
                # Base params failed - skip entire mode
                print(f"\n    Mode {mode_idx} SKIPPED (base params failed)")
                break

    return episode_idx, metadata_list, all_cp_indices


def main(argv):
    # Determine task
    if len(FLAGS.tasks) > 0:
        task_name = FLAGS.tasks[0]
        if task_name not in SUPPORTED_TASKS:
            print(f"WARNING: Task '{task_name}' not in supported list: {SUPPORTED_TASKS}")
    else:
        task_name = DEFAULT_TASK

    print(f"Using task: {task_name}")

    tasks = get_task_classes([task_name])

    # Generate canonical control point parameters (4 angles x 2 distances = 8)
    canonical_params = generate_canonical_control_point_params()
    n_canonical = len(canonical_params)

    print(f"\n{'='*60}")
    print("PICK-AND-PLACE DATASET GENERATOR (NEW 7-PHASE APPROACH)")
    print(f"{'='*60}")
    print(f"  Task: {task_name}")
    print(f"  ---")
    print(f"  Canonical control points: {n_canonical} (4 angles x 2 distances)")
    print(f"    Angles: 0°, 90°, 180°, 270°")
    print(f"    Distances: 0.5, 1.0")
    print(f"    Position: fixed at 0.5")
    print(f"  Control point radius: {FLAGS.control_point_radius}m")
    print(f"  ---")
    print(f"  Demos per mode: {FLAGS.demos_per_mode} (1 original + {FLAGS.demos_per_mode-1} noisy)")
    print(f"  CP noise: angle={FLAGS.cp_angle_noise:.2f}rad (~{np.degrees(FLAGS.cp_angle_noise):.1f}°), dist={FLAGS.cp_dist_noise:.2f}")
    print(f"  Max noise attempts: {FLAGS.max_noise_attempts}")
    print(f"  ---")
    print(f"  Phase configuration:")
    for phase, steps in PHASE_STEPS.items():
        learned = "(LEARNED)" if phase in ["reach", "carry"] else "(hard-coded)"
        print(f"    {phase}: {steps} steps {learned}")
    print(f"  Total steps per demo: {TOTAL_STEPS}")
    print(f"{'='*60}\n")

    # Load precomputed init data
    if FLAGS.use_precomputed_init:
        init_file = os.path.join(os.path.dirname(__file__), "stack_blocks_init.npz")
        if not os.path.exists(init_file):
            print(f"ERROR: Init file not found: {init_file}")
            print(f"Please run stack_block_init.py first to generate it.")
            return

        print(f"Loading precomputed init data from: {init_file}")
        init_data = dict(np.load(init_file))

        # Convert arrays back to numpy
        for key in init_data:
            if isinstance(init_data[key], np.ndarray) and init_data[key].ndim == 0:
                init_data[key] = float(init_data[key])

        valid_reach_indices = list(init_data['reach_valid_indices'])
        valid_carry_indices = list(init_data['carry_valid_indices'])

        print(f"  Valid REACH CPs: {len(valid_reach_indices)}")
        print(f"  Valid CARRY CPs: {len(valid_carry_indices)}")
    else:
        print("ERROR: Must use precomputed init data for stack_blocks task")
        return

    # Setup environment
    obs_config = create_obs_config(save_video=FLAGS.save_video, include_task_low_dim=True)
    action_mode = create_action_mode(FLAGS.joint_action_mode)

    rlbench_env = Environment(
        action_mode=action_mode, obs_config=obs_config, headless=True
    )
    rlbench_env.launch()

    for t_cls in tasks:
        task_env = rlbench_env.get_task(t_cls)
        task_env.set_variation(0)

        variation_path = os.path.join(
            FLAGS.save_path, task_env.get_name(), "variation0"
        )
        check_and_make(variation_path)

        # Save configuration
        np.save(os.path.join(variation_path, "canonical_cp_params.npy"), canonical_params)
        np.save(os.path.join(variation_path, "phase_steps.npy"), PHASE_STEPS)

        task_info = {
            "task_name": task_name,
            "phase_steps": PHASE_STEPS,
            "total_steps": TOTAL_STEPS,
            "control_point_radius": FLAGS.control_point_radius,
            "learned_phases": ["reach", "carry"],
            "hardcoded_phases": ["descend", "grasp", "lift", "descend_release", "release"],
            "demos_per_mode": FLAGS.demos_per_mode,
            "cp_angle_noise": FLAGS.cp_angle_noise,
            "cp_dist_noise": FLAGS.cp_dist_noise,
        }
        np.save(os.path.join(variation_path, "task_info.npy"), task_info)

        # Setup paths (train only)
        train_root = os.path.join(variation_path, "train")
        check_and_make(train_root)

        train_episodes_path = os.path.join(train_root, "episodes")
        check_and_make(train_episodes_path)

        # TRAIN (generates n_modes * demos_per_mode demos)
        print("\n" + "="*40 + " TRAIN " + "="*40)
        train_episode_idx, train_metadata, train_cp_indices = run_split(
            "train", task_env, task_name, train_episodes_path, 0,
            init_data, canonical_params, FLAGS.control_point_radius,
            valid_reach_indices, valid_carry_indices, FLAGS.save_video,
            FLAGS.demos_per_mode, FLAGS.cp_angle_noise, FLAGS.cp_dist_noise,
            FLAGS.max_noise_attempts
        )

        np.save(os.path.join(train_root, "train_metadata.npy"), train_metadata)
        np.save(os.path.join(train_root, "train_cp_indices.npy"), np.array(train_cp_indices))  # Shape: (N, 2) for (reach, carry)

        # Summary
        n_success_train = sum(1 for m in train_metadata if m.get("task_success", False))

        print(f"\n{'='*60}")
        print("COMPLETE")
        print(f"  Task: {task_name}")
        print(f"  Train demos: {len(train_metadata)} ({n_success_train} task-successful)")
        print(f"  Steps per demo: {TOTAL_STEPS}")
        print(f"  Saved to: {variation_path}")
        print(f"{'='*60}\n")

    rlbench_env.shutdown()


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
