"""
Training script for pick_place trajectory encoder/decoder with reconstruction, KL, and DTW losses.

Adapted from RLBench_close_drawer/encoder/train_close_drawer_encoder.py for the pick-and-place task.

Dataset structure:
- 160 trajectories: 80 REACH + 80 CARRY (alternating: reach_0, carry_0, reach_1, carry_1, ...)
- Each trajectory has 64 steps
- States: 22 dims, Actions: 8 dims
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

# Set matplotlib backend to Agg (non-interactive) before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Local imports
from trajectory_encoder import TrajectoryVAE
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def compute_reconstruction_loss(pred_states, pred_actions, gt_states, gt_actions):
    """
    Reconstruction loss for predicted sequences
    pred_states: (B, horizon-1, state_dim) - predicted s1, s2, s3
    pred_actions: (B, horizon, action_dim) - predicted a0, a1, a2, a3
    gt_states: (B, horizon, state_dim) - ground truth states
    gt_actions: (B, horizon, action_dim) - ground truth actions
    """
    # Actions: predict a0, a1, a2, a3
    action_loss = F.mse_loss(pred_actions, gt_actions)

    # States: predict s1, s2, s3 (skip s0 as it's given)
    if pred_states is not None:
        state_loss = F.mse_loss(pred_states, gt_states[:, 1:, :])
    else:
        state_loss = torch.tensor(0.0, device=pred_actions.device)

    return action_loss + state_loss


def compute_kl_loss(z_mean, z_logvar):
    """
    KL divergence loss: KL(q(z|x) || p(z)) where p(z) = N(0, I)

    This regularizes the latent space to be close to a standard Gaussian,
    preventing z values from exploding while allowing enough spread for
    different trajectory modes to be distinguishable.

    z_mean: (B, latent_dim) - unbounded mean (no tanh)
    z_logvar: (B, latent_dim) - log variance
    """
    # KL = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def build_local_frame_batched(start_pos, end_pos):
    """
    Build trajectory-relative orthonormal frame for a batch of trajectories.
    Same as utils.build_local_frame but batched.

    Args:
        start_pos: (B, 3) - start positions
        end_pos: (B, 3) - end positions

    Returns:
        line_dir: (B, 3) - unit vector along trajectory
        perp1: (B, 3) - perpendicular axis in vertical plane
        perp2: (B, 3) - perpendicular axis horizontal
        dist: (B,) - distance from start to end
    """
    line_vec = end_pos - start_pos  # (B, 3)
    dist = torch.norm(line_vec, dim=1, keepdim=True)  # (B, 1)
    dist = torch.clamp(dist, min=1e-6)  # Avoid division by zero
    line_dir = line_vec / dist  # (B, 3)

    # World up vector
    world_up = torch.tensor([0.0, 0.0, 1.0], device=start_pos.device, dtype=start_pos.dtype)
    world_up = world_up.unsqueeze(0).expand(start_pos.shape[0], -1)  # (B, 3)

    # Project world_up onto plane perpendicular to line_vec
    dot = torch.sum(world_up * line_dir, dim=1, keepdim=True)  # (B, 1)
    perp1 = world_up - dot * line_dir  # (B, 3)
    perp1_len = torch.norm(perp1, dim=1, keepdim=True)  # (B, 1)

    # Handle case where line_vec is nearly vertical
    nearly_vertical = perp1_len.squeeze(1) < 1e-6  # (B,)
    if nearly_vertical.any():
        world_forward = torch.tensor([0.0, 1.0, 0.0], device=start_pos.device, dtype=start_pos.dtype)
        world_forward = world_forward.unsqueeze(0).expand(start_pos.shape[0], -1)  # (B, 3)
        dot_fwd = torch.sum(world_forward * line_dir, dim=1, keepdim=True)  # (B, 1)
        perp1_alt = world_forward - dot_fwd * line_dir  # (B, 3)
        perp1_alt_len = torch.norm(perp1_alt, dim=1, keepdim=True)  # (B, 1)
        perp1_alt = perp1_alt / torch.clamp(perp1_alt_len, min=1e-6)

        # Replace for nearly vertical cases
        perp1 = torch.where(nearly_vertical.unsqueeze(1), perp1_alt, perp1 / torch.clamp(perp1_len, min=1e-6))
    else:
        perp1 = perp1 / perp1_len

    # perp2 is perpendicular to both line_dir and perp1
    perp2 = torch.cross(line_dir, perp1, dim=1)  # (B, 3)
    perp2 = perp2 / torch.norm(perp2, dim=1, keepdim=True).clamp(min=1e-6)

    return line_dir, perp1, perp2, dist.squeeze(1)


def normalize_ee_trajectory(ee_pos):
    """
    Normalize end-effector trajectory relative to start and end positions.

    Transforms EE positions into a trajectory-relative coordinate system:
    - X-axis (progress): projection along start->end direction, scaled to [0, 1]
    - Y-axis: offset along perp1 (vertical plane), scaled by path length
    - Z-axis: offset along perp2 (horizontal), scaled by path length

    This ensures that REACH and CARRY trajectories with the same CP params
    will have the same normalized shape, regardless of absolute positions.

    Args:
        ee_pos: (B, T, 3) - end-effector positions

    Returns:
        ee_normalized: (B, T, 3) - normalized positions [progress, perp1_offset, perp2_offset]
    """
    B, T, _ = ee_pos.shape

    # Get start and end positions
    start_pos = ee_pos[:, 0, :]  # (B, 3)
    end_pos = ee_pos[:, -1, :]  # (B, 3)

    # Build trajectory-relative frame
    line_dir, perp1, perp2, dist = build_local_frame_batched(start_pos, end_pos)

    # Expand for broadcasting
    start_pos_exp = start_pos.unsqueeze(1)  # (B, 1, 3)
    line_dir_exp = line_dir.unsqueeze(1)  # (B, 1, 3)
    perp1_exp = perp1.unsqueeze(1)  # (B, 1, 3)
    perp2_exp = perp2.unsqueeze(1)  # (B, 1, 3)
    dist_exp = dist.unsqueeze(1).unsqueeze(2)  # (B, 1, 1)

    # Compute normalized coordinates
    # Vector from start to each point
    vec_from_start = ee_pos - start_pos_exp  # (B, T, 3)

    # Progress along the line (0 at start, 1 at end)
    progress = torch.sum(vec_from_start * line_dir_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)

    # Point on the straight line at this progress
    # line_vec = end_pos - start_pos = dist * line_dir
    line_vec = (end_pos - start_pos).unsqueeze(1)  # (B, 1, 3)
    point_on_line = start_pos_exp + progress * line_vec  # (B, T, 3)

    # Offset from the straight line
    offset = ee_pos - point_on_line  # (B, T, 3)

    # Project offset onto perp1 and perp2, normalize by path length
    perp1_offset = torch.sum(offset * perp1_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)
    perp2_offset = torch.sum(offset * perp2_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)

    # Combine into normalized trajectory
    ee_normalized = torch.cat([progress, perp1_offset, perp2_offset], dim=2)  # (B, T, 3)

    return ee_normalized


def make_trajectory_relative(states, actions):
    """
    Convert trajectories to relative coordinates using trajectory-relative frame.

    For the EE positions (indices 15:18), we normalize relative to BOTH start and end:
    - Progress: 0 at start, 1 at end
    - Perpendicular offsets: scaled by path length

    This ensures REACH and CARRY trajectories with the same CP params will have
    the same normalized EE trajectory shape.

    State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    Action format: joint_pos(7) + gripper_open(1)

    Args:
        states: (B, T, 22) - absolute states
        actions: (B, T, 8) - absolute actions

    Returns:
        states_rel: (B, T, 22) - relative states
        actions_rel: (B, T, 8) - relative actions
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    # Extract s0 reference values for joint space
    s0_joint_pos = states[:, 0:1, 0:7]  # (B, 1, 7)
    s0_joint_vel = states[:, 0:1, 7:14]  # (B, 1, 7)

    # Make joint positions relative to s0
    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos

    # Make joint velocities relative to s0
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel

    # Normalize EE positions using trajectory-relative frame (key change!)
    ee_pos = states[:, :, 15:18]  # (B, T, 3)
    ee_normalized = normalize_ee_trajectory(ee_pos)  # (B, T, 3) - [progress, perp1, perp2]
    states_rel[:, :, 15:18] = ee_normalized

    # Keep gripper_open and quaternion as-is (indices 14 and 18:22)

    # Make actions (joint positions) relative to s0
    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    # Keep gripper_open in actions as-is (index 7)

    return states_rel, actions_rel


