"""
Generalization test for close_drawer trajectory encoder.

Tests:
1. 12 different angles at distance=1.0, position=0.5 (angle generalization)
2. 12 trajectories with angle 0 and 180, each with 6 distances (distance generalization)

Generates trajectories, encodes them, and visualizes the z embeddings.
"""
import os
import sys
import torch
import numpy as np
import pickle

# Set matplotlib backend before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Add parent directories for imports
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'make_dataset'))

from trajectory_encoder import TrajectoryVAE
from train_close_drawer_encoder import make_trajectory_relative


def load_encoder(checkpoint_path, config=None):
    """Load trained encoder from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # Use config from checkpoint or default
    if config is None:
        config = checkpoint.get('config', {})

    model = TrajectoryVAE(
        state_dim=config.get('state_dim', 22),
        action_dim=config.get('action_dim', 8),
        hidden_dim=config.get('hidden_dim', 128),
        num_layers=config.get('num_layers', 4),
        num_heads=config.get('num_heads', 4),
        latent_dim=config.get('latent_dim', 2),
        horizon=config.get('horizon', 87)
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"Loaded encoder from {checkpoint_path}")
    print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"  Loss: {checkpoint.get('loss', 'unknown'):.4f}")

    return model, config


def encode_trajectories(model, trajectories, device='cpu'):
    """
    Encode trajectories using the trained encoder.

    Args:
        model: TrajectoryVAE model
        trajectories: list of dicts with 'states' and 'actions' keys
        device: torch device

    Returns:
        z: numpy array of shape (N, latent_dim)
    """
    model = model.to(device)
    model.eval()

    all_z = []

    with torch.no_grad():
        for traj in tqdm(trajectories, desc="Encoding trajectories"):
            states = traj['states']  # (T, state_dim)
            actions = traj['actions']  # (T, action_dim)

            # Truncate to 87 steps if needed (encoder trained with horizon=87)
            # Raw demos have 88 steps but processed training data has 87
            if states.shape[0] > 87:
                states = states[:87]
                actions = actions[:87]

            states = states.unsqueeze(0).to(device)  # (1, T, state_dim)
            actions = actions.unsqueeze(0).to(device)  # (1, T, action_dim)

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(states, actions)

            # Encode
            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def get_12_colors():
    """Get 12 distinct colors for trajectories (same as EE trajectory plots)."""
    cmap = plt.cm.tab20
    colors = [cmap(i / 12) for i in range(12)]
    return colors


def visualize_angle_test(z, angles, save_path):
    """
    Visualize embeddings for angle generalization test.
    12 angles at distance=1.0, position=0.5.
    Uses same colors as EE trajectory plots.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Use 12 distinct colors (same as EE trajectory plots)
    colors = get_12_colors()

    # Left: z embeddings
    ax = axes[0]
    for i, angle in enumerate(angles):
        ax.scatter(z[i, 0], z[i, 1], c=[colors[i]], s=100, alpha=0.8,
                  edgecolors='black', linewidths=0.5, label=f'{angle:.0f}°')
        ax.annotate(f'{angle:.0f}', (z[i, 0], z[i, 1]),
                   textcoords="offset points", xytext=(5, 5), fontsize=8)

    ax.set_xlabel('z[0]', fontsize=12)
    ax.set_ylabel('z[1]', fontsize=12)
    ax.set_title('Angle Generalization Test\n(distance=1.0, position=0.5)', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper left', fontsize=7, ncol=2)

    # Right: polar plot of angle vs z
    ax_polar = fig.add_subplot(122, projection='polar')
    theta = np.radians(angles)
    r = np.sqrt(z[:, 0]**2 + z[:, 1]**2)  # Distance from origin
    for i in range(len(angles)):
        ax_polar.scatter(theta[i], r[i], c=[colors[i]], s=100, alpha=0.8,
                        edgecolors='black', linewidths=0.5)
    ax_polar.set_title('z magnitude vs angle', fontsize=12)

    # Remove the regular axes[1] since we replaced it with polar
    axes[1].remove()

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved angle test visualization to {save_path}")

    return fig


def visualize_distance_test(z, params, save_path):
    """
    Visualize embeddings for distance generalization test.
    Angle 0 and 180, each with 6 distances.
    Uses same colors as EE trajectory plots (12 distinct colors).

    params: list of (angle, distance) tuples
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    angles = np.array([p[0] for p in params])
    distances = np.array([p[1] for p in params])

    # Use 12 distinct colors (same as EE trajectory plots)
    colors = get_12_colors()

    # Left: z embeddings colored by trajectory index
    ax = axes[0]

    for i, (angle, dist) in enumerate(params):
        marker = 'o' if angle == 0 else 's'
        ax.scatter(z[i, 0], z[i, 1], c=[colors[i]], s=100, alpha=0.8,
                  edgecolors='black', linewidths=0.5, marker=marker,
                  label=f'angle={angle:.0f}°, dist={dist:.2f}')
        ax.annotate(f'{dist:.2f}', (z[i, 0], z[i, 1]),
                   textcoords="offset points", xytext=(5, 5), fontsize=8)

    ax.set_xlabel('z[0]', fontsize=12)
    ax.set_ylabel('z[1]', fontsize=12)
    ax.set_title('Distance Generalization Test\n(angle=0,180, position=0.5)', fontsize=12)
    ax.legend(loc='upper left', fontsize=6, ncol=2)
    ax.grid(True, alpha=0.3)

    # Right: distance vs z magnitude for each angle
    ax2 = axes[1]
    z_mag = np.sqrt(z[:, 0]**2 + z[:, 1]**2)

    # Separate angle 0 and angle 180
    mask_0 = angles == 0
    mask_180 = angles == 180

    # Plot with same colors as left plot
    for i, (angle, dist) in enumerate(params):
        marker = 'o' if angle == 0 else 's'
        ax2.scatter(dist, z_mag[i], c=[colors[i]], s=100, alpha=0.8,
                   edgecolors='black', linewidths=0.5, marker=marker)

    ax2.set_xlabel('Distance fraction', fontsize=12)
    ax2.set_ylabel('|z| (magnitude)', fontsize=12)
    ax2.set_title('z magnitude vs distance\n(circle=angle 0°, square=angle 180°)', fontsize=12)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved distance test visualization to {save_path}")

    return fig


def visualize_combined(z_angle, angles, z_dist, dist_params, save_path):
    """
    Combined visualization showing both tests.
    Uses same 12-color scheme as EE trajectory plots.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Use 12 distinct colors (same as EE trajectory plots)
    colors = get_12_colors()

    # Left: angle test
    ax = axes[0]
    for i, angle in enumerate(angles):
        ax.scatter(z_angle[i, 0], z_angle[i, 1], c=[colors[i]], s=80, alpha=0.8,
                  edgecolors='black', linewidths=0.5, marker='o', label=f'{angle:.0f}°')
        ax.annotate(f'{angle:.0f}', (z_angle[i, 0], z_angle[i, 1]),
                   textcoords="offset points", xytext=(3, 3), fontsize=7)
    ax.set_xlabel('z[0]', fontsize=12)
    ax.set_ylabel('z[1]', fontsize=12)
    ax.set_title('Angle Test (dist=1.0)', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper left', fontsize=6, ncol=2)

    # Right: distance test
    ax = axes[1]
    for i, (angle, dist) in enumerate(dist_params):
        marker = 'o' if angle == 0 else 's'
        ax.scatter(z_dist[i, 0], z_dist[i, 1], c=[colors[i]], s=80, alpha=0.8,
                  edgecolors='black', linewidths=0.5, marker=marker,
                  label=f'a={angle:.0f}°, d={dist:.2f}')
        ax.annotate(f'{dist:.2f}', (z_dist[i, 0], z_dist[i, 1]),
                   textcoords="offset points", xytext=(3, 3), fontsize=7)
    ax.set_xlabel('z[0]', fontsize=12)
    ax.set_ylabel('z[1]', fontsize=12)
    ax.set_title('Distance Test (angle=0,180)', fontsize=12)
    ax.legend(loc='upper left', fontsize=5, ncol=2)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved combined visualization to {save_path}")

    return fig


def generate_test_trajectories_from_simulation(output_dir, num_angles=12, num_distances=6):
    """
    Generate test trajectories by running the RLBench simulation.

    Test 1: 12 different angles at distance=1.0, position=0.5
    Test 2: angle 0 and 180, each with 6 distances at position=0.5
    """
    from multiprocessing import Process, Manager
    from pyrep.const import RenderMode
    from pyrep.errors import ConfigurationPathError, IKError

    from rlbench import ObservationConfig
    from rlbench.action_modes.action_mode import MoveArmThenGripper
    from rlbench.action_modes.arm_action_modes import JointPosition
    from rlbench.action_modes.gripper_action_modes import Discrete
    from rlbench.backend.utils import task_file_to_task_class
    from rlbench.environment import Environment

    from push_utils import (
        get_drawer_handle_position,
        set_drawer_open,
        fix_cabinet_orientation,
        reset_robot_to_default,
        generate_push_trajectory,
    )
    from close_drawer_config import DRAWER_VARIATION, DRAWER_OPEN_AMOUNT, CONTROL_POINT_RADIUS
    from push_utils import check_task_success
    from pyrep.objects.joint import Joint

    # Create output directories
    angle_test_dir = os.path.join(output_dir, 'angle_test')
    distance_test_dir = os.path.join(output_dir, 'distance_test')
    os.makedirs(angle_test_dir, exist_ok=True)
    os.makedirs(distance_test_dir, exist_ok=True)

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(2)  # Top drawer

    drawer_variation = DRAWER_VARIATION
    drawer_open_amount = DRAWER_OPEN_AMOUNT
    control_point_radius = CONTROL_POINT_RADIUS

    # Initialize
    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)
    set_drawer_open(task_env, drawer_variation, drawer_open_amount)
    handle_pos, handle_ori = get_drawer_handle_position(task_env, drawer_variation)

    print(f"Handle position: {handle_pos}")

    # =========================================================================
    # Test 1: 12 different angles at distance=1.0, position=0.5
    # =========================================================================
    print("\n" + "="*60)
    print("Test 1: Angle Generalization (distance=1.0, position=0.5)")
    print("="*60)

    angle_trajectories = []
    angle_values = []

    for i in range(num_angles):
        angle_deg = i * (360.0 / num_angles)
        angle_rad = np.radians(angle_deg)
        dist_frac = 1.0
        pos_frac = 0.5

        canonical_params = np.array([[angle_rad, dist_frac, pos_frac]])

        print(f"\nAngle {i+1}/{num_angles}: {angle_deg:.1f} degrees")

        try:
            reset_robot_to_default(task_env)
            task_env.reset()
            fix_cabinet_orientation(task_env)
            set_drawer_open(task_env, drawer_variation, drawer_open_amount)

            current_handle_pos, current_handle_ori = get_drawer_handle_position(
                task_env, drawer_variation
            )

            demo, traj_metadata = generate_push_trajectory(
                task_env,
                start_pos=np.zeros(3),
                handle_pos=current_handle_pos,
                handle_ori=current_handle_ori,
                cp_idx=0,
                canonical_params=canonical_params,
                control_point_radius=control_point_radius,
                waypoint_params=None,
                phase_steps=None,  # Use default (88 steps) like dataset generator
                steps_per_point=5,
                target_drawer_idx=drawer_variation,
            )

            if len(demo) == 0:
                raise RuntimeError("Empty trajectory")

            # Validation checks (same as dataset generator)
            trace = traj_metadata.get("trace")
            phase_indices = traj_metadata.get("phase_indices", {})

            if trace is not None and len(trace) > 0 and phase_indices:
                reach_range = phase_indices.get("reach", (0, 0))
                if reach_range[1] > 0:
                    reach_end_pos = trace[reach_range[1] - 1]
                    dist_to_handle = np.linalg.norm(reach_end_pos - current_handle_pos)

                    push_range = phase_indices.get("push", (0, 0))
                    if push_range[1] > push_range[0]:
                        push_start_pos = trace[push_range[0]]
                        push_end_pos = trace[push_range[1] - 1]
                        push_distance = np.linalg.norm(push_end_pos - push_start_pos)

                        print(f"  Validation:")
                        print(f"    Handle position: {current_handle_pos}")
                        print(f"    Reach end -> Handle distance: {dist_to_handle:.4f}m")
                        print(f"    Push distance: {push_distance:.4f}m")

            # Check if task succeeded
            task_success = check_task_success(task_env)

            # Check drawer joint position
            drawer_names = ['bottom', 'middle', 'top']
            final_drawer_joint = Joint(f'drawer_joint_{drawer_names[drawer_variation]}')
            final_drawer_pos = final_drawer_joint.get_joint_position()
            print(f"    Drawer joint after trajectory: {final_drawer_pos:.4f}m (need <0.04m for success)")
            print(f"    Task success: {task_success}")

            # Extract states and actions
            states = []
            actions = []
            for obs in demo:
                state = np.concatenate([
                    obs.joint_positions,
                    obs.joint_velocities,
                    [obs.gripper_open],
                    obs.gripper_pose
                ])
                states.append(state)

                if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                    action = obs.misc['joint_position_action']
                else:
                    action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
                actions.append(action)

            states = np.array(states)
            actions = np.array(actions)

            angle_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (angle_deg, dist_frac, pos_frac),
                'success': task_success
            })
            angle_values.append(angle_deg)

            # Save trajectory
            episode_dir = os.path.join(angle_test_dir, f'episode{i}')
            os.makedirs(episode_dir, exist_ok=True)
            np.save(os.path.join(episode_dir, 'states.npy'), states)
            np.save(os.path.join(episode_dir, 'actions.npy'), actions)
            np.save(os.path.join(episode_dir, 'ee_trajectory.npy'), trace)
            np.save(os.path.join(episode_dir, 'metadata.npy'), {
                'angle': angle_deg,
                'distance': dist_frac,
                'position': pos_frac,
                'success': task_success,
                'drawer_final_pos': final_drawer_pos
            })

            print(f"  Demo: {len(demo)} steps, success: {task_success}")

        except Exception as e:
            print(f"  Failed: {e}")

    # =========================================================================
    # Test 2: angle 0 and 180, each with 6 distances
    # =========================================================================
    print("\n" + "="*60)
    print("Test 2: Distance Generalization (angle=0,180, position=0.5)")
    print("="*60)

    distance_trajectories = []
    distance_params = []

    # 6 distances from 0.2 to 1.0
    distances = np.linspace(0.2, 1.0, num_distances)
    test_angles = [0, 180]

    episode_idx = 0
    for angle_deg in test_angles:
        for dist_frac in distances:
            angle_rad = np.radians(angle_deg)
            pos_frac = 0.5

            canonical_params = np.array([[angle_rad, dist_frac, pos_frac]])

            print(f"\nEpisode {episode_idx+1}: angle={angle_deg:.0f}, distance={dist_frac:.2f}")

            try:
                reset_robot_to_default(task_env)
                task_env.reset()
                fix_cabinet_orientation(task_env)
                set_drawer_open(task_env, drawer_variation, drawer_open_amount)

                current_handle_pos, current_handle_ori = get_drawer_handle_position(
                    task_env, drawer_variation
                )

                demo, traj_metadata = generate_push_trajectory(
                    task_env,
                    start_pos=np.zeros(3),
                    handle_pos=current_handle_pos,
                    handle_ori=current_handle_ori,
                    cp_idx=0,
                    canonical_params=canonical_params,
                    control_point_radius=control_point_radius,
                    waypoint_params=None,
                    phase_steps=None,  # Use default (88 steps) like dataset generator
                    steps_per_point=5,
                    target_drawer_idx=drawer_variation,
                )

                if len(demo) == 0:
                    raise RuntimeError("Empty trajectory")

                # Validation checks (same as dataset generator)
                trace = traj_metadata.get("trace")
                phase_indices = traj_metadata.get("phase_indices", {})

                if trace is not None and len(trace) > 0 and phase_indices:
                    reach_range = phase_indices.get("reach", (0, 0))
                    if reach_range[1] > 0:
                        reach_end_pos = trace[reach_range[1] - 1]
                        dist_to_handle = np.linalg.norm(reach_end_pos - current_handle_pos)

                        push_range = phase_indices.get("push", (0, 0))
                        if push_range[1] > push_range[0]:
                            push_start_pos = trace[push_range[0]]
                            push_end_pos = trace[push_range[1] - 1]
                            push_distance = np.linalg.norm(push_end_pos - push_start_pos)

                            print(f"  Validation:")
                            print(f"    Handle position: {current_handle_pos}")
                            print(f"    Reach end -> Handle distance: {dist_to_handle:.4f}m")
                            print(f"    Push distance: {push_distance:.4f}m")

                # Check if task succeeded
                task_success = check_task_success(task_env)

                # Check drawer joint position
                drawer_names = ['bottom', 'middle', 'top']
                final_drawer_joint = Joint(f'drawer_joint_{drawer_names[drawer_variation]}')
                final_drawer_pos = final_drawer_joint.get_joint_position()
                print(f"    Drawer joint after trajectory: {final_drawer_pos:.4f}m (need <0.04m for success)")
                print(f"    Task success: {task_success}")

                # Extract states and actions
                states = []
                actions = []
                for obs in demo:
                    state = np.concatenate([
                        obs.joint_positions,
                        obs.joint_velocities,
                        [obs.gripper_open],
                        obs.gripper_pose
                    ])
                    states.append(state)

                    if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                        action = obs.misc['joint_position_action']
                    else:
                        action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
                    actions.append(action)

                states = np.array(states)
                actions = np.array(actions)

                distance_trajectories.append({
                    'states': torch.from_numpy(states).float(),
                    'actions': torch.from_numpy(actions).float(),
                    'params': (angle_deg, dist_frac, pos_frac),
                    'success': task_success
                })
                distance_params.append((angle_deg, dist_frac))

                # Save trajectory
                episode_dir = os.path.join(distance_test_dir, f'episode{episode_idx}')
                os.makedirs(episode_dir, exist_ok=True)
                np.save(os.path.join(episode_dir, 'states.npy'), states)
                np.save(os.path.join(episode_dir, 'actions.npy'), actions)
                np.save(os.path.join(episode_dir, 'ee_trajectory.npy'), trace)
                np.save(os.path.join(episode_dir, 'metadata.npy'), {
                    'angle': angle_deg,
                    'distance': dist_frac,
                    'position': pos_frac,
                    'success': task_success,
                    'drawer_final_pos': final_drawer_pos
                })

                print(f"  Demo: {len(demo)} steps, success: {task_success}")
                episode_idx += 1

            except Exception as e:
                print(f"  Failed: {e}")

    rlbench_env.shutdown()

    return angle_trajectories, np.array(angle_values), distance_trajectories, distance_params


