# filename: utils.py
"""
Shared utilities for dataset generation with multiple targets.

Contains:
  - Common flags definition
  - Environment setup (observation config, action mode)
  - Target grid generation (3D cube for train, interleaved for eval)
  - Demo saving utilities
  - Trajectory generation (guarantees exactly num_steps observations)
  - Curve functions
"""

import os
import pickle
import numpy as np
import imageio

from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity, JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

from absl import flags

FLAGS = flags.FLAGS


# ============================================================================
# Common Flags (shared by all generator scripts)
# ============================================================================

def define_common_flags():
    """Define flags common to all dataset generators. Call once at module load."""
    # Default save path: use DPPO_DATA_DIR env var, or fall back to local data/ directory
    default_save_path = os.environ.get(
        "DPPO_DATA_DIR",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
    )
    flags.DEFINE_string("save_path", default_save_path, "Where to save the demos.")
    flags.DEFINE_list("tasks", [], "The tasks to collect.")
    flags.DEFINE_string("joint_action_mode", "abs",
                        'Joint action mode: "vel", "abs", or "delta".')
    flags.DEFINE_bool("save_video", False, "Whether to save video recordings.")
    flags.DEFINE_integer("num_steps", 64, "Number of steps per trajectory (exact).")

    # Multi-target flags (3D cube)
    flags.DEFINE_float(
        "target_xy_span", 0.15,
        "Total XY span (meters) around the base target for sampling train targets."
    )
    flags.DEFINE_float(
        "target_z_span", 0.15,
        "Total Z span (meters) around the base target for sampling train targets."
    )
    flags.DEFINE_integer(
        "train_targets_per_axis", 3,
        "Number of training target positions per axis (3D grid over XYZ)."
    )


# ============================================================================
# Constants
# ============================================================================

# Fixed start and base target positions
DEFAULT_START_POS = np.array([0.27851078, -0.00815551, 1.4719069])
DEFAULT_BASE_TARGET_POS = np.array([0.36239344, -0.12145063, 1.11076617])

# Panda home joint configuration
HOME_JOINTS = np.array([0, 0, 0, -1.57, 0, 1.57, 0.785])

# Joint limits for absolute position mode
ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                    -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                      5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)


# ============================================================================
# File / Directory Utilities
# ============================================================================

def check_and_make(dir_path: str):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


# ============================================================================
# Trajectory Visualization on Video
# ============================================================================

def project_world_to_image(points_3d, camera, image_size):
    """
    Project 3D world points to 2D image using camera intrinsic matrix.
    Works with RLBench/CoppeliaSim cameras.

    Args:
        points_3d: np.ndarray of shape (N, 3) or list of 3D points
        camera: VisionSensor from PyRep
        image_size: tuple (height, width)

    Returns:
        list of (u, v) tuples or None for invalid projections
    """
    # Get camera pose
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]  # Get rotation matrix

    # Camera intrinsics - try to get FOV from camera
    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        # Transform point to camera frame
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel

        # In CoppeliaSim/PyRep, camera looks along +Z axis in its local frame
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            # Standard pinhole projection
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels=None):
    """
    Overlay EE trajectory on video frames.

    Args:
        frames: list of RGB images (H, W, 3)
        ee_trace: np.ndarray of shape (N, 3), EE positions
        camera: VisionSensor for projection
        phase_labels: list of str, phase name for each trace point

    Returns:
        list of frames with trajectory overlay
    """
    import cv2

    if len(frames) == 0 or len(ee_trace) == 0:
        return frames

    image_size = frames[0].shape[:2]  # (H, W)

    # Project all trajectory points to 2D
    projected = project_world_to_image(ee_trace, camera, image_size)

    # Define colors for different phases (BGR format for cv2)
    phase_colors = {
        "reach": (0, 255, 0),         # Green - learned
        "descend": (255, 165, 0),     # Orange
        "grasp": (255, 0, 0),         # Blue
        "lift": (255, 255, 0),        # Cyan
        "carry": (0, 255, 255),       # Yellow - learned
        "descend_release": (255, 0, 255),  # Magenta
        "release": (128, 0, 128),     # Purple
    }
    default_color = (255, 255, 255)  # White

    frames_with_overlay = []

    for frame_idx, frame in enumerate(frames):
        # Copy frame to draw on
        frame_overlay = frame.copy()

        # Draw the full trajectory up to current frame
        # Map frame index to trace index (they might not be 1:1)
        trace_idx = min(frame_idx, len(projected) - 1)

        # Draw trajectory lines
        for i in range(1, trace_idx + 1):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue

            # Check bounds
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            # Get color based on phase
            if phase_labels is not None and i < len(phase_labels):
                color = phase_colors.get(phase_labels[i], default_color)
            else:
                color = default_color

            # Draw line segment
            cv2.line(frame_overlay, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)

        # Draw current position marker
        if trace_idx < len(projected) and projected[trace_idx] is not None:
            curr_pt = projected[trace_idx]
            if 0 <= curr_pt[0] < image_size[1] and 0 <= curr_pt[1] < image_size[0]:
                cv2.circle(frame_overlay, curr_pt, 6, (255, 255, 255), -1)  # White filled
                cv2.circle(frame_overlay, curr_pt, 6, (0, 0, 0), 2)  # Black outline

        frames_with_overlay.append(frame_overlay)

    return frames_with_overlay


