import argparse
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import pybullet as p
import pybullet_data
from tqdm import tqdm
import time
import imageio
import torch
import yaml

from sde_lib import SDE_Brownian_manifolds
from sampling import SDE_sampler_manifolds_ULLA_EM
from manifolds.Robot import Manifold_Robot
from utils import set_seed_everywhere


# ==============================================================================
# SECTION 1: RRT Path Planning Class
# ==============================================================================
class RRT:
    class Node:
        def __init__(self, point):
            self.point = np.array(point)
            self.parent = None

    def __init__(self, start, goal, obstacles, client_id, robot_id, joint_indices, ee_link_index, step_size=0.05, max_iter=5000):
        self.start = self.Node(start)
        self.goal = self.Node(goal)
        self.obstacles = obstacles
        self.client = client_id
        self.robot_id = robot_id
        self.joint_indices = joint_indices
        self.ee_link_index = ee_link_index
        self.fixed_orientation = p.getQuaternionFromEuler([np.pi, 0, 0])
        self.step_size = step_size
        self.max_iter = max_iter
        self.nodes = [self.start]
        self.bounds = np.array([[0.2, 0.6], [-0.7, 0.7], [0.05, 0.15]])

    def find_path(self):
        for _ in range(self.max_iter):
            if np.random.rand() > 0.1:
                rnd_point = self._get_random_point()
            else:
                rnd_point = self.goal.point
            nearest_node = self._get_nearest_node(rnd_point)
            new_point = self._steer(nearest_node.point, rnd_point)
            if not self._is_collision(nearest_node.point, new_point):
                new_node = self.Node(new_point)
                new_node.parent = nearest_node
                self.nodes.append(new_node)
                if np.linalg.norm(new_node.point - self.goal.point) <= self.step_size:
                    self.goal.parent = new_node
                    return self._reconstruct_path()
        return None

    def _get_random_point(self):
        point = np.random.uniform(self.bounds[:, 0], self.bounds[:, 1])
        point[2] = 0.1 # Keep z-coordinate fixed
        return point

    def _get_nearest_node(self, point):
        distances = [np.linalg.norm(node.point - point) for node in self.nodes]
        return self.nodes[np.argmin(distances)]

    def _steer(self, from_point, to_point):
        direction = to_point - from_point
        distance = np.linalg.norm(direction)
        direction /= distance
        if distance < self.step_size:
            return to_point
        else:
            return from_point + direction * self.step_size

    def _is_collision(self, start_point, end_point):
        num_checks = 3
        for i in range(1, num_checks + 1):
            interp_point = start_point + (end_point - start_point) * (i / num_checks)
            joint_poses = p.calculateInverseKinematics(
                self.robot_id, self.ee_link_index, interp_point, self.fixed_orientation, physicsClientId=self.client)
            if joint_poses is None: return True
            for j, joint_index in enumerate(self.joint_indices):
                p.resetJointState(self.robot_id, joint_index, joint_poses[j], physicsClientId=self.client)
            p.performCollisionDetection(physicsClientId=self.client)
            for obs_id in self.obstacles:
                if len(p.getContactPoints(bodyA=self.robot_id, bodyB=obs_id, physicsClientId=self.client)) > 0:
                    return True
        return False

    def _reconstruct_path(self):
        path = []
        current = self.goal
        while current is not None:
            path.append(current.point)
            current = current.parent
        return np.array(path[::-1])

# ==============================================================================
# SECTION 2: Environment Setup, Data Generation, and Visualization
# ==============================================================================
def create_pybullet_env(connection_mode=p.DIRECT):
    client_id = p.connect(connection_mode)
    p.setAdditionalSearchPath(pybullet_data.getDataPath(), physicsClientId=client_id)
    p.setGravity(0, 0, -9.8, physicsClientId=client_id)
    p.loadURDF("plane.urdf", physicsClientId=client_id)
    robot_id = p.loadURDF("franka_panda/panda.urdf", basePosition=[0, 0, 0], useFixedBase=True, physicsClientId=client_id)
    num_joints = p.getNumJoints(robot_id, physicsClientId=client_id)
    joint_indices = [i for i in range(num_joints) if p.getJointInfo(robot_id, i, physicsClientId=client_id)[2] == p.JOINT_REVOLUTE][:7]
    ee_link_index = 11
    obstacle_positions = [[0.4, -0.3, 0.1], [0.4, 0.3, 0.1]]
    obstacle_radius = 0.1
    obstacle_ids = []
    for pos in obstacle_positions:
        shape = p.createCollisionShape(p.GEOM_SPHERE, radius=obstacle_radius, physicsClientId=client_id)
        body = p.createMultiBody(baseMass=0, baseCollisionShapeIndex=shape, basePosition=pos, physicsClientId=client_id)
        obstacle_ids.append(body)
    return client_id, robot_id, joint_indices, ee_link_index, obstacle_ids, obstacle_positions, obstacle_radius

