"""
Training script for pick_place encoder with normalized EE input.

KEY DESIGN - solving mode collapse while ensuring REACH/CARRY same-CP overlap:

1. Encoder input: Trajectory-frame normalized EE (3D: progress, perp1, perp2)
   - This is IDENTICAL for REACH and CARRY with same CP params
   - So encoder will produce IDENTICAL z for same-CP trajectories

2. Decoder output: Full state (22D) + action (8D) trajectory
   - This is MUCH HARDER than reconstructing 3D curve
   - Forces z to be discriminative (can't get away with collapsed z)

3. Decoder context: s0 + trajectory_length + time
   - s0: WHERE the trajectory starts
   - trajectory_length: HOW FAR it goes
   - time: WHEN in the trajectory
   - z must encode HOW it curves (the mode/CP)

Losses:
- Reconstruction: MSE on full state+action trajectory
- KL: Regularization to prevent posterior collapse
- DTW: Forces z-space distances to match trajectory shape distances
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

from trajectory_encoder_normalized_ee import TrajectoryVAENormalizedEE
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def build_local_frame_batched(start_pos, end_pos):
    """
    Build trajectory-relative orthonormal frame for a batch of trajectories.
    """
    line_vec = end_pos - start_pos
    dist = torch.norm(line_vec, dim=1, keepdim=True)
    dist = torch.clamp(dist, min=1e-6)
    line_dir = line_vec / dist

    world_up = torch.tensor([0.0, 0.0, 1.0], device=start_pos.device, dtype=start_pos.dtype)
    world_up = world_up.unsqueeze(0).expand(start_pos.shape[0], -1)

    dot = torch.sum(world_up * line_dir, dim=1, keepdim=True)
    perp1 = world_up - dot * line_dir
    perp1_len = torch.norm(perp1, dim=1, keepdim=True)

    nearly_vertical = perp1_len.squeeze(1) < 1e-6
    if nearly_vertical.any():
        world_forward = torch.tensor([0.0, 1.0, 0.0], device=start_pos.device, dtype=start_pos.dtype)
        world_forward = world_forward.unsqueeze(0).expand(start_pos.shape[0], -1)
        dot_fwd = torch.sum(world_forward * line_dir, dim=1, keepdim=True)
        perp1_alt = world_forward - dot_fwd * line_dir
        perp1_alt_len = torch.norm(perp1_alt, dim=1, keepdim=True)
        perp1_alt = perp1_alt / torch.clamp(perp1_alt_len, min=1e-6)
        perp1 = torch.where(nearly_vertical.unsqueeze(1), perp1_alt, perp1 / torch.clamp(perp1_len, min=1e-6))
    else:
        perp1 = perp1 / perp1_len

    perp2 = torch.cross(line_dir, perp1, dim=1)
    perp2 = perp2 / torch.norm(perp2, dim=1, keepdim=True).clamp(min=1e-6)

    return line_dir, perp1, perp2, dist.squeeze(1)


def normalize_ee_trajectory(ee_pos):
    """
    Normalize EE trajectory to trajectory-frame coordinates.

    Output: [progress, perp1_offset, perp2_offset]
    - progress: 0 at start, 1 at end
    - perp1_offset: perpendicular offset (scaled by path length)
    - perp2_offset: perpendicular offset (scaled by path length)

    This representation is IDENTICAL for REACH and CARRY with same CP.
    """
    B, T, _ = ee_pos.shape

    start_pos = ee_pos[:, 0, :]
    end_pos = ee_pos[:, -1, :]

    line_dir, perp1, perp2, dist = build_local_frame_batched(start_pos, end_pos)

    start_pos_exp = start_pos.unsqueeze(1)
    line_dir_exp = line_dir.unsqueeze(1)
    perp1_exp = perp1.unsqueeze(1)
    perp2_exp = perp2.unsqueeze(1)
    dist_exp = dist.unsqueeze(1).unsqueeze(2)

    vec_from_start = ee_pos - start_pos_exp
    progress = torch.sum(vec_from_start * line_dir_exp, dim=2, keepdim=True) / dist_exp

    line_vec = (end_pos - start_pos).unsqueeze(1)
    point_on_line = start_pos_exp + progress * line_vec
    offset = ee_pos - point_on_line

    perp1_offset = torch.sum(offset * perp1_exp, dim=2, keepdim=True) / dist_exp
    perp2_offset = torch.sum(offset * perp2_exp, dim=2, keepdim=True) / dist_exp

    ee_normalized = torch.cat([progress, perp1_offset, perp2_offset], dim=2)
    return ee_normalized, dist


def extract_and_normalize_ee(states):
    """
    Extract EE positions from states and normalize to trajectory-frame coordinates.

    states: (B, T, 22) - full state with EE at indices 15:18
    Returns:
        ee_normalized: (B, T, 3) - normalized EE [progress, perp1, perp2]
        trajectory_length: (B, 1) - length of EE path from start to end
    """
    ee_pos = states[:, :, 15:18]  # (B, T, 3)
    ee_normalized, trajectory_length = normalize_ee_trajectory(ee_pos)
    return ee_normalized, trajectory_length.unsqueeze(1)  # (B, T, 3), (B, 1)


def make_trajectory_relative(states, actions):
    """
    Convert trajectories to s0-relative coordinates.

    States: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    Actions: joint_pos(7) + gripper_open(1)
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    # Extract s0 reference values
    s0_joint_pos = states[:, 0:1, 0:7]  # (B, 1, 7)
    s0_joint_vel = states[:, 0:1, 7:14]  # (B, 1, 7)
    s0_gripper_xyz = states[:, 0:1, 15:18]  # (B, 1, 3)

    # Make relative to s0
    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel
    states_rel[:, :, 15:18] = states[:, :, 15:18] - s0_gripper_xyz

    # Actions: joint positions relative to s0
    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    return states_rel, actions_rel