def save_demo(demo, episode_path, save_video=False, ee_trace=None, camera=None, phase_labels=None):
    """
    Save low-dimensional state (and optional RGB video) for a single demo.

    Args:
        demo: list of observations
        episode_path: str, directory to save to
        save_video: bool, whether to save video
        ee_trace: np.ndarray of shape (N, 3), EE positions to overlay on video
        camera: VisionSensor, camera for 3D->2D projection (required if ee_trace provided)
        phase_labels: list of str, phase name for each trace point (for coloring)
    """
    with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
        pickle.dump(demo, f)

    if save_video:
        frames = []
        for obs in demo:
            if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
                frames.append(obs.front_rgb)

        if len(frames) > 0:
            # If EE trace provided, overlay it on frames
            if ee_trace is not None and camera is not None and len(ee_trace) > 0:
                frames = overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels)

            video_path = os.path.join(episode_path, "video.mp4")
            imageio.mimwrite(
                video_path, frames, fps=30, codec='libx264', quality=8
            )
            print(f"\t  Video saved: {video_path} ({len(frames)} frames)")


# ============================================================================
# Environment Setup
# ============================================================================

def create_obs_config(save_video=False, image_size=256, include_task_low_dim=True):
    """
    Create observation config for RLBench environment.

    Args:
        save_video: bool, whether to enable front camera RGB for video saving
        image_size: int, size of camera images if save_video is True
        include_task_low_dim: bool, whether to include task_low_dim_state
            (contains object/target poses for goal-conditioned learning)

    Returns:
        ObservationConfig
    """
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True

    # Include task-specific low-dim state (object/target positions)
    # This is needed for goal-conditioned learning
    if include_task_low_dim:
        obs_config.task_low_dim_state = True

    if save_video:
        obs_config.front_camera.rgb = True
        obs_config.front_camera.image_size = [image_size, image_size]

    return obs_config


def create_action_mode(joint_action_mode="abs"):
    """Create action mode for RLBench environment."""
    if joint_action_mode == "abs":
        class CustomMoveArmThenGripper(MoveArmThenGripper):
            def action_bounds(self):
                return (ACT_MIN, ACT_MIN + ACT_RANGE)
        return CustomMoveArmThenGripper(JointPosition(True), Discrete())
    else:
        return MoveArmThenGripper(JointVelocity(), Discrete())


def get_task_classes(task_names):
    """Get task classes from task names, validating they exist."""
    task_files = [
        t.replace(".py", "")
        for t in os.listdir(task.TASKS_PATH)
        if t != "__init__.py" and t.endswith(".py")
    ]

    if len(task_names) > 0:
        for t in task_names:
            if t not in task_files:
                raise ValueError(f"Task {t} not recognised!")
        task_files = task_names

    return [task_file_to_task_class(t) for t in task_files]


# ============================================================================
# Target Grid Generation (3D Cube)
# ============================================================================