def generate_3d_path(client, robot_info, obstacles, condition, start_pos, end_pos):
    robot_id, joint_indices, ee_link_index = robot_info
    if condition == 1: # S-shape
        mid_points = [[0.15, -0.4, 0.1], [0.65, 0.4, 0.1]]
    else: # Reverse S-shape
        mid_points = [[0.65, -0.4, 0.1], [0.15, 0.4, 0.1]]
    points = [start_pos] + mid_points + [end_pos]
    path_segments = []
    for i in range(len(points) - 1):
        rrt = RRT(points[i], points[i+1], obstacles, client, robot_id, joint_indices, ee_link_index)
        segment = rrt.find_path()
        if segment is None: return None
        path_segments.append(segment if i == 0 else segment[1:])
    return np.vstack(path_segments)

def convert_path_to_joint_angles(path_3d, client, robot_info, previous_q = None):
    """
    [MODIFIED] Calculates IK using the previous step's solution as the rest pose
    for improved stability and temporal consistency.
    """
    robot_id, _, ee_link_index = robot_info
    fixed_orientation = p.getQuaternionFromEuler([np.pi, 0, 0])
    num_joints = p.getNumJoints(robot_id, physicsClientId=client)
    joint_info = [p.getJointInfo(robot_id, i, physicsClientId=client) for i in range(num_joints)]
    
    # Get the robot's physical joint limits from PyBullet for the IK solver hint.
    ll_physical = np.array([j[8] for j in joint_info if j[2] == p.JOINT_REVOLUTE][:7])
    ul_physical = np.array([j[9] for j in joint_info if j[2] == p.JOINT_REVOLUTE][:7])

    # Define the strict custom joint limits that must be satisfied.
    ll_custom = np.array([-2.9671, -1.8326, -2.9671, -3.1416, -2.9671, -0.0873, -2.9671]) - np.pi/2
    ul_custom = np.array([2.9671, 1.8326, 2.9671, 0.0, 2.9671, 3.8223, 2.9671]) + np.pi/2
    
    jr = [u - l for l, u in zip(ll_physical, ul_physical)]
    # Define the initial rest pose for the very first point.
    initial_rest_pose = [0, -0.785, 0, -2.356, 0, 1.571, 0.785]
    
    joint_path = []
    # [NEW] This variable will hold the last successful joint configuration.
    # It's initialized with the default rest pose.
    if previous_q is not None:
        last_valid_poses = previous_q
    else:
        last_valid_poses = initial_rest_pose

    for point in path_3d:
        # [MODIFIED] Use the solution from the previous timestep as the rest pose.
        # This biases the IK solver to find a solution that is close to the
        # previous configuration, resulting in a smoother trajectory.
        joint_poses = p.calculateInverseKinematics(
            robot_id, ee_link_index, point, fixed_orientation, lowerLimits=ll_physical,
            upperLimits=ul_physical, jointRanges=jr, restPoses=last_valid_poses,
            residualThreshold=1e-5, maxNumIterations=2000, physicsClientId=client)
        
        if joint_poses is None: 
            return None # IK solver failed

        current_joint_poses = np.array(joint_poses[:7])
        
        # --- Wrap angles to satisfy custom joint limits or reject ---
        processed_poses = np.zeros(7)
        
        for i in range(7):
            angle = current_joint_poses[i]
            lower_limit = ll_custom[i]
            upper_limit = ul_custom[i]
            
            # 1. Check if the angle is already within the valid range.
            if lower_limit <= angle <= upper_limit:
                processed_poses[i] = angle
                continue

            # 2. If not, try to wrap it into the valid range.
            center = (lower_limit + upper_limit) / 2.0
            num_wraps = np.round((angle - center) / (2 * np.pi))
            wrapped_angle = angle - num_wraps * (2 * np.pi)

            # 3. Check if the new wrapped angle is within the limits.
            if lower_limit <= wrapped_angle <= upper_limit:
                processed_poses[i] = wrapped_angle
            else:
                # 4. If wrapping fails, reject the entire path.
                return None

        current_joint_poses = processed_poses.copy()
        last_valid_poses = current_joint_poses.copy()
        
        joint_path.append(processed_poses)

        # [NEW] Update the `last_valid_poses` with the latest successful solution
        # to be used in the next iteration.
        last_valid_poses = processed_poses.copy()
    # print(joint_path)
    return np.array(joint_path)