def extract_normalized_ee_trajectory(states_rel):
    """
    Extract normalized end-effector trajectory from relative states.

    After make_trajectory_relative(), the EE positions at indices 15:18 are already
    normalized to [progress, perp1_offset, perp2_offset] coordinates.

    states_rel: (B, T, 22) - relative states (after make_trajectory_relative)
    Returns: (B, T, 3) - normalized trajectory [progress, perp1, perp2]
    """
    return states_rel[:, :, 15:18]  # Already normalized!


def compute_dtw_loss(z, states_rel, gamma=0.00001):
    """
    DTW loss: distances in z-space should match normalized trajectory distances.

    Uses the already-normalized EE trajectory (progress + perpendicular offsets)
    so that REACH and CARRY with the same CP params have matching DTW distances.

    z: (B, latent_dim) - latent embeddings
    states_rel: (B, T, state_dim) - RELATIVE states (after make_trajectory_relative)
    gamma: smoothness parameter for soft-DTW
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    # Extract normalized EE trajectory (already in trajectory-relative coordinates)
    ee_traj = extract_normalized_ee_trajectory(states_rel)  # (B, T, 3) - [progress, perp1, perp2]

    # Create shifted pairs by rolling
    z_1 = z  # (B, latent_dim)
    z_2 = torch.roll(z, 1, 0)
    ee_traj_1 = ee_traj  # (B, T, 3)
    ee_traj_2 = torch.roll(ee_traj, 1, 0)

    # Compute z-space distances (L2 distance)
    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)  # (B,)

    # Compute soft-DTW distances between normalized trajectories
    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_traj_1, ee_traj_2)  # (B,)

    # Normalize both to [0, 1] range for comparison
    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    # MSE loss between normalized distances
    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_control_point_metadata(metadata_path):
    """
    Load control point metadata from pick-and-place dataset.

    The pick-and-place dataset stores metadata at the train folder level,
    not per-episode. The processed dataset alternates REACH and CARRY:
    trajectory 0 = REACH from episode 0
    trajectory 1 = CARRY from episode 0
    trajectory 2 = REACH from episode 1
    ...

    Args:
        metadata_path: Path to train folder (containing train_metadata.npy)

    Returns:
        dict with control point info per trajectory, or None if not found
    """
    try:
        # Find train_metadata.npy
        if os.path.isfile(metadata_path) and metadata_path.endswith('.npy'):
            meta_file = metadata_path
        else:
            # Try to find it in the directory structure
            if 'processed' in metadata_path:
                # Go from processed dir to train
                base_dir = os.path.dirname(os.path.dirname(metadata_path))
                meta_file = os.path.join(base_dir, 'train', 'train_metadata.npy')
            elif 'train' in metadata_path:
                meta_file = os.path.join(metadata_path, 'train_metadata.npy')
            else:
                meta_file = os.path.join(metadata_path, 'train', 'train_metadata.npy')

        if not os.path.exists(meta_file):
            print(f"Metadata file not found: {meta_file}")
            return None

        # Load metadata
        episode_metadata = np.load(meta_file, allow_pickle=True)

        # Build per-trajectory metadata (each episode produces 2 trajectories)
        modes = []
        cp_params = []
        phase_types = []  # 'reach' or 'carry'
        reach_cps = []
        carry_cps = []

        for ep_meta in episode_metadata:
            # REACH trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('reach')
            reach_cps.append(ep_meta['reach_cp'])
            carry_cps.append(ep_meta['carry_cp'])

            # CARRY trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('carry')
            reach_cps.append(ep_meta['reach_cp'])
            carry_cps.append(ep_meta['carry_cp'])

        return {
            'modes': np.array(modes),
            'cp_params': cp_params,
            'phase_types': phase_types,
            'reach_cps': np.array(reach_cps),
            'carry_cps': np.array(carry_cps),
            'num_trajectories': len(modes)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        import traceback
        traceback.print_exc()
        return None


def visualize_embeddings(z, epoch, cp_metadata=None, save_path=None):
    """
    Visualize embeddings colored by control point mode.

    z: (N, latent_dim) numpy array of embeddings
    cp_metadata: dict with 'modes', 'cp_params', 'phase_types' from load_control_point_metadata
    """
    # Apply t-SNE if z is higher than 2D
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        print(f"Applying t-SNE to reduce {z.shape[1]}D embeddings to 2D...")
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_mode = axes[0]
    ax_phase = axes[1]
    ax_params = axes[2]

    if cp_metadata is not None and len(cp_metadata['modes']) >= len(z_vis):
        modes = cp_metadata['modes'][:len(z_vis)]
        cp_params = cp_metadata['cp_params'][:len(z_vis)]
        phase_types = cp_metadata['phase_types'][:len(z_vis)]

        # Color by mode
        unique_modes = np.unique(modes)
        mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
        mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}
        point_colors_mode = [mode_color_map[m] for m in modes]

        scatter1 = ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_mode,
                                   s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        # Create legend for modes
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax_mode.legend(handles=legend_elements, loc='upper right', fontsize=8)
        ax_mode.set_title(f'Embeddings by Mode (Epoch {epoch}){title_suffix}', fontsize=11)

        # Color by phase type (REACH vs CARRY)
        phase_color_map = {'reach': 'blue', 'carry': 'red'}
        point_colors_phase = [phase_color_map[p] for p in phase_types]

        scatter2 = ax_phase.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_phase,
                                    s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        legend_elements_phase = [Patch(facecolor='blue', label='REACH'),
                                 Patch(facecolor='red', label='CARRY')]
        ax_phase.legend(handles=legend_elements_phase, loc='upper right', fontsize=10)
        ax_phase.set_title(f'Embeddings by Phase Type{title_suffix}', fontsize=11)

        # Plot CP params (angle vs distance) colored by mode
        angles = [np.degrees(p[0]) for p in cp_params]
        dists = [p[1] for p in cp_params]

        ax_params.scatter(angles, dists, c=point_colors_mode, s=50, alpha=0.7,
                          edgecolors='black', linewidths=0.5)
        ax_params.set_xlabel('Angle (degrees)', fontsize=11)
        ax_params.set_ylabel('Distance fraction', fontsize=11)
        ax_params.set_title('Control Point Parameters', fontsize=11)
        ax_params.set_xlim(-20, 380)
        ax_params.set_ylim(0, 1.2)
        ax_params.grid(True, alpha=0.3)
    else:
        # Default coloring by index
        scatter = ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7,
                                  edgecolors='black', linewidths=0.5)
        plt.colorbar(scatter, ax=ax_mode, label='Trajectory Index')
        ax_phase.text(0.5, 0.5, 'No CP metadata available',
                      transform=ax_phase.transAxes, ha='center', va='center')
        ax_params.text(0.5, 0.5, 'No CP metadata available',
                       transform=ax_params.transAxes, ha='center', va='center')

    ax_mode.set_xlabel('z[0]' if z.shape[1] <= 2 else 'Component 1', fontsize=11)
    ax_mode.set_ylabel('z[1]' if z.shape[1] <= 2 else 'Component 2', fontsize=11)
    ax_mode.grid(True, alpha=0.3)

    ax_phase.set_xlabel('z[0]' if z.shape[1] <= 2 else 'Component 1', fontsize=11)
    ax_phase.set_ylabel('z[1]' if z.shape[1] <= 2 else 'Component 2', fontsize=11)
    ax_phase.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch"""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        # Extract states and actions from batch
        actions = batch.actions  # (B, horizon, action_dim)
        states = batch.conditions['state']  # (B, horizon, state_dim)

        # Move data to device
        actions = actions.to(device)
        states = states.to(device)

        # If states only has 1 timestep, skip
        if states.shape[1] < actions.shape[1]:
            continue

        # Convert to relative coordinates (removes absolute position info)
        states_rel, actions_rel = make_trajectory_relative(states, actions)

        # Forward pass with relative coordinates
        outputs = model(states_rel, actions_rel)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        z = outputs['z']
        pred_states = outputs['pred_states']
        pred_actions = outputs['pred_actions']

        # Compute losses (use relative coordinates for reconstruction)
        recon_loss = compute_reconstruction_loss(pred_states, pred_actions, states_rel, actions_rel)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, states_rel, gamma=0.00001)

        # Total loss
        loss = recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Accumulate losses
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        # Update progress bar
        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        # Log to wandb
        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    # Average losses
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """
    Collect all embeddings for visualization.
    Creates a new dataloader with shuffle=False to ensure embeddings
    are collected in episode order.
    """
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            actions = batch.actions
            states = batch.conditions['state']

            actions = actions.to(device)
            states = states.to(device)

            if states.shape[1] < actions.shape[1]:
                continue

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(states, actions)

            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    all_z = np.concatenate(all_z, axis=0)
    return all_z