def generate_3d_target_grid(base_target_pos, xy_span, z_span, targets_per_axis):
    """
    Generate a 3D cube grid of target positions.

    If z_span == 0, generates a 2D grid (XY only) to avoid duplicate positions.

    Args:
        base_target_pos: np.ndarray(3,), center of the grid
        xy_span: float, total span in X and Y directions (meters)
        z_span: float, total span in Z direction (meters)
        targets_per_axis: int, number of targets per axis (e.g., 3 -> 3x3x3 = 27 targets)

    Returns:
        train_targets: list of np.ndarray(3,), training target positions
        eval_targets: list of np.ndarray(3,), evaluation target positions (interleaved)
    """
    axis_xy = np.linspace(-xy_span / 2.0, xy_span / 2.0, targets_per_axis)

    # Handle z_span == 0: use single Z value to avoid duplicates
    if z_span == 0 or abs(z_span) < 1e-6:
        axis_z = np.array([0.0])  # Single Z value
    else:
        axis_z = np.linspace(-z_span / 2.0, z_span / 2.0, targets_per_axis)

    # Train targets: full grid (3D or 2D depending on z_span)
    train_targets = []
    for dx in axis_xy:
        for dy in axis_xy:
            for dz in axis_z:
                train_targets.append(
                    base_target_pos + np.array([dx, dy, dz])
                )

    # Eval targets: midpoints between train grid points (interleaved grid)
    eval_targets = []
    if targets_per_axis > 1:
        axis_xy_eval = 0.5 * (axis_xy[:-1] + axis_xy[1:])
        # For Z: use single value if z_span == 0, otherwise interleave
        if z_span == 0 or abs(z_span) < 1e-6:
            axis_z_eval = np.array([0.0])
        else:
            axis_z_eval = 0.5 * (axis_z[:-1] + axis_z[1:])
        for dx in axis_xy_eval:
            for dy in axis_xy_eval:
                for dz in axis_z_eval:
                    eval_targets.append(
                        base_target_pos + np.array([dx, dy, dz])
                    )
    else:
        # Degenerate case: only one train target
        offsets = [
            np.array([0.02, 0.0, 0.0]),
            np.array([-0.02, 0.0, 0.0]),
            np.array([0.0, 0.02, 0.0]),
            np.array([0.0, -0.02, 0.0]),
        ]
        if z_span != 0 and abs(z_span) >= 1e-6:
            offsets.extend([
                np.array([0.0, 0.0, 0.02]),
                np.array([0.0, 0.0, -0.02]),
            ])
        eval_targets = [base_target_pos + off for off in offsets]

    return train_targets, eval_targets


def print_target_grid_info(train_targets, eval_targets, xy_span, z_span):
    """Print information about the generated target grid."""
    print(f"\n{'='*60}")
    print("TARGET GRID (3D Cube)")
    print(f"{'='*60}")
    print(f"  XY span: +/-{xy_span/2*1000:.1f}mm")
    print(f"  Z span: +/-{z_span/2*1000:.1f}mm")
    print(f"  Train targets: {len(train_targets)}")
    print(f"  Eval targets: {len(eval_targets)}")

    if len(train_targets) > 0:
        train_arr = np.array(train_targets)
        print(f"  Train X range: [{train_arr[:, 0].min():.3f}, {train_arr[:, 0].max():.3f}]")
        print(f"  Train Y range: [{train_arr[:, 1].min():.3f}, {train_arr[:, 1].max():.3f}]")
        print(f"  Train Z range: [{train_arr[:, 2].min():.3f}, {train_arr[:, 2].max():.3f}]")

    if len(eval_targets) > 0:
        eval_arr = np.array(eval_targets)
        print(f"  Eval X range: [{eval_arr[:, 0].min():.3f}, {eval_arr[:, 0].max():.3f}]")
        print(f"  Eval Y range: [{eval_arr[:, 1].min():.3f}, {eval_arr[:, 1].max():.3f}]")
        print(f"  Eval Z range: [{eval_arr[:, 2].min():.3f}, {eval_arr[:, 2].max():.3f}]")
    print(f"{'='*60}\n")


# ============================================================================
# Trajectory Generation
# ============================================================================