def resample_trajectory(trajectory, num_points):
    from scipy.interpolate import interp1d
    distances = np.cumsum(np.linalg.norm(np.diff(trajectory, axis=0), axis=1))
    distances = np.insert(distances, 0, 0)
    # Ensure there are at least 2 unique points for interpolation
    unique_indices = np.where(np.diff(distances) > 1e-6)[0]
    unique_indices = np.insert(unique_indices + 1, 0, 0)
    if len(unique_indices) < 2: return None
    distances = distances[unique_indices]
    trajectory = trajectory[unique_indices]
    interpolator = interp1d(distances, trajectory, axis=0, kind='slinear')
    new_distances = np.linspace(0, distances[-1], num_points)
    return interpolator(new_distances)

def verify_trajectory_z_height(trajectory_7d, client, robot_info):
    robot_id, joint_indices, ee_link_index = robot_info
    z_values = []
    for joint_angles in trajectory_7d:
        for j, joint_index in enumerate(joint_indices):
            p.resetJointState(robot_id, joint_index, joint_angles[j], physicsClientId=client)
        link_state = p.getLinkState(robot_id, ee_link_index, physicsClientId=client)
        z_values.append(link_state[0][2])
    min_z, max_z = min(z_values), max(z_values)
    # Check if all z-values are very close to the target 0.1
    if abs(min_z - 0.1) < 0.001 and abs(max_z - 0.1) < 0.001:
        return True, f"Z OK: [{min_z:.4f}, {max_z:.4f}]"
    else:
        return False, f"Z FAIL: [{min_z:.4f}, {max_z:.4f}]"

# --- MODIFIED: Final, strictest S-shape verification logic ---
def is_valid_s_shape(path_3d, condition, slack=0.02):
    """
    1. Checks for exactly one sign change relative to a dynamic centerline.
    2. Ensures the path does not "hook back" across the center in the latter half (y > 0.1).
    """
    center_x_nominal = 0.4
    y_coords = path_3d[:, 1]

    # --- Part 1: Dynamic sign change check (from previous version) ---
    if condition == 0: # Reverse S (R -> L)
        dynamic_center_x = np.where(y_coords < 0, center_x_nominal - slack, center_x_nominal + slack)
    else: # S-shape (L -> R)
        dynamic_center_x = np.where(y_coords < 0, center_x_nominal + slack, center_x_nominal - slack)

    relative_x = path_3d[:, 0] - dynamic_center_x
    signs = np.sign(relative_x)
    signs = signs[signs != 0] # Remove zeros

    if len(signs) < 10: # Path is too short or too close to center
        return False, "S-shape FAIL: Path too close to dynamic center"

    start_sign = signs[0]
    if condition == 1 and start_sign != -1:
        return False, f"S-shape FAIL: c=1 must start on the left (sign={start_sign})"
    if condition == 0 and start_sign != 1:
        return False, f"S-shape FAIL: c=0 must start on the right (sign={start_sign})"

    sign_changes = np.count_nonzero(np.diff(signs))

    if sign_changes != 1:
        return False, f"S-shape FAIL: Sign changes = {sign_changes} (expected 1)"

    # --- Part 2: New "No Hook-Back" check ---
    # Filter for the latter half of the path (y > 0.1)
    upper_path_mask = y_coords > 0.1
    if not np.any(upper_path_mask):
        return False, "S-shape FAIL: Path does not reach upper region"

    upper_path_x_coords = path_3d[upper_path_mask][:, 0]

    if condition == 0: # Reverse S (R -> L), latter part should be x < 0.4
        if np.any(upper_path_x_coords > center_x_nominal):
            return False, "S-shape FAIL: c=0 hooked back to the right"

    elif condition == 1: # S-shape (L -> R), latter part should be x > 0.4
        if np.any(upper_path_x_coords < center_x_nominal):
            return False, "S-shape FAIL: c=1 hooked back to the left"

    return True, f"S-shape OK: Sign changes = {sign_changes}"