def compute_reconstruction_loss(pred_states, pred_actions, gt_states_rel, gt_actions_rel):
    """
    Reconstruction loss for full state+action trajectory.

    pred_states: (B, T-1, state_dim) - predicted s1...s_{T-1}
    pred_actions: (B, T, action_dim) - predicted a0...a_{T-1}
    gt_states_rel: (B, T, state_dim) - ground truth s0-relative states
    gt_actions_rel: (B, T, action_dim) - ground truth s0-relative actions
    """
    action_loss = F.mse_loss(pred_actions, gt_actions_rel)

    if pred_states is not None:
        state_loss = F.mse_loss(pred_states, gt_states_rel[:, 1:, :])
    else:
        state_loss = torch.tensor(0.0, device=pred_actions.device)

    return action_loss + state_loss


def compute_kl_loss(z_mean, z_logvar):
    """KL divergence loss: KL(q(z|x) || p(z)) where p(z) = N(0, I)"""
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def compute_dtw_loss(z, ee_normalized, gamma=0.00001):
    """
    DTW loss: z-space distances should match normalized EE trajectory distances.

    This pushes different trajectory shapes to have different z values.
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    z_1 = z
    z_2 = torch.roll(z, 1, 0)
    ee_1 = ee_normalized
    ee_2 = torch.roll(ee_normalized, 1, 0)

    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)

    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_1, ee_2)

    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_control_point_metadata(metadata_path):
    """Load CP metadata for visualization."""
    try:
        if os.path.isfile(metadata_path) and metadata_path.endswith('.npy'):
            meta_file = metadata_path
        else:
            if 'processed' in metadata_path:
                base_dir = os.path.dirname(os.path.dirname(metadata_path))
                meta_file = os.path.join(base_dir, 'train', 'train_metadata.npy')
            elif 'train' in metadata_path:
                meta_file = os.path.join(metadata_path, 'train_metadata.npy')
            else:
                meta_file = os.path.join(metadata_path, 'train', 'train_metadata.npy')

        if not os.path.exists(meta_file):
            print(f"Metadata file not found: {meta_file}")
            return None

        episode_metadata = np.load(meta_file, allow_pickle=True)

        modes = []
        cp_params = []
        phase_types = []

        for ep_meta in episode_metadata:
            # REACH trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('reach')

            # CARRY trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('carry')

        return {
            'modes': np.array(modes),
            'cp_params': cp_params,
            'phase_types': phase_types,
            'num_trajectories': len(modes)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        return None


def visualize_embeddings(z, epoch, cp_metadata=None, save_path=None):
    """Visualize embeddings colored by mode and phase type."""
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_mode = axes[0]
    ax_phase = axes[1]
    ax_params = axes[2]

    if cp_metadata is not None and len(cp_metadata['modes']) >= len(z_vis):
        modes = cp_metadata['modes'][:len(z_vis)]
        cp_params = cp_metadata['cp_params'][:len(z_vis)]
        phase_types = cp_metadata['phase_types'][:len(z_vis)]

        unique_modes = np.unique(modes)
        mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
        mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}
        point_colors_mode = [mode_color_map[m] for m in modes]

        ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_mode,
                        s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax_mode.legend(handles=legend_elements, loc='upper right', fontsize=8)
        ax_mode.set_title(f'Embeddings by Mode (Epoch {epoch}){title_suffix}', fontsize=11)

        # Phase type coloring - REACH and CARRY should overlap!
        phase_color_map = {'reach': 'blue', 'carry': 'red'}
        point_colors_phase = [phase_color_map[p] for p in phase_types]

        ax_phase.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_phase,
                         s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        legend_elements_phase = [Patch(facecolor='blue', label='REACH'),
                                 Patch(facecolor='red', label='CARRY')]
        ax_phase.legend(handles=legend_elements_phase, loc='upper right', fontsize=10)
        ax_phase.set_title(f'Embeddings by Phase Type{title_suffix}\n(REACH & CARRY should overlap!)', fontsize=11)

        angles = [np.degrees(p[0]) for p in cp_params]
        dists = [p[1] for p in cp_params]

        ax_params.scatter(angles, dists, c=point_colors_mode, s=50, alpha=0.7,
                          edgecolors='black', linewidths=0.5)
        ax_params.set_xlabel('Angle (degrees)', fontsize=11)
        ax_params.set_ylabel('Distance fraction', fontsize=11)
        ax_params.set_title('Control Point Parameters', fontsize=11)
        ax_params.set_xlim(-20, 380)
        ax_params.set_ylim(0, 1.2)
        ax_params.grid(True, alpha=0.3)
    else:
        scatter = ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7)
        plt.colorbar(scatter, ax=ax_mode, label='Trajectory Index')

    ax_mode.set_xlabel('z[0]', fontsize=11)
    ax_mode.set_ylabel('z[1]', fontsize=11)
    ax_mode.grid(True, alpha=0.3)

    ax_phase.set_xlabel('z[0]', fontsize=11)
    ax_phase.set_ylabel('z[1]', fontsize=11)
    ax_phase.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def visualize_embeddings_by_phase(z, epoch, cp_metadata=None, save_dir=None):
    """
    Visualize embeddings separately for REACH and CARRY phases.

    Outputs two separate figures:
    - embeddings_reach_epoch_X.png
    - embeddings_carry_epoch_X.png
    """
    if cp_metadata is None or len(cp_metadata['modes']) < len(z):
        print("Warning: cp_metadata not available for phase-separated visualization")
        return None, None

    modes = cp_metadata['modes'][:len(z)]
    cp_params = cp_metadata['cp_params'][:len(z)]
    phase_types = np.array(cp_metadata['phase_types'][:len(z)])

    # Separate indices for REACH and CARRY
    reach_mask = phase_types == 'reach'
    carry_mask = phase_types == 'carry'

    z_reach = z[reach_mask]
    z_carry = z[carry_mask]
    modes_reach = modes[reach_mask]
    modes_carry = modes[carry_mask]
    cp_params_reach = [cp_params[i] for i in range(len(cp_params)) if reach_mask[i]]
    cp_params_carry = [cp_params[i] for i in range(len(cp_params)) if carry_mask[i]]

    unique_modes = np.unique(modes)
    mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
    mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}

    figs = []

    for phase_name, z_phase, modes_phase, cp_params_phase in [
        ('reach', z_reach, modes_reach, cp_params_reach),
        ('carry', z_carry, modes_carry, cp_params_carry)
    ]:
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Left: embeddings colored by mode
        ax = axes[0]
        point_colors = [mode_color_map[m] for m in modes_phase]
        ax.scatter(z_phase[:, 0], z_phase[:, 1], c=point_colors,
                   s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax.legend(handles=legend_elements, loc='upper right', fontsize=8)
        ax.set_title(f'{phase_name.upper()} Embeddings by Mode (Epoch {epoch})', fontsize=12)
        ax.set_xlabel('z[0]', fontsize=11)
        ax.set_ylabel('z[1]', fontsize=11)
        ax.grid(True, alpha=0.3)

        # Right: CP parameters
        ax = axes[1]
        angles = [np.degrees(p[0]) for p in cp_params_phase]
        dists = [p[1] for p in cp_params_phase]
        ax.scatter(angles, dists, c=point_colors, s=50, alpha=0.7,
                   edgecolors='black', linewidths=0.5)
        ax.set_xlabel('Angle (degrees)', fontsize=11)
        ax.set_ylabel('Distance fraction', fontsize=11)
        ax.set_title(f'{phase_name.upper()} Control Point Parameters', fontsize=12)
        ax.set_xlim(-20, 380)
        ax.set_ylim(0, 1.2)
        ax.grid(True, alpha=0.3)

        plt.tight_layout()

        if save_dir:
            save_path = os.path.join(save_dir, f'embeddings_{phase_name}_epoch_{epoch}.png')
            plt.savefig(save_path, dpi=150, bbox_inches='tight')

        figs.append(fig)

    return figs[0], figs[1]  # reach_fig, carry_fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch."""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        states = batch.conditions['state'].to(device)  # (B, T, 22)
        actions = batch.actions.to(device)  # (B, T, 8)

        # Extract s0 (absolute) for decoder context
        s0 = states[:, 0, :]  # (B, 22)

        # Extract and normalize EE trajectory (identical for REACH/CARRY with same CP)
        ee_normalized, trajectory_length = extract_and_normalize_ee(states)

        # Convert to s0-relative for reconstruction targets
        states_rel, actions_rel = make_trajectory_relative(states, actions)

        # Forward pass
        outputs = model(ee_normalized, s0, trajectory_length, states_rel, actions_rel)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        pred_states = outputs['pred_states']
        pred_actions = outputs['pred_actions']

        # Compute losses
        recon_loss = compute_reconstruction_loss(pred_states, pred_actions, states_rel, actions_rel)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, ee_normalized, gamma=0.00001)

        loss = recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """Collect all embeddings for visualization."""
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            states = batch.conditions['state'].to(device)

            # Extract and normalize EE (same as training)
            ee_normalized, _ = extract_and_normalize_ee(states)

            # Encode
            z = model.encode(ee_normalized)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def main():
    parser = argparse.ArgumentParser(description='Train Pick-and-Place Encoder with Normalized EE Input')
    parser.add_argument('--horizon', type=int, default=64)
    parser.add_argument('--latent_dim', type=int, default=2)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=5000)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--kl_weight', type=float, default=0.001)
    parser.add_argument('--dtw_weight', type=float, default=1.0)
    parser.add_argument('--vis_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=500)
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--wandb_project', type=str, default='pick-place-normalized-ee-encoder')

    args = parser.parse_args()

    config = {
        'dataset_path': args.dataset_path,
        'metadata_path': args.metadata_path,
        'state_dim': 22,
        'action_dim': 8,
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}_normalized_ee')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    wandb.init(
        project=args.wandb_project,
        name=f'norm_ee_vae_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    device = torch.device(config['device'])
    print(f"Using device: {device}")

    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )
    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True
    )

    # Use new VAE with normalized EE input
    model = TrajectoryVAENormalizedEE(
        state_dim=config['state_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'], eta_min=1e-6)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    cp_metadata = load_control_point_metadata(config['metadata_path'])
    if cp_metadata is not None:
        print(f"Loaded CP metadata for {cp_metadata['num_trajectories']} trajectories")

    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])

            # Print z statistics
            print(f"\n  z stats: mean={all_z.mean(0)}, std={all_z.std(0)}, min={all_z.min(0)}, max={all_z.max(0)}")

            fig = visualize_embeddings(
                all_z, epoch, cp_metadata=cp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )
            wandb.log({'embeddings': wandb.Image(fig), 'epoch': epoch})
            plt.close(fig)

            # Also generate separate REACH and CARRY visualizations
            reach_fig, carry_fig = visualize_embeddings_by_phase(
                all_z, epoch, cp_metadata=cp_metadata, save_dir=visualization_dir
            )
            if reach_fig is not None:
                wandb.log({'embeddings_reach': wandb.Image(reach_fig), 'epoch': epoch})
                plt.close(reach_fig)
            if carry_fig is not None:
                wandb.log({'embeddings_carry': wandb.Image(carry_fig), 'epoch': epoch})
                plt.close(carry_fig)

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        scheduler.step()

    epoch_pbar.close()

    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    print(f"Final z stats: mean={all_z.mean(0)}, std={all_z.std(0)}, min={all_z.min(0)}, max={all_z.max(0)}")

    fig = visualize_embeddings(
        all_z, config['num_epochs'], cp_metadata=cp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    # Generate separate REACH and CARRY final visualizations
    reach_fig, carry_fig = visualize_embeddings_by_phase(
        all_z, config['num_epochs'], cp_metadata=cp_metadata, save_dir=visualization_dir
    )
    if reach_fig is not None:
        wandb.log({'final_embeddings_reach': wandb.Image(reach_fig)})
        plt.close(reach_fig)
    if carry_fig is not None:
        wandb.log({'final_embeddings_carry': wandb.Image(carry_fig)})
        plt.close(carry_fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