def move_robot_to_start(task_env, start_pos):
    """
    Move robot to start position.

    Returns:
        desired_ori: orientation to maintain during trajectory

    Raises:
        RuntimeError: if robot fails to reach start position
    """
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    # Go to home configuration first
    robot.arm.set_joint_positions(HOME_JOINTS)
    for _ in range(30):
        task_env._scene.pyrep.step()

    # Solve IK and move to start
    current_ori = tip.get_orientation()
    start_joints = robot.arm.solve_ik(start_pos, euler=current_ori)
    robot.arm.set_joint_positions(start_joints)
    for _ in range(30):
        task_env._scene.pyrep.step()

    # Verify position
    current_pos = tip.get_position()
    pos_error = np.linalg.norm(current_pos - start_pos)
    if pos_error > 0.01:
        raise RuntimeError(
            f"Failed to reach start position (error: {pos_error*1000:.1f}mm)"
        )

    task_env._scene.pyrep.step()
    desired_ori = tip.get_orientation()

    return desired_ori


def execute_trajectory_step(task_env, target_ee_pos, desired_ori):
    """
    Execute one step of trajectory, solving IK and stepping simulation.

    Returns:
        obs: observation after the step
        joint_positions: the commanded joint positions

    Raises:
        IKError or ConfigurationPathError: if IK fails
    """
    robot = task_env._scene.robot

    joint_positions = robot.arm.solve_ik(target_ee_pos, euler=desired_ori)
    robot.arm.set_joint_target_positions(joint_positions)

    for _ in range(5):
        task_env._scene.pyrep.step()
        task_env._scene.task.step()

    obs = task_env._scene.get_observation()

    if not hasattr(obs, 'misc'):
        obs.misc = {}
    obs.misc['joint_position_action'] = np.concatenate(
        [joint_positions, np.array([obs.gripper_open])]
    )

    return obs, joint_positions


def generate_trajectory(task_env, start_pos, target_pos, num_steps, curve_fn):
    """
    Generate a trajectory with exactly num_steps observations.

    Args:
        task_env: RLBench task environment
        start_pos: np.ndarray(3,), start EE position
        target_pos: np.ndarray(3,), target EE position
        num_steps: int, exact number of observations to generate
        curve_fn: function(t) -> np.ndarray(3,), maps t in [0,1] to EE position

    Returns:
        demo: list of observations, length = num_steps

    Raises:
        RuntimeError: if trajectory generation fails (too many IK failures)
    """
    demo = []

    # Move to start position
    desired_ori = move_robot_to_start(task_env, start_pos)

    # Record initial observation (step 0)
    demo.append(task_env._scene.get_observation())

    # Generate remaining steps
    successful_steps = 1  # Already have step 0
    failed_steps = 0
    max_failures = num_steps // 4  # Allow up to 25% failures

    for i in range(1, num_steps):
        t = i / (num_steps - 1)  # t goes from 0 to 1 over num_steps points
        target_ee_pos = curve_fn(t)

        try:
            obs, _ = execute_trajectory_step(task_env, target_ee_pos, desired_ori)
            demo.append(obs)
            successful_steps += 1
        except (IKError, ConfigurationPathError):
            failed_steps += 1
            if failed_steps > max_failures:
                raise RuntimeError(
                    f"Too many IK failures: {failed_steps}/{num_steps}"
                )
            # Use previous observation as placeholder to maintain exact length
            demo.append(demo[-1])

    assert len(demo) == num_steps, f"Expected {num_steps} steps, got {len(demo)}"

    if failed_steps > 0:
        print(f"\t  IK: {successful_steps} ok, {failed_steps} failed (used prev obs)")

    return demo


# ============================================================================
# Canonical Control Point Grid (4 angles x 2 distances = 8 points)
# ============================================================================

def generate_canonical_control_point_params(n_per_axis=None):
    """
    Generate a canonical set of control point parameters.

    Parameters are (angle, distance, position_along_line):
      - angle: 4 fixed angles: 0°, 90°, 180°, 270°
      - distance: 2 values: 0.5 and 1.0 (fraction of radius)
      - position: FIXED at 0.5 (middle of start->target line)

    Structure:
      - 4 angles (0°, 90°, 180°, 270°) x 2 distances (0.5, 1.0) = 8 control points
      - position is fixed at 0.5 (middle of line)

    Args:
        n_per_axis: Ignored, kept for backward compatibility

    Returns:
        params: np.ndarray of shape (8, 3) with (angle, dist_frac, pos_frac)
    """
    angles_deg = [0.0, 90.0, 180.0, 270.0]  # 4 fixed angles
    distances = [0.5, 1.0]  # Two distance values
    pos_frac = 0.5  # Fixed position at middle

    params = []
    for dist in distances:
        for angle_deg in angles_deg:
            params.append([np.radians(angle_deg), dist, pos_frac])

    return np.array(params)