def validate_values_with_config(config_path, calc_ll, calc_ul):
    """
    Loads a YAML config and compares its values against the calculated script values.
    Raises ValueError if there is a mismatch.
    """
    if not config_path:
        print("\n--config_path not provided, skipping validation against YAML file.")
        return

    print(f"\n--- Validating calculated values against config file: {config_path} ---")
    
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        problem_config = config.get('problem', {})
        
        # Get values from YAML
        yaml_ll = problem_config.get('joint_lower_limits')
        yaml_ul = problem_config.get('joint_upper_limits')

        # Perform comparisons
        mismatches = []
        if yaml_ll and not np.allclose(calc_ll, yaml_ll):
            mismatches.append(f"joint_lower_limits mismatch!\n  Script: {calc_ll}\n  YAML:   {yaml_ll}")
        if yaml_ul and not np.allclose(calc_ul, yaml_ul):
            mismatches.append(f"joint_upper_limits mismatch!\n  Script: {calc_ul}\n  YAML:   {yaml_ul}")

        if mismatches:
            error_message = "\n\n" + "="*60 + "\nVALIDATION ERROR: Mismatch between script and YAML config!\n" + "="*60
            for msg in mismatches:
                error_message += f"\n- {msg}"
            error_message += "\n\nPlease update the YAML file or the script to ensure consistency."
            raise ValueError(error_message)
        
        print("Validation successful: All relevant values in script and YAML file match.")
        print("-" * 55 + "\n")

    except FileNotFoundError:
        raise FileNotFoundError(f"Error: The specified config file was not found at {config_path}")
    except Exception as e:
        # Re-raise other exceptions, including the ValueError from above
        raise e


def plot_trajectories(trajectories_3d_c0, trajectories_3d_c1, obstacles_pos, radius, save_path):
    print(f"Visualizing and saving results to: {save_path}...")
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    for path in tqdm(trajectories_3d_c0, desc="Plotting c=0 (R to L)"):
        ax.plot(path[:, 0], path[:, 1], color='magenta', alpha=0.1)
    for path in tqdm(trajectories_3d_c1, desc="Plotting c=1 (L to R)"):
        ax.plot(path[:, 0], path[:, 1], color='gold', alpha=0.1)
    for pos in obstacles_pos:
        ax.add_patch(plt.Circle((pos[0], pos[1]), radius=radius, color='green'))
    plt.title("Top-Down View of 7-DoF Arm End-Effector Trajectories")
    plt.xlabel("X position"); plt.ylabel("Y position"); plt.grid(True); plt.axis('equal')
    from matplotlib.lines import Line2D
    legend_elements = [ Line2D([0], [0], color='magenta', lw=2, label='c=0 (Right to Left)'),
                        Line2D([0], [0], color='gold', lw=2, label='c=1 (Left to Right)'),
                        plt.Circle((0,0), 0.1, color='green', label='Obstacles')]
    ax.legend(handles=legend_elements)
    plt.savefig(save_path, dpi=300); plt.close()
    print("Visualization complete.")

def record_trajectory_to_mp4_headless(trajectory_7d, condition, output_dir, joint_indices):
    print(f"Generating MP4 video for trajectory c={condition} (Headless Mode)...")
    rec_client, rec_robot_id, _, _, _, _, _ = create_pybullet_env(p.DIRECT)
    mp4_path = str(output_dir / f"trajectory_c{condition}.mp4")
    width, height = 640, 480
    view_matrix = p.computeViewMatrix(
        cameraEyePosition=[0.9, 0, 0.6], cameraTargetPosition=[0.4, 0, 0.2], cameraUpVector=[0, 0, 1], physicsClientId=rec_client)
    projection_matrix = p.computeProjectionMatrixFOV(
        fov=60.0, aspect=float(width) / height, nearVal=0.1, farVal=100.0, physicsClientId=rec_client)
    with imageio.get_writer(mp4_path, fps=30) as writer:
        for joint_angles in tqdm(trajectory_7d, desc=f"Recording c={condition}"):
            for j, joint_index in enumerate(joint_indices):
                p.resetJointState(rec_robot_id, joint_index, joint_angles[j], physicsClientId=rec_client)
            _, _, rgb_pixels, _, _ = p.getCameraImage(
                width, height, view_matrix, projection_matrix, physicsClientId=rec_client)
            frame = np.array(rgb_pixels).reshape((height, width, 4))[:, :, :3]
            writer.append_data(frame)
    p.disconnect(physicsClientId=rec_client)
    print(f"MP4 video saved to: {mp4_path}")