def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Train Pick-and-Place Trajectory Encoder VAE')
    parser.add_argument('--horizon', type=int, default=64, help='Horizon length for trajectories (64 for pick_place REACH/CARRY)')
    parser.add_argument('--latent_dim', type=int, default=2, help='Latent dimension')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension')
    parser.add_argument('--num_layers', type=int, default=4, help='Number of transformer layers')
    parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=5000, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--kl_weight', type=float, default=0.001, help='KL loss weight')
    parser.add_argument('--dtw_weight', type=float, default=1.0, help='DTW loss weight')
    parser.add_argument('--vis_freq', type=int, default=100, help='Visualization frequency (epochs)')
    parser.add_argument('--save_freq', type=int, default=500, help='Checkpoint save frequency (epochs)')
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz',
                        help='Path to training dataset (use raw data, not normalized)')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train',
                        help='Path to train folder containing train_metadata.npy')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder',
                        help='Directory to save checkpoints and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
    parser.add_argument('--wandb_project', type=str, default='pick-place-encoder', help='Wandb project name')

    args = parser.parse_args()

    # Configuration from args
    config = {
        'dataset_path': args.dataset_path,
        'metadata_path': args.metadata_path,
        'state_dim': 22,
        'action_dim': 8,
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    # Create save directory with subfolders
    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        name=f'pick_place_vae_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    # Set device
    device = torch.device(config['device'])
    print(f"Using device: {device}")

    # Load dataset
    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )

    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    # Initialize model
    model = TrajectoryVAE(
        state_dim=config['state_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=1e-5
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs'],
        eta_min=1e-6
    )

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Load control point metadata for visualization
    cp_metadata = load_control_point_metadata(config['metadata_path'])
    if cp_metadata is not None:
        print(f"Loaded control point metadata for {cp_metadata['num_trajectories']} trajectories")
    else:
        print("Control point metadata not available - using default visualization")

    # Training loop
    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        # Train
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        # Log epoch metrics
        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        # Visualize embeddings
        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])

            fig = visualize_embeddings(
                all_z,
                epoch,
                cp_metadata=cp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )

            wandb.log({
                'embeddings': wandb.Image(fig),
                'epoch': epoch
            })
            plt.close(fig)

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        # Save checkpoint
        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        # Step scheduler
        scheduler.step()

    epoch_pbar.close()

    # Final embedding visualization
    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    fig = visualize_embeddings(
        all_z,
        config['num_epochs'],
        cp_metadata=cp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