def compute_control_point_from_params(start_pos, target_pos, radius, angle, dist_frac, pos_frac):
    """
    Compute a control point given canonical parameters.

    Args:
        start_pos: np.ndarray(3,), start position
        target_pos: np.ndarray(3,), target position
        radius: float, maximum offset as fraction of path length (e.g., 0.1 = 10% of path)
        angle: float, angle around the line (radians)
        dist_frac: float, fraction of radius for offset distance (0-1)
        pos_frac: float, fraction along start->target line for base position

    Returns:
        control_point: np.ndarray(3,)

    Note:
        The offset distance is computed as: dist_frac * radius * path_length
        This ensures the same (angle, dist_frac) produces the same NORMALIZED curve shape
        regardless of the actual path length. Same CP params -> same curve in normalized space.
    """
    line_vec = target_pos - start_pos
    path_length = np.linalg.norm(line_vec)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, target_pos)

    # Base position along the line
    base_pos = start_pos + pos_frac * line_vec

    # Offset perpendicular to the line, SCALED BY PATH LENGTH
    # This ensures the normalized offset (offset / path_length) is constant
    offset_dist = dist_frac * radius * path_length
    offset = offset_dist * (np.cos(angle) * perp1 + np.sin(angle) * perp2)

    return base_pos + offset


def get_control_points_for_target(start_pos, target_pos, radius, canonical_params, indices):
    """
    Get control points for a specific target using selected indices from canonical params.

    Args:
        start_pos: np.ndarray(3,)
        target_pos: np.ndarray(3,)
        radius: float
        canonical_params: np.ndarray of shape (N, 3) from generate_canonical_control_point_params
        indices: list of int, which canonical params to use

    Returns:
        control_points: list of np.ndarray(3,)
    """
    control_points = []
    for idx in indices:
        angle, dist_frac, pos_frac = canonical_params[idx]
        cp = compute_control_point_from_params(
            start_pos, target_pos, radius, angle, dist_frac, pos_frac
        )
        control_points.append(cp)
    return control_points


def prefilter_control_points(robot_arm, start_pos, target_pos, orientation,
                              canonical_params, control_point_radius,
                              num_samples=10, require_orientation=True):
    """
    Pre-filter control points by testing IK feasibility along the Bezier curve.

    This function tests each control point by sampling points along the curve
    and checking if IK can find a solution. No actual robot movement occurs.

    Args:
        robot_arm: PyRep arm object (has solve_ik method)
        start_pos: np.ndarray(3,), start position
        target_pos: np.ndarray(3,), target position
        orientation: np.ndarray(3,), desired euler orientation (or None for position-only)
        canonical_params: np.ndarray of shape (N, 3) from generate_canonical_control_point_params
        control_point_radius: float, radius for control point offset
        num_samples: int, number of points to sample along curve for IK testing
        require_orientation: bool, if True test with orientation constraint

    Returns:
        valid_indices: list of int, indices of control points that pass IK test
        results: dict with detailed results for each control point
    """
    valid_indices = []
    results = {}

    # Sample t values along the curve (skip t=0 since we start there)
    t_values = np.linspace(0, 1, num_samples + 1)[1:]  # Skip t=0

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        # Compute control point
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Test IK at sampled points along the curve
        ik_failures = 0
        for t in t_values:
            curve_pos = parabola3D(start_pos, target_pos, control_point, t)
            try:
                if require_orientation and orientation is not None:
                    robot_arm.solve_ik(curve_pos, euler=orientation)
                else:
                    robot_arm.solve_ik(curve_pos)
            except (IKError, ConfigurationPathError):
                ik_failures += 1

        # Consider valid if all IK tests pass
        is_valid = (ik_failures == 0)
        if is_valid:
            valid_indices.append(cp_idx)

        results[cp_idx] = {
            "valid": is_valid,
            "ik_failures": ik_failures,
            "total_samples": len(t_values),
            "angle_deg": np.degrees(angle),
            "dist_frac": dist_frac,
            "pos_frac": pos_frac,
        }

    return valid_indices, results