# ==============================================================================
# SECTION 3: Prior Data Generation using SDE
# ==============================================================================
def gen_data(manifold, init, repeat_num=1, N=100000, T=1000, sigma=1.):
    """
    Generates data by running a long Langevin dynamics simulation (forward SDE).
    """
    device = 'cpu' # cpu is prefered for single chain simulation

    # Use Brownian motion on manifolds as there is no potential field like in MD
    sde = SDE_Brownian_manifolds(sigma_min=sigma, sigma_max=sigma, tau_min=1.0, tau_max=1.0, N=N, T=T)
    sde.func_b = lambda x: torch.zeros_like(x) # Zero drift

    init_tensor = torch.from_numpy(init).float().unsqueeze(0).repeat(repeat_num, 1).to(device)
    
    # Sampler settings from get_dipeptide_data.py
    manifold.epsilon = 0.1
    sde_kwargs = {'alpha': 50.0, 'gamma' : 5.0}
    
    _, _, other_dict = SDE_sampler_manifolds_ULLA_EM(sde, manifold, init_tensor,
                                                     reverse=False, keep_quiet=False, **sde_kwargs)

    # Return the full history of the SDE path, reshaping for a single chain
    return other_dict["x_hist_all"].detach().cpu().reshape(repeat_num, -1, init.shape[-1]).numpy()


def generate_robot_prior_data(args):
    """
    [MODIFIED] Loads 14D data, converts it to 7D for SDE simulation,
    then converts the result back to 14D for saving.
    """
    print(f"\nGenerating prior dataset using a long forward SDE chain...")
    set_seed_everywhere(args.seed)
    output_name = "robot_7dof_joints"
    data_dir = pathlib.Path("../data/robot_arm")
    dataset_path = data_dir / f"{output_name}_paths.npy"
    prior_path = data_dir / f"{output_name}_prior.npy"
    ref_path = data_dir / f"{output_name}_ref.npy"

    if not dataset_path.exists():
        print(f"Error: Main dataset not found at: {dataset_path}. Please run with '--mode rrt' first.")
        return

    # 1. Load the 14D main dataset and convert it back to 7D angles for physics simulation.
    print(f"Loading 14D dataset from: {dataset_path}")
    x_0_data_14d = np.load(dataset_path)

    xref = x_0_data_14d[0] # Select the first sample (in 14D) as the reference
    np.save(ref_path, xref) # Save the reference sample in 14D format
    print(f"Loaded {x_0_data_14d.shape[0]} trajectories. Selected one as reference and saved to {ref_path}.")

    # 2. Setup the manifold (works with 14D angle data)
    obstacle_info = [{'position': [0.4, -0.3, 0.205]}, {'position': [0.4, 0.3, 0.205]}] # z=0.205 to account for robot base height
    manifold = Manifold_Robot(obstacles_info=obstacle_info, time_steps=args.time_steps, safety_margin=0.00, target_ee_z=0.205, boundary_repulsion_rate=0.1)

    # 3. Generate the full, long simulation trajectory from the single 14D reference sample
    data_full = gen_data(manifold, xref.flatten(), repeat_num=1, N=args.sde_N, T=args.sde_T, sigma=args.sigma)
    
    # 4. Apply burn-in and thinning to the simulation history
    data_tensor = torch.from_numpy(data_full).float().squeeze(0)
    
    if data_tensor.shape[0] > args.burn_in:
        x_prior_flat = data_tensor[args.burn_in::args.thinning]
    else:
        print("Warning: Not enough samples for burn-in. Using all available samples after simulation.")
        x_prior_flat = data_tensor
    
    # 5. Reshape to original trajectory format (7D angles)
    x_prior_7d = x_prior_flat.reshape(-1, args.time_steps, 7).numpy()
    
    # 6. [NEW] Reparameterize the final prior data from 7D angles to 14D (cos, sin) for saving
    print("Reparameterizing final prior data to 14D for saving...")
    x_prior_cos = np.cos(x_prior_7d)
    x_prior_sin = np.sin(x_prior_7d)
    x_prior_14d = np.concatenate([x_prior_cos, x_prior_sin], axis=-1)

    print(f"Final prior dataset contains {x_prior_14d.shape[0]} samples.")
    np.save(prior_path, x_prior_14d)
    print(f"Saved 14D prior dataset to: {prior_path}")
    return