def load_trajectories_from_disk(output_dir):
    """Load previously generated trajectories from disk."""
    angle_test_dir = os.path.join(output_dir, 'angle_test')
    distance_test_dir = os.path.join(output_dir, 'distance_test')

    angle_trajectories = []
    angle_values = []

    # Load angle test trajectories
    if os.path.exists(angle_test_dir):
        episode_dirs = sorted([d for d in os.listdir(angle_test_dir) if d.startswith('episode')])
        for ep_dir in episode_dirs:
            ep_path = os.path.join(angle_test_dir, ep_dir)
            states = np.load(os.path.join(ep_path, 'states.npy'))
            actions = np.load(os.path.join(ep_path, 'actions.npy'))
            metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

            angle_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (metadata['angle'], metadata['distance'], metadata['position'])
            })
            angle_values.append(metadata['angle'])

    distance_trajectories = []
    distance_params = []

    # Load distance test trajectories
    if os.path.exists(distance_test_dir):
        episode_dirs = sorted([d for d in os.listdir(distance_test_dir) if d.startswith('episode')])
        for ep_dir in episode_dirs:
            ep_path = os.path.join(distance_test_dir, ep_dir)
            states = np.load(os.path.join(ep_path, 'states.npy'))
            actions = np.load(os.path.join(ep_path, 'actions.npy'))
            metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

            distance_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (metadata['angle'], metadata['distance'], metadata['position'])
            })
            distance_params.append((metadata['angle'], metadata['distance']))

    return angle_trajectories, np.array(angle_values), distance_trajectories, distance_params