def prefilter_control_points_fast(robot_arm, start_pos, target_pos, orientation,
                                   canonical_params, control_point_radius,
                                   require_orientation=True):
    """
    Fast pre-filter: only test the midpoint and endpoint of each curve.

    This is a quick approximation - if midpoint and endpoint are reachable,
    the full curve is likely reachable.

    Args:
        robot_arm: PyRep arm object
        start_pos, target_pos, orientation: trajectory parameters
        canonical_params: control point parameters
        control_point_radius: float
        require_orientation: bool

    Returns:
        valid_indices: list of valid control point indices
    """
    valid_indices = []

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Test midpoint (t=0.5) and endpoint (t=1.0)
        try:
            mid_pos = parabola3D(start_pos, target_pos, control_point, 0.5)
            end_pos = parabola3D(start_pos, target_pos, control_point, 1.0)

            if require_orientation and orientation is not None:
                robot_arm.solve_ik(mid_pos, euler=orientation)
                robot_arm.solve_ik(end_pos, euler=orientation)
            else:
                robot_arm.solve_ik(mid_pos)
                robot_arm.solve_ik(end_pos)

            valid_indices.append(cp_idx)
        except (IKError, ConfigurationPathError):
            pass  # This control point fails

    return valid_indices


def prefilter_control_points_with_collision(robot_arm, start_pos, target_pos, orientation,
                                             canonical_params, control_point_radius,
                                             num_samples=10, require_orientation=True):
    """
    Pre-filter control points by testing BOTH IK feasibility AND collision-free.

    This is the key fix for jitter: paths that pass IK but have collisions will
    cause the physics engine to resolve penetrations, creating jitter.

    For each control point, samples along the curve are tested:
    1. IK exists (can reach the position)
    2. No collision when robot is at that configuration

    Args:
        robot_arm: PyRep arm object (has solve_ik, check_arm_collision methods)
        start_pos: np.ndarray(3,), start position
        target_pos: np.ndarray(3,), target position
        orientation: np.ndarray(3,), desired euler orientation
        canonical_params: np.ndarray of shape (N, 3)
        control_point_radius: float, radius for control point offset
        num_samples: int, number of points to sample along curve
        require_orientation: bool, if True test with orientation constraint

    Returns:
        valid_indices: list of int, indices of control points that are IK-feasible AND collision-free
        results: dict with detailed results for each control point
    """
    valid_indices = []
    results = {}

    # Sample t values along the curve (include more samples for collision checking)
    t_values = np.linspace(0, 1, num_samples + 1)[1:]  # Skip t=0

    # Store original joint positions to restore later
    original_joints = list(robot_arm.get_joint_positions())

    for cp_idx, (angle, dist_frac, pos_frac) in enumerate(canonical_params):
        # Compute control point
        control_point = compute_control_point_from_params(
            start_pos, target_pos, control_point_radius, angle, dist_frac, pos_frac
        )

        # Test IK AND collision at sampled points along the curve
        ik_failures = 0
        collision_count = 0
        prev_joints = original_joints.copy()

        for t in t_values:
            curve_pos = parabola3D(start_pos, target_pos, control_point, t)
            try:
                # Solve IK (seeded from previous joints for continuity)
                robot_arm.set_joint_positions(prev_joints, disable_dynamics=True)

                if require_orientation and orientation is not None:
                    joint_positions = robot_arm.solve_ik_via_jacobian(curve_pos, euler=orientation)
                else:
                    joint_positions = robot_arm.solve_ik_via_jacobian(curve_pos)

                # Set robot to this configuration and check collision
                robot_arm.set_joint_positions(joint_positions, disable_dynamics=True)

                # Check collision against all collidable objects
                if robot_arm.check_arm_collision():
                    collision_count += 1

                prev_joints = list(joint_positions)

            except (IKError, ConfigurationPathError):
                ik_failures += 1

        # Valid only if NO IK failures AND NO collisions
        is_valid = (ik_failures == 0) and (collision_count == 0)
        if is_valid:
            valid_indices.append(cp_idx)

        results[cp_idx] = {
            "valid": is_valid,
            "ik_failures": ik_failures,
            "collision_count": collision_count,
            "total_samples": len(t_values),
            "angle_deg": np.degrees(angle),
            "dist_frac": dist_frac,
            "pos_frac": pos_frac,
        }

    # Restore original joint positions
    robot_arm.set_joint_positions(original_joints, disable_dynamics=True)

    return valid_indices, results