def generate_rrt_data(args):
    """
    [MODIFIED] Generates trajectory data and saves it in 14D 
    (cos(theta), sin(theta)) reparameterized format.
    """
    output_name = "robot_7dof_joints"
    data_dir = pathlib.Path(f"../data/robot_arm")
    figs_dir = pathlib.Path(f"./datasets/figs/{output_name}")
    video_dir = pathlib.Path(f"./datasets/videos/{output_name}")
    data_dir.mkdir(parents=True, exist_ok=True); figs_dir.mkdir(parents=True, exist_ok=True); video_dir.mkdir(parents=True, exist_ok=True)

    client, robot_id, joint_indices, ee_link_index, obs_ids, obs_pos, obs_radius = create_pybullet_env(p.DIRECT)
    robot_info = (robot_id, joint_indices, ee_link_index)

    start_point = [0.4, -0.5, 0.1]
    end_point = [0.4, 0.5, 0.1]
    
    # ... (IK parameter setup remains the same) ...
    num_joints = p.getNumJoints(robot_id, physicsClientId=client)
    joint_info = [p.getJointInfo(robot_id, i, physicsClientId=client) for i in range(num_joints)]
    ll = [j[8] for j in joint_info if j[2] == p.JOINT_REVOLUTE][:7]
    ul = [j[9] for j in joint_info if j[2] == p.JOINT_REVOLUTE][:7]
    ll_rounded = list(np.around(np.array(ll), 4))
    ul_rounded = list(np.around(np.array(ul), 4))
    print(f"\nJoint Lower Limits (ll): {ll_rounded}")
    print(f"Joint Upper Limits (ul): {ul_rounded}")
    print("You can copy these values into your YAML configuration file.")
    print("="*50 + "\n")
    validate_values_with_config(args.config_path, ll_rounded, ul_rounded)
    
    c0_7d, c0_3d, c1_7d, c1_3d = [], [], [], []
    for condition in [0, 1]:
        # ... (Data generation loop remains the same) ...
        print(f"\nGenerating {args.num_each} 7D joint trajectories for condition c={condition}...")
        trajectories_7d, trajectories_3d = [], []
        with tqdm(total=args.num_each) as pbar:
            while len(trajectories_7d) < args.num_each:
                path_3d = generate_3d_path(client, robot_info, obs_ids, condition, start_point, end_point)
                if path_3d is None: continue
                resampled_3d = resample_trajectory(path_3d, args.time_steps)
                if resampled_3d is None: continue
                s_shape_ok, s_shape_msg = is_valid_s_shape(resampled_3d, condition)
                if not s_shape_ok:
                    pbar.set_description(f"{s_shape_msg}, retrying...")
                    continue
                path_7d = convert_path_to_joint_angles(resampled_3d, client, robot_info)
                if path_7d is None: continue
                z_ok, z_msg = verify_trajectory_z_height(path_7d, client, robot_info)
                if z_ok:
                    trajectories_7d.append(path_7d)
                    trajectories_3d.append(resampled_3d)
                    pbar.update(1)
                    pbar.set_description(f"Found {len(trajectories_7d)}/{args.num_each}")
                else:
                    pbar.set_description(f"{z_msg}, retrying...")
        if condition == 0:
            c0_7d, c0_3d = trajectories_7d, trajectories_3d
        else:
            c1_7d, c1_3d = trajectories_7d, trajectories_3d
    p.disconnect(client)

    if args.record:
        if c0_7d: record_trajectory_to_mp4_headless(c0_7d[0], 0, video_dir, joint_indices)
        if c1_7d: record_trajectory_to_mp4_headless(c1_7d[0], 1, video_dir, joint_indices)

    all_trajectories_7d = np.concatenate([np.array(c0_7d, dtype=np.float32), np.array(c1_7d, dtype=np.float32)], axis=0)
    all_labels = np.concatenate([np.zeros(len(c0_7d)), np.ones(len(c1_7d))])

    # --- Analyze Min/Max of the original 7D angle data ---
    print("\n" + "="*60)
    print("Analyzing the true min/max joint values (in radians) from all collected samples...")
    if all_trajectories_7d.shape[0] > 0:
        actual_min_vals = np.min(all_trajectories_7d, axis=(0, 1))
        actual_max_vals = np.max(all_trajectories_7d, axis=(0, 1))
        min_list = np.around(actual_min_vals, 4).tolist()
        max_list = np.around(actual_max_vals, 4).tolist()
        print("\n--- Analysis Complete ---")
        print("Copy the following lines into your YAML file's `problem` section (these are in radians):")
        print(f"\njoint_lower_limits: {min_list}")
        print(f"joint_upper_limits: {max_list}\n")
    else:
        print("No trajectories were generated, cannot analyze limits.")
    print("="*60 + "\n")
    
    # --- [NEW] Reparameterize from 7D angles to 14D (cos, sin) for saving ---
    print("Reparameterizing trajectories from 7D angles to 14D (cos, sin) format...")
    trajectories_cos = np.cos(all_trajectories_7d)
    trajectories_sin = np.sin(all_trajectories_7d)
    all_trajectories_14d = np.concatenate([trajectories_cos, trajectories_sin], axis=-1).astype(np.float32)

    print(f"\nTotal number of trajectories generated: {len(all_trajectories_14d)}")
    if all_trajectories_14d.shape[0] > 0:
        print(f"Shape of data to be saved (14D reparameterized): {all_trajectories_14d.shape}")
        dataset_path = data_dir / f"{output_name}_paths.npy"
        labels_path = data_dir / f"{output_name}_labels.npy"
        np.save(dataset_path, all_trajectories_14d)
        np.save(labels_path, all_labels)
        print(f"14D trajectory data saved to: {dataset_path}")
        print(f"Label data saved to: {labels_path}")
        plot_trajectories(c0_3d, c1_3d, obs_pos, obs_radius, figs_dir / "all_trajectories_7dof_rrt.png")
    else:
        print("\nError: Failed to generate any valid trajectories.")

    print("\n[7-DoF RRT-based data generation pipeline complete]")