def main():
    parser = argparse.ArgumentParser(description='Generalization test for close_drawer encoder')
    parser.add_argument('--encoder_path', type=str, required=True,
                       help='Path to encoder checkpoint')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Directory to save test data and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0',
                       help='Device to use for encoding')
    parser.add_argument('--skip_generation', action='store_true',
                       help='Skip trajectory generation and load from disk')
    parser.add_argument('--num_angles', type=int, default=12,
                       help='Number of angles to test')
    parser.add_argument('--num_distances', type=int, default=6,
                       help='Number of distances to test per angle')

    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    plots_dir = os.path.join(args.output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)

    # Generate or load trajectories
    if args.skip_generation:
        print("Loading trajectories from disk...")
        angle_trajs, angle_vals, dist_trajs, dist_params = load_trajectories_from_disk(args.output_dir)
    else:
        print("Generating test trajectories...")
        angle_trajs, angle_vals, dist_trajs, dist_params = generate_test_trajectories_from_simulation(
            args.output_dir, num_angles=args.num_angles, num_distances=args.num_distances
        )

    print(f"\nLoaded {len(angle_trajs)} angle test trajectories")
    print(f"Loaded {len(dist_trajs)} distance test trajectories")

    if len(angle_trajs) == 0 and len(dist_trajs) == 0:
        print("No trajectories to encode!")
        return

    # Load encoder
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model, config = load_encoder(args.encoder_path)
    model = model.to(device)

    # Encode angle test trajectories
    if len(angle_trajs) > 0:
        print("\nEncoding angle test trajectories...")
        z_angle = encode_trajectories(model, angle_trajs, device)

        # Save embeddings
        np.save(os.path.join(args.output_dir, 'z_angle_test.npy'), z_angle)
        np.save(os.path.join(args.output_dir, 'angles.npy'), angle_vals)

        # Visualize
        visualize_angle_test(z_angle, angle_vals, os.path.join(plots_dir, 'angle_test.png'))

        print(f"\nAngle test embeddings:")
        for i, (angle, z) in enumerate(zip(angle_vals, z_angle)):
            print(f"  angle={angle:6.1f}: z=[{z[0]:7.4f}, {z[1]:7.4f}]")

    # Encode distance test trajectories
    if len(dist_trajs) > 0:
        print("\nEncoding distance test trajectories...")
        z_dist = encode_trajectories(model, dist_trajs, device)

        # Save embeddings
        np.save(os.path.join(args.output_dir, 'z_distance_test.npy'), z_dist)
        np.save(os.path.join(args.output_dir, 'distance_params.npy'), np.array(dist_params))

        # Visualize
        visualize_distance_test(z_dist, dist_params, os.path.join(plots_dir, 'distance_test.png'))

        print(f"\nDistance test embeddings:")
        for i, ((angle, dist), z) in enumerate(zip(dist_params, z_dist)):
            print(f"  angle={angle:3.0f}, dist={dist:.2f}: z=[{z[0]:7.4f}, {z[1]:7.4f}]")

    # Combined visualization
    if len(angle_trajs) > 0 and len(dist_trajs) > 0:
        visualize_combined(z_angle, angle_vals, z_dist, dist_params,
                          os.path.join(plots_dir, 'combined_test.png'))

    print(f"\nResults saved to {args.output_dir}")
    print(f"Plots saved to {plots_dir}")


if __name__ == '__main__':
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()