# ============================================================================
# Curve Functions
# ============================================================================

def parabola3D(S, E, C, t):
    """Quadratic Bezier curve in 3D with endpoints S, E and control point C."""
    M = 0.5 * (S + E)
    P1 = 2 * C - M
    return (1 - t) * (1 - t) * S + 2 * (1 - t) * t * P1 + t * t * E


def cubic_bezier3D(S, E, C1, C2, t):
    """Cubic Bezier in 3D with endpoints S, E and control points C1, C2."""
    one_minus = 1.0 - t
    return (one_minus ** 3) * S \
        + 3 * (one_minus ** 2) * t * C1 \
        + 3 * one_minus * (t ** 2) * C2 \
        + (t ** 3) * E


def build_local_frame(start_pos, target_pos):
    """
    Return (line_vec_norm, perp1, perp2) forming an orthonormal frame.

    Trajectory-relative perpendicular axes:
      - line_vec_norm: unit vector along start -> target
      - perp1: perpendicular to line, in the vertical plane (has Z component)
      - perp2: perpendicular to both line and perp1 (horizontal)

    The key insight: for normalized curves to overlap, the perpendicular axes
    must be defined RELATIVE to the trajectory direction, not as fixed world axes.

    We use world +Z as the "up" reference to construct perp1:
      - perp1 is the component of world +Z perpendicular to line_vec
      - perp2 = line_vec × perp1 (right-hand rule)

    This ensures:
      - angle=0° means offset "upward" relative to the trajectory
      - Same (angle, dist_frac) produces same normalized curve shape
    """
    line_vec = target_pos - start_pos
    line_vec_norm = line_vec / np.linalg.norm(line_vec)

    # World up vector
    world_up = np.array([0.0, 0.0, 1.0])

    # Project world_up onto the plane perpendicular to line_vec
    # perp1 = world_up - (world_up · line_vec) * line_vec
    dot = np.dot(world_up, line_vec_norm)
    perp1 = world_up - dot * line_vec_norm
    perp1_len = np.linalg.norm(perp1)

    if perp1_len < 1e-6:
        # line_vec is nearly vertical, use world +Y as fallback reference
        world_forward = np.array([0.0, 1.0, 0.0])
        dot = np.dot(world_forward, line_vec_norm)
        perp1 = world_forward - dot * line_vec_norm
        perp1_len = np.linalg.norm(perp1)

    perp1 = perp1 / perp1_len

    # perp2 is perpendicular to both line_vec and perp1
    perp2 = np.cross(line_vec_norm, perp1)
    perp2 = perp2 / np.linalg.norm(perp2)

    return line_vec_norm, perp1, perp2


def sample_control_point(start_pos, end_pos, radius,
                         angle_range=None, distance_range=None):
    """
    Sample a control point around the line between start and target.

    Args:
        start_pos: np.ndarray(3,), start EE position.
        end_pos: np.ndarray(3,), target EE position.
        radius: Maximum radius from the line (meters).
        angle_range: (min_angle, max_angle) in degrees around the line.
        distance_range: (min_frac, max_frac) fractions of `radius` for distance.
    """
    midpoint = 0.5 * (start_pos + end_pos)
    line_vec_norm, perp1, perp2 = build_local_frame(start_pos, end_pos)

    if angle_range is not None:
        angle = np.random.uniform(
            np.radians(angle_range[0]),
            np.radians(angle_range[1])
        )
    else:
        angle = np.random.uniform(0, 2 * np.pi)

    if distance_range is not None:
        dist = radius * np.random.uniform(distance_range[0], distance_range[1])
    else:
        dist = np.random.uniform(0, radius)

    offset = dist * (np.cos(angle) * perp1 + np.sin(angle) * perp2)
    control_point = midpoint + offset
    return control_point
