"""
Generalization test for grasp (pick_up_cup) trajectory encoder.

Test: 12 approach angles (0, 30, 60, ..., 330) at fixed grasp height=0.12.
Original training uses 4 angles (0, 90, 180, 270), so 8 are interpolated.

Pipeline per angle:
1. Launch RLBench sim, generate grasp trajectory at the given approach angle
2. Extract states (22-dim) and actions (8-dim) from observations
3. Normalize using existing normalization.npz (min-max to [-1, 1])
4. Apply make_trajectory_relative
5. Encode with trained VAE -> z (1, 2)

Also encodes all original 40 training trajectories for comparison visualization.
"""
import os
import sys
import torch
import numpy as np

# Set matplotlib backend before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Add parent directories for imports
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'make_dataset'))

from trajectory_encoder import TrajectoryVAE
from train_grasp_encoder import make_trajectory_relative


def load_encoder(checkpoint_path, config=None):
    """Load trained encoder from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    if config is None:
        config = checkpoint.get('config', {})

    model = TrajectoryVAE(
        state_dim=config.get('state_dim', 22),
        action_dim=config.get('action_dim', 8),
        hidden_dim=config.get('hidden_dim', 128),
        num_layers=config.get('num_layers', 4),
        num_heads=config.get('num_heads', 4),
        latent_dim=config.get('latent_dim', 2),
        horizon=config.get('horizon', 107)  # Training used horizon=107
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"Loaded encoder from {checkpoint_path}")
    print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"  Loss: {checkpoint.get('loss', 'unknown'):.4f}")

    return model, config


def load_normalization(normalization_path):
    """Load normalization statistics from npz file."""
    norm = np.load(normalization_path)
    obs_min = norm['obs_min']
    obs_max = norm['obs_max']
    action_min = norm['action_min']
    action_max = norm['action_max']

    obs_range = obs_max - obs_min
    obs_range = np.where(obs_range < 1e-6, 1.0, obs_range)

    action_range = action_max - action_min
    action_range = np.where(action_range < 1e-6, 1.0, action_range)

    print(f"Loaded normalization from {normalization_path}")
    return {
        'obs_min': obs_min, 'obs_max': obs_max, 'obs_range': obs_range,
        'action_min': action_min, 'action_max': action_max, 'action_range': action_range,
    }


def normalize_trajectory(states, actions, norm_stats):
    """Apply min-max normalization to [-1, 1]."""
    states_norm = 2.0 * (states - norm_stats['obs_min']) / norm_stats['obs_range'] - 1.0
    actions_norm = 2.0 * (actions - norm_stats['action_min']) / norm_stats['action_range'] - 1.0
    return states_norm.astype(np.float32), actions_norm.astype(np.float32)


def encode_trajectories(model, trajectories, device='cpu'):
    """
    Encode trajectories using the trained encoder.
    Trajectories should already be normalized.

    Args:
        model: TrajectoryVAE model
        trajectories: list of dicts with 'states' and 'actions' keys (torch tensors)
        device: torch device

    Returns:
        z: numpy array of shape (N, latent_dim)
    """
    model = model.to(device)
    model.eval()

    all_z = []

    # Encoder was trained with horizon=107 (see run_training.sh)
    # PositionalEncoding max_len=215 supports up to 107 timesteps (107*2=214 tokens)
    encoder_horizon = 107

    with torch.no_grad():
        for traj in tqdm(trajectories, desc="Encoding trajectories"):
            states = traj['states']  # (T, state_dim)
            actions = traj['actions']  # (T, action_dim)

            # Truncate to encoder horizon (107, not 108)
            if states.shape[0] > encoder_horizon:
                states = states[:encoder_horizon]
                actions = actions[:encoder_horizon]

            states = states.unsqueeze(0).to(device)  # (1, T, state_dim)
            actions = actions.unsqueeze(0).to(device)  # (1, T, action_dim)

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(states, actions)

            # Encode
            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def encode_training_data(model, dataset_path, device='cpu'):
    """
    Encode all training trajectories from normalized dataset.
    Returns z embeddings and mode labels.
    """
    data = np.load(dataset_path)
    states_all = data['states']  # (N*T, 22) flattened
    actions_all = data['actions']  # (N*T, 8) flattened
    traj_lengths = data['traj_lengths']  # (N,)

    print(f"Loaded training dataset: {len(traj_lengths)} trajectories")

    model = model.to(device)
    model.eval()

    all_z = []

    # Encoder was trained with horizon=107
    encoder_horizon = 107

    # Reconstruct individual trajectories from flattened data
    idx = 0
    with torch.no_grad():
        for i, traj_len in enumerate(tqdm(traj_lengths, desc="Encoding training data")):
            states = torch.from_numpy(states_all[idx:idx + traj_len]).float()
            actions = torch.from_numpy(actions_all[idx:idx + traj_len]).float()
            idx += traj_len

            # Truncate to encoder horizon (107)
            if states.shape[0] > encoder_horizon:
                states = states[:encoder_horizon]
                actions = actions[:encoder_horizon]

            states = states.unsqueeze(0).to(device)
            actions = actions.unsqueeze(0).to(device)

            states_rel, actions_rel = make_trajectory_relative(states, actions)
            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def load_training_mode_labels(episodes_path, num_episodes):
    """Load mode labels from training episode metadata.

    IMPORTANT: Must use lexicographic sort to match convert_to_npz.py ordering.
    sorted() gives: episode0, episode1, episode10, episode11, ..., episode19, episode2, ...
    """
    # Get episode folders in same lexicographic order as convert_to_npz.py
    episode_folders = sorted([
        d for d in os.listdir(episodes_path)
        if d.startswith('episode') and os.path.isdir(os.path.join(episodes_path, d))
    ])

    if num_episodes is not None:
        episode_folders = episode_folders[:num_episodes]

    modes = []
    angles = []
    for ep_folder in episode_folders:
        metadata_path = os.path.join(episodes_path, ep_folder, 'metadata.npy')
        if os.path.exists(metadata_path):
            metadata = np.load(metadata_path, allow_pickle=True).item()
            modes.append(metadata.get('mode', -1))
            angles.append(np.degrees(metadata.get('approach_angle', 0.0)))
        else:
            modes.append(-1)
            angles.append(0.0)
    return np.array(modes), np.array(angles)


def visualize_12angle_embedding(z, angles_deg, save_path, title_suffix=''):
    """
    Visualize 12-point embedding with HSV colormap by angle, connected ring.
    """
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))

    # HSV color by angle
    angle_normalized = (np.array(angles_deg) % 360) / 360.0
    colors = plt.cm.hsv(angle_normalized)

    # Plot connected ring
    order = np.argsort(angles_deg)
    z_ordered = z[order]
    # Close the ring
    ring_z = np.vstack([z_ordered, z_ordered[:1]])
    ax.plot(ring_z[:, 0], ring_z[:, 1], 'k-', alpha=0.3, linewidth=1, zorder=1)

    # Plot points
    for i, angle in enumerate(angles_deg):
        ax.scatter(z[i, 0], z[i, 1], c=[colors[i]], s=300, alpha=0.9,
                   edgecolors='black', linewidths=1.0, zorder=3)

    ax.set_xlabel('z[0]', fontsize=14)
    ax.set_ylabel('z[1]', fontsize=14)
    ax.set_title(f'12-Angle Generalization Test{title_suffix}', fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')

    # Add HSV colorbar
    sm = plt.cm.ScalarMappable(cmap='hsv', norm=plt.Normalize(0, 360))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.8)
    cbar.set_label('Approach Angle (degrees)', fontsize=12)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"Saved 12-angle embedding to {save_path}")


def visualize_original_4mode_embedding(z_train, modes, angles_deg, save_path):
    """
    Visualize original 4-mode training data embeddings.
    40 points (10 per mode), categorical colors for 4 modes.
    """
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))

    unique_modes = np.unique(modes)
    # Use fixed colors for 4 modes (0=red, 1=green, 2=blue, 3=orange)
    mode_colors = ['#e41a1c', '#4daf4a', '#377eb8', '#ff7f00']
    mode_labels = ['0° (mode 0)', '90° (mode 1)', '180° (mode 2)', '270° (mode 3)']

    for mode_idx in unique_modes:
        if mode_idx < 0:
            continue
        mask = modes == mode_idx
        color = mode_colors[int(mode_idx) % len(mode_colors)]
        label = mode_labels[int(mode_idx)] if int(mode_idx) < len(mode_labels) else f'Mode {int(mode_idx)}'
        ax.scatter(z_train[mask, 0], z_train[mask, 1], c=color, s=100, alpha=0.7,
                   edgecolors='black', linewidths=0.5, label=label, zorder=3)

    ax.set_xlabel('z[0]', fontsize=14)
    ax.set_ylabel('z[1]', fontsize=14)
    ax.set_title('Original 4-Mode Training Embeddings', fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')
    ax.legend(loc='best', fontsize=11)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"Saved original 4-mode embedding to {save_path}")


def visualize_combined_embedding(z_train, modes, z_12angles, angles_12, save_path):
    """
    Combined visualization: training data + 12-angle ring on same plot.
    """
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    # Plot training data (lighter, smaller)
    unique_modes = np.unique(modes)
    mode_colors = ['#e41a1c', '#4daf4a', '#377eb8', '#ff7f00']
    mode_labels = ['0° train', '90° train', '180° train', '270° train']

    for mode_idx in unique_modes:
        if mode_idx < 0:
            continue
        mask = modes == mode_idx
        color = mode_colors[int(mode_idx) % len(mode_colors)]
        label = mode_labels[int(mode_idx)] if int(mode_idx) < len(mode_labels) else f'Mode {int(mode_idx)}'
        ax.scatter(z_train[mask, 0], z_train[mask, 1], c=color, s=60, alpha=0.4,
                   edgecolors='gray', linewidths=0.3, label=label, zorder=2)

    # Plot 12-angle ring (larger, brighter)
    angle_normalized = (np.array(angles_12) % 360) / 360.0
    colors = plt.cm.hsv(angle_normalized)

    # Connected ring
    order = np.argsort(angles_12)
    z_ordered = z_12angles[order]
    ring_z = np.vstack([z_ordered, z_ordered[:1]])
    ax.plot(ring_z[:, 0], ring_z[:, 1], 'k-', alpha=0.4, linewidth=1.5, zorder=1)

    for i, angle in enumerate(angles_12):
        ax.scatter(z_12angles[i, 0], z_12angles[i, 1], c=[colors[i]], s=250, alpha=0.9,
                   edgecolors='black', linewidths=1.5, zorder=4,
                   label=f'{angle:.0f}°' if i == 0 else None)

    ax.set_xlabel('z[0]', fontsize=14)
    ax.set_ylabel('z[1]', fontsize=14)
    ax.set_title('Training Data + 12-Angle Generalization', fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')

    # Add HSV colorbar for 12-angle points
    sm = plt.cm.ScalarMappable(cmap='hsv', norm=plt.Normalize(0, 360))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.7)
    cbar.set_label('Approach Angle (degrees)', fontsize=12)

    ax.legend(loc='best', fontsize=9, ncol=2)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"Saved combined embedding to {save_path}")


def generate_12angle_trajectories(output_dir, norm_stats, num_angles=12):
    """
    Generate grasp trajectories at 12 approach angles using RLBench simulation.

    For each angle:
    1. Generate trajectory via simulation
    2. Extract states/actions
    3. Normalize using normalization.npz
    4. Return normalized trajectories ready for encoding
    """
    from multiprocessing import Process, Manager
    from pyrep.const import RenderMode
    from pyrep.errors import ConfigurationPathError, IKError

    from rlbench import ObservationConfig
    from rlbench.action_modes.action_mode import MoveArmThenGripper
    from rlbench.action_modes.arm_action_modes import JointPosition
    from rlbench.action_modes.gripper_action_modes import Discrete
    from rlbench.backend.utils import task_file_to_task_class
    from rlbench.environment import Environment

    from grasp_utils import (
        get_cup_position,
        set_cup_position,
        reset_robot_to_default,
        generate_grasp_trajectory,
        check_task_success,
        check_grasp_success_manual,
    )
    from grasp_config import (
        FIXED_CUP_POSITION,
        HOME_JOINTS,
        CONTROL_POINT_RADIUS,
        GRASP_HEIGHTS,
        CUP_VARIATION,
    )

    # Create output directory
    angle_test_dir = os.path.join(output_dir, '12angle_test')
    os.makedirs(angle_test_dir, exist_ok=True)

    grasp_height = GRASP_HEIGHTS[0]  # 0.12

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("pick_up_cup")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(CUP_VARIATION)

    # Initialize
    descriptions, obs = task_env.reset()

    print(f"\nFixed cup position: {FIXED_CUP_POSITION}")
    print(f"Grasp height: {grasp_height}")
    print(f"Control point radius: {CONTROL_POINT_RADIUS}")

    # =========================================================================
    # Generate trajectories at 12 angles
    # =========================================================================
    print("\n" + "=" * 60)
    print(f"Generating {num_angles} approach angle trajectories")
    print("=" * 60)

    angle_trajectories = []
    angle_values = []
    max_attempts = 20
    # Noise offsets to try if exact angle fails (±5°, ±3°, ±2°, ±1°)
    noise_offsets_deg = [0, -5, 5, -3, 3, -2, 2, -1, 1]

    for i in range(num_angles):
        target_angle_deg = i * (360.0 / num_angles)

        print(f"\nAngle {i + 1}/{num_angles}: {target_angle_deg:.1f} degrees")

        # Skip if episode already exists on disk
        episode_dir = os.path.join(angle_test_dir, f'episode{i}')
        states_path = os.path.join(episode_dir, 'states.npy')
        if os.path.exists(states_path):
            metadata = np.load(os.path.join(episode_dir, 'metadata.npy'), allow_pickle=True).item()
            states = np.load(states_path)
            actions = np.load(os.path.join(episode_dir, 'actions.npy'))
            states_norm, actions_norm = normalize_trajectory(states, actions, norm_stats)
            t_angle = metadata.get('target_angle_deg', metadata['approach_angle_deg'])
            angle_trajectories.append({
                'states': torch.from_numpy(states_norm).float(),
                'actions': torch.from_numpy(actions_norm).float(),
                'angle_deg': t_angle,
                'actual_angle_deg': metadata['approach_angle_deg'],
                'success': metadata.get('task_success', False),
            })
            angle_values.append(t_angle)
            print(f"  Already exists on disk, skipping (angle={metadata['approach_angle_deg']:.1f}°)")
            continue

        success = False
        actual_angle_deg = target_angle_deg

        for offset in noise_offsets_deg:
            if success:
                break

            trial_angle_deg = target_angle_deg + offset
            trial_angle_rad = np.radians(trial_angle_deg)

            if offset != 0:
                print(f"  Trying offset {offset:+.0f}° -> {trial_angle_deg:.1f}°")

            for attempt in range(max_attempts):
                try:
                    reset_robot_to_default(task_env)
                    task_env.reset()

                    if FIXED_CUP_POSITION is not None:
                        set_cup_position(task_env, FIXED_CUP_POSITION)

                    current_cup_pos, current_cup_ori = get_cup_position(task_env)

                    demo, traj_metadata = generate_grasp_trajectory(
                        task_env,
                        start_pos=np.zeros(3),
                        cup_pos=current_cup_pos,
                        cup_ori=current_cup_ori,
                        approach_angle=trial_angle_rad,
                        grasp_height=grasp_height,
                        control_point_radius=CONTROL_POINT_RADIUS,
                        waypoint_params=None,
                        phase_steps=None,
                        steps_per_point=5,
                    )

                    if len(demo) == 0:
                        raise RuntimeError("Empty trajectory")

                    task_success = check_task_success(task_env)
                    initial_cup_z = traj_metadata.get("initial_cup_z", current_cup_pos[2])
                    grasp_success = check_grasp_success_manual(task_env, initial_cup_z)

                    print(f"  Task success: {task_success}, Grasp success: {grasp_success}")

                    # Extract states and actions
                    states = []
                    actions = []
                    for obs_step in demo:
                        state = np.concatenate([
                            obs_step.joint_positions,
                            obs_step.joint_velocities,
                            [obs_step.gripper_open],
                            obs_step.gripper_pose
                        ])
                        states.append(state)

                        if hasattr(obs_step, 'misc') and 'joint_position_action' in obs_step.misc:
                            action = obs_step.misc['joint_position_action']
                        else:
                            action = np.concatenate([obs_step.joint_positions, [obs_step.gripper_open]])
                        actions.append(action)

                    states = np.array(states, dtype=np.float32)
                    actions = np.array(actions, dtype=np.float32)

                    # Save raw trajectory
                    trace = traj_metadata.get("trace")
                    episode_dir = os.path.join(angle_test_dir, f'episode{i}')
                    os.makedirs(episode_dir, exist_ok=True)
                    np.save(os.path.join(episode_dir, 'states.npy'), states)
                    np.save(os.path.join(episode_dir, 'actions.npy'), actions)
                    if trace is not None:
                        np.save(os.path.join(episode_dir, 'ee_trajectory.npy'), trace)
                    np.save(os.path.join(episode_dir, 'metadata.npy'), {
                        'approach_angle': trial_angle_rad,
                        'approach_angle_deg': trial_angle_deg,
                        'target_angle_deg': target_angle_deg,
                        'angle_offset_deg': offset,
                        'grasp_height': grasp_height,
                        'task_success': task_success,
                        'grasp_success': grasp_success,
                        'cup_pos': current_cup_pos.tolist(),
                    })

                    # Normalize trajectory
                    states_norm, actions_norm = normalize_trajectory(states, actions, norm_stats)

                    actual_angle_deg = trial_angle_deg
                    angle_trajectories.append({
                        'states': torch.from_numpy(states_norm).float(),
                        'actions': torch.from_numpy(actions_norm).float(),
                        'angle_deg': target_angle_deg,  # Use target for ring plotting
                        'actual_angle_deg': trial_angle_deg,
                        'success': task_success,
                    })
                    angle_values.append(target_angle_deg)

                    offset_str = f" (offset {offset:+.0f}°)" if offset != 0 else ""
                    print(f"  Demo: {len(demo)} steps, SUCCESS (attempt {attempt + 1}){offset_str}")
                    success = True
                    break

                except Exception as e:
                    if attempt < max_attempts - 1:
                        print(f"  Attempt {attempt + 1} failed: {str(e)[:80]}...")
                    else:
                        if offset == noise_offsets_deg[-1]:
                            print(f"  FAILED after all offsets and {max_attempts} attempts: {e}")
                        else:
                            print(f"  Offset {offset:+.0f}° failed after {max_attempts} attempts")

        if not success:
            print(f"  WARNING: Could not generate trajectory for angle {target_angle_deg}!")

    rlbench_env.shutdown()

    return angle_trajectories, np.array(angle_values)


def load_trajectories_from_disk(output_dir, norm_stats):
    """Load previously generated trajectories from disk and normalize."""
    angle_test_dir = os.path.join(output_dir, '12angle_test')

    angle_trajectories = []
    angle_values = []

    if os.path.exists(angle_test_dir):
        episode_dirs = sorted(
            [d for d in os.listdir(angle_test_dir) if d.startswith('episode')],
            key=lambda x: int(x.replace('episode', ''))
        )
        for ep_dir in episode_dirs:
            ep_path = os.path.join(angle_test_dir, ep_dir)
            states = np.load(os.path.join(ep_path, 'states.npy'))
            actions = np.load(os.path.join(ep_path, 'actions.npy'))
            metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

            # Normalize
            states_norm, actions_norm = normalize_trajectory(states, actions, norm_stats)

            # Use target angle for plotting (falls back to approach_angle_deg for old data)
            target_angle = metadata.get('target_angle_deg', metadata['approach_angle_deg'])
            angle_trajectories.append({
                'states': torch.from_numpy(states_norm).float(),
                'actions': torch.from_numpy(actions_norm).float(),
                'angle_deg': target_angle,
                'actual_angle_deg': metadata['approach_angle_deg'],
                'success': metadata.get('task_success', False),
            })
            angle_values.append(target_angle)

    return angle_trajectories, np.array(angle_values)


def main():
    parser = argparse.ArgumentParser(
        description='12-angle generalization test for grasp encoder'
    )
    parser.add_argument('--encoder_path', type=str, required=True,
                        help='Path to encoder checkpoint')
    parser.add_argument('--normalization_path', type=str, required=True,
                        help='Path to normalization.npz')
    parser.add_argument('--dataset_path', type=str, required=True,
                        help='Path to train_normalized.npz (for encoding original 4-mode data)')
    parser.add_argument('--episodes_path', type=str, required=True,
                        help='Path to episodes directory (for loading mode labels)')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='Directory to save test data and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='Device to use for encoding')
    parser.add_argument('--skip_generation', action='store_true',
                        help='Skip trajectory generation and load from disk')
    parser.add_argument('--num_angles', type=int, default=12,
                        help='Number of angles to test')

    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    plots_dir = os.path.join(args.output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)

    # Load normalization stats
    norm_stats = load_normalization(args.normalization_path)

    # Generate or load trajectories
    if args.skip_generation:
        print("Loading trajectories from disk...")
        angle_trajs, angle_vals = load_trajectories_from_disk(args.output_dir, norm_stats)
    else:
        print("Generating test trajectories...")
        angle_trajs, angle_vals = generate_12angle_trajectories(
            args.output_dir, norm_stats, num_angles=args.num_angles
        )

    print(f"\nLoaded {len(angle_trajs)} angle test trajectories")

    if len(angle_trajs) == 0:
        print("No trajectories to encode!")
        return

    # Load encoder
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model, config = load_encoder(args.encoder_path)
    model = model.to(device)

    # =========================================================================
    # Encode 12-angle test trajectories
    # =========================================================================
    print("\nEncoding 12-angle test trajectories...")
    z_12angles = encode_trajectories(model, angle_trajs, device)

    # Save embeddings
    np.save(os.path.join(args.output_dir, 'z_12angles.npy'), z_12angles)
    np.save(os.path.join(args.output_dir, 'angles_12.npy'), angle_vals)

    print(f"\n12-angle test embeddings:")
    for i, (angle, z) in enumerate(zip(angle_vals, z_12angles)):
        print(f"  angle={angle:6.1f}: z=[{z[0]:7.4f}, {z[1]:7.4f}]")

    # =========================================================================
    # Encode original training data
    # =========================================================================
    print("\nEncoding original training data...")
    z_train = encode_training_data(model, args.dataset_path, device)

    # Load mode labels
    num_train = len(z_train)
    modes, train_angles = load_training_mode_labels(args.episodes_path, num_train)

    np.save(os.path.join(args.output_dir, 'z_train.npy'), z_train)
    np.save(os.path.join(args.output_dir, 'train_modes.npy'), modes)
    np.save(os.path.join(args.output_dir, 'train_angles.npy'), train_angles)

    # =========================================================================
    # Visualize
    # =========================================================================
    visualize_12angle_embedding(
        z_12angles, angle_vals,
        os.path.join(plots_dir, 'embedding_12angles.png')
    )

    visualize_original_4mode_embedding(
        z_train, modes, train_angles,
        os.path.join(plots_dir, 'embedding_original_4modes.png')
    )

    visualize_combined_embedding(
        z_train, modes, z_12angles, angle_vals,
        os.path.join(plots_dir, 'embedding_combined.png')
    )

    print(f"\nResults saved to {args.output_dir}")
    print(f"Plots saved to {plots_dir}")
    print(f"\nKey output files:")
    print(f"  z_12angles.npy: ({z_12angles.shape}) - z embeddings for 12 angles")
    print(f"  angles_12.npy: ({angle_vals.shape}) - angle values in degrees")
    print(f"  z_train.npy: ({z_train.shape}) - z embeddings for training data")


if __name__ == '__main__':
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()