# ==============================================================================
# SECTION 4: Main Execution Logic
# ==============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate 7-DoF joint angle trajectories.")
    parser.add_argument("--mode", type=str, default="rrt", choices=["rrt", "prior"], help="Mode of operation: 'rrt' to generate trajectories, 'prior' to generate prior dataset.")
    parser.add_argument("--config_path", type=str, default='../configs/experiment/robot.yaml', help="Path to the YAML configuration file to validate against.")

    # RRT related arguments
    parser.add_argument("--num_each", type=int, default=200, help="Number of trajectories for each condition (for rrt mode).")
    parser.add_argument("--time_steps", type=int, default=10, help="Number of time steps per trajectory.")
    parser.add_argument("--record", action="store_true", help="Record the first successful trajectory of each condition to an MP4 file (for rrt mode).")

    # SDE (prior) related arguments
    parser.add_argument("--seed", type=int, default=1, help="Random seed (for prior mode).")
    parser.add_argument("--sde_N", type=int, default=20000, help="Number of SDE steps (for prior mode).")
    parser.add_argument("--sde_T", type=float, default=100.0, help="SDE end time (for prior mode).")
    parser.add_argument("--sigma", type=float, default=1.0, help="SDE sigma value (for prior mode).")
    parser.add_argument("--burn_in", type=int, default=1000, help="Burn-in samples for SDE chain (for prior mode).")
    parser.add_argument("--thinning", type=int, default=10, help="Thinning factor for SDE chain (for prior mode).")

    args = parser.parse_args()

    if args.mode == "rrt":
        generate_rrt_data(args)
    elif args.mode == "prior":
        generate_robot_prior_data(args)
    else:
        print(f"Error: Unknown mode '{args.mode}'")