import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from diffusers import AutoencoderKL
from PIL import Image
import numpy as np
import os
import argparse
import glob
import matplotlib.pyplot as plt
import requests
import h5py
from tqdm import tqdm

# ==========================================
# 0. DATASET DOWNLOAD & EXTRACTION
# ==========================================
def download_and_extract_nyu(data_root):
    """Download and extract NYU Depth V2 dataset if not already present."""
    # Official URL for the Labeled Dataset (subset of 1449 aligned pairs)
    url = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
    filename = "nyu_depth_v2_labeled.mat"
    filepath = os.path.join(data_root, filename)

    # Output directories
    img_dir = os.path.join(data_root, "images")
    depth_dir = os.path.join(data_root, "depths")

    # --- 1. Check if already extracted ---
    if os.path.exists(img_dir) and os.path.exists(depth_dir):
        num_images = len([f for f in os.listdir(img_dir) if f.endswith('.png')]) if os.path.exists(img_dir) else 0
        num_depths = len([f for f in os.listdir(depth_dir) if f.endswith('.png')]) if os.path.exists(depth_dir) else 0
        
        if num_images > 1400 and num_depths > 1400:
            print(f"✓ Dataset already extracted: {num_images} images, {num_depths} depth maps")
            return

    # --- 2. Download ---
    if not os.path.exists(filepath):
        print(f"File not found. Downloading {filename} (approx. 2.8 GB)...")
        os.makedirs(data_root, exist_ok=True)
        
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filepath, 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True, desc=filename) as bar:
            for chunk in response.iter_content(chunk_size=1024*1024):  # 1MB chunks
                if chunk:
                    f.write(chunk)
                    bar.update(len(chunk))
        print("Download complete.")
    else:
        print(f"Found {filename}, skipping download.")

    # --- 3. Extract ---
    print("Extracting Images and Depth Maps from .mat file...")
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(depth_dir, exist_ok=True)

    # The MAT file is actually an HDF5 container
    with h5py.File(filepath, 'r') as f:
        # Note: h5py reads Matlab arrays transposed!
        # Matlab 'images' is (H, W, 3, N) -> h5py sees (N, 3, W, H)
        # Matlab 'depths' is (H, W, N)    -> h5py sees (N, W, H)
        
        images = f['images']
        depths = f['depths']
        num_samples = images.shape[0]
        
        print(f"Processing {num_samples} samples...")
        
        for i in tqdm(range(num_samples)):
            # --- Extract RGB ---
            # Shape (3, W, H) -> Transpose to (H, W, 3)
            img_data = images[i]
            img_data = np.transpose(img_data, (2, 1, 0)) 
            img = Image.fromarray(img_data.astype('uint8'))
            img.save(os.path.join(img_dir, f"{i:04d}.png"))
            
            # --- Extract Depth ---
            # Shape (W, H) -> Transpose to (H, W)
            # Data is in meters (float)
            depth_data = depths[i]
            depth_data = np.transpose(depth_data, (1, 0))
            
            # Save as 16-bit PNG (Millimeters) to preserve precision
            # Most depth loaders expect this format
            depth_mm = (depth_data * 1000).astype('uint16')
            depth_img = Image.fromarray(depth_mm, mode='I;16')
            depth_img.save(os.path.join(depth_dir, f"{i:04d}_depth.png"))

    print(f"Done! Data located in: {data_root}")
    print(f"  Images: {img_dir}")
    print(f"  Depths: {depth_dir}")

# ==========================================
# 1. MODEL: Strong Depth Branch (SD-VAE)
# ==========================================
class StrongDepthBranch(nn.Module):
    def __init__(self, latent_dim, model_name="stabilityai/sd-vae-ft-mse", freeze_vae=True, use_projections=True):
        """
        Adapts the 83M param Stable Diffusion VAE for 1-channel Depth.
        Learns to project Depth -> VAE Latent Space -> Depth.
        
        Args:
            latent_dim: Dimension for projection layers
            model_name: HuggingFace model name for VAE
            freeze_vae: Whether to freeze VAE parameters
            use_projections: Whether to include linear projection layers
        """
        super().__init__()
        print(f"Loading SD-VAE: {model_name}...")
        self.vae = AutoencoderKL.from_pretrained(model_name)
        self.use_projections = use_projections
        
        if freeze_vae:
            # Freeze the massive VAE to save memory and force the 
            # linear layers to adapt the depth data to the VAE's "language".
            for param in self.vae.parameters():
                param.requires_grad = False
            self.vae.eval()
            
        # VAE downsamples by 8. Latent channels = 4.
        # We assume 256x256 input -> 32x32 feature map.
        self.spatial_dim = 32 
        self.vae_flat_dim = 4 * self.spatial_dim * self.spatial_dim
        print(f"VAE latent flat dim: {self.vae_flat_dim}")
        
        # Trainable Projections (optional)
        if use_projections:
            self.to_latent = nn.Linear(self.vae_flat_dim, latent_dim)
            self.from_latent = nn.Linear(latent_dim, self.vae_flat_dim)
        else:
            self.to_latent = None
            self.from_latent = None
        
        self.scale_factor = 0.18215

    def encode(self, x):
        # x is (B, 1, H, W) in [0, 1]
        
        # 1. Adapter: 1ch -> 3ch (Repeat)
        # The VAE expects 3 channels. We just duplicate the depth map.
        x_3ch = x.repeat(1, 3, 1, 1)
        
        # 2. Normalize [0, 1] -> [-1, 1] for VAE
        x_norm = 2.0 * x_3ch - 1.0
        
        # Get Mode (Deterministic encoding)
        # Use no_grad only if VAE is frozen
        if not self.vae.training:
            with torch.no_grad():
                dist = self.vae.encode(x_norm).latent_dist
                z_vae = dist.mode() 
                z_vae = z_vae * self.scale_factor
        else:
            dist = self.vae.encode(x_norm).latent_dist
            z_vae = dist.mode() 
            z_vae = z_vae * self.scale_factor
            
        if self.use_projections:
            z_flat = torch.flatten(z_vae, start_dim=1)
            return self.to_latent(z_flat)
        else:
            return z_vae

    def decode(self, z):
        # Project back
        if self.use_projections:
            z_flat = self.from_latent(z)
            z_vae = z_flat.view(-1, 4, self.spatial_dim, self.spatial_dim)
        else:
            z_vae = z
        
        z_vae = z_vae / self.scale_factor
        
        # Decode through VAE
        # Use no_grad only if VAE is frozen
        if not self.vae.training:
            with torch.no_grad():
                x_hat_3ch = self.vae.decode(z_vae).sample
        else:
            x_hat_3ch = self.vae.decode(z_vae).sample
            
        # 3. Adapter: 3ch -> 1ch (Mean)
        x_hat = torch.mean(x_hat_3ch, dim=1, keepdim=True)
        
        # Convert [-1, 1] back to [0, 1]
        x_hat = (x_hat / 2.0 + 0.5).clamp(0, 1)
        return x_hat

# ==========================================
# 2. DATASET LOADER (NYU Depth Style)
# ==========================================
class DepthDataset(Dataset):
    def __init__(self, root_dir, split='train', size=256):
        """
        Assumes structure:
            root_dir/depths/img_0001.png ...
        """
        self.root_dir = root_dir
        self.size = size
        
        # Find all depth images (png, jpg, tif)
        # Adjust 'depths' to whatever your folder is named
        search_path = os.path.join(root_dir, '**', '*depth*.png') 
        self.files = sorted(glob.glob(search_path, recursive=True))
        
        if len(self.files) == 0:
            # Fallback search
            self.files = sorted(glob.glob(os.path.join(root_dir, '*.png')))
            
        if len(self.files) == 0:
            raise FileNotFoundError(f"No depth images found in {root_dir}")

        # Simple split
        split_idx = int(0.9 * len(self.files))
        if split == 'train':
            self.files = self.files[:split_idx]
        else:
            self.files = self.files[split_idx:]
            
        print(f"Found {len(self.files)} depth maps for {split}.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        
        # Load Depth
        # Mode 'I' or 'I;16' is 16-bit integer
        depth_img = Image.open(path)
        
        # Resize
        depth_img = depth_img.resize((self.size, self.size), Image.NEAREST)
        
        # Convert to float meters
        depth = np.array(depth_img).astype(np.float32)
        
        # If values are > 255, it's millimeters. Convert to meters.
        if depth.max() > 255.0:
            depth = depth / 1000.0  # mm to meters
            
        # Clip to max distance (e.g., 10m) and Normalize to [0, 1]
        depth = np.clip(depth, 0, 10.0) / 10.0
            
        # Ensure channel dim (1, H, W)
        tensor = torch.from_numpy(depth).unsqueeze(0)
        
        return tensor

# ==========================================
# 3. TRAINING SCRIPT
# ==========================================
def train_stage1_vae(args):
    """Stage 1: Finetune VAE only on depth maps (no projection layers)"""
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Create plot directory
    plot_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '03_results', 'train_plots', 'nyudepthv2')
    os.makedirs(plot_dir, exist_ok=True)
    
    # 0. Check and download dataset if needed
    print("="*80)
    print("STAGE 1: Finetuning VAE on Depth Maps")
    print("="*80)
    print("Checking dataset availability...")
    download_and_extract_nyu(args.data_root)
    print("="*80 + "\n")
    
    # 1. Data
    print("Initializing Data...")
    train_ds = DepthDataset(args.data_root, split='train')
    val_ds = DepthDataset(args.data_root, split='val')
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    # 2. Model (VAE unfrozen, no projections)
    print("Initializing Model (VAE unfrozen, no projections)...")
    model = StrongDepthBranch(
        latent_dim=args.latent_dim, 
        freeze_vae=False,  # Unfreeze VAE
        use_projections=False  # No projection layers
    ).to(device)
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel Statistics:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Non-trainable parameters: {total_params - trainable_params:,}")
    
    # Print sample image shape
    sample_batch = next(iter(train_loader))
    print(f"\nData Statistics:")
    print(f"  Input shape: {sample_batch.shape} (batch, channels, height, width)")
    print(f"  Training samples: {len(train_ds)}")
    print(f"  Validation samples: {len(val_ds)}")
    print(f"  Batch size: {args.batch_size}\n")
    
    # Optimizer (Train VAE parameters only)
    optimizer = torch.optim.AdamW(
        model.vae.parameters(),  # Only VAE params
        lr=args.lr_stage1
    )
    criterion = nn.MSELoss()
    
    # Mixed precision scaler
    scaler = GradScaler() if args.mixed_precision else None
    if args.mixed_precision:
        print("Using mixed precision training (BF16)")
    
    # Save initial reconstruction (before training)
    print("Saving initial reconstruction (before training)...")
    model.eval()
    with torch.no_grad():
        sample_batch = next(iter(val_loader)).to(device)
        z_init = model.encode(sample_batch)
        x_hat_init = model.decode(z_init)
        save_plot(sample_batch, x_hat_init, -1, plot_dir, prefix="stage1_initial")
    
    # 3. Loop
    print(f"Starting Stage 1 Training ({args.epochs_stage1} epochs)...")
    best_loss = float('inf')
    
    for epoch in range(args.epochs_stage1):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            x = batch.to(device)
            
            optimizer.zero_grad()
            
            # Forward (VAE only, no projections)
            if args.mixed_precision:
                with autocast(dtype=torch.bfloat16):
                    z = model.encode(x)
                    x_hat = model.decode(z)
                    loss = criterion(x_hat, x)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                z = model.encode(x)
                x_hat = model.decode(z)
                loss = criterion(x_hat, x)
                
                loss.backward()
                optimizer.step()
            
            train_loss += loss.item()
        
        # Store train samples for plotting
        with torch.no_grad():
            train_sample = x.detach()
            train_recon = x_hat.detach()
            
        avg_train = train_loss / len(train_loader)
        
        # Val
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x = batch.to(device)
                z = model.encode(x)
                x_hat = model.decode(z)
                loss = criterion(x_hat, x)
                val_loss += loss.item()
        avg_val = val_loss / len(val_loader)
        
        print(f"[Stage 1] Epoch {epoch+1}/{args.epochs_stage1} | Train MSE: {avg_train:.6f} | Val MSE: {avg_val:.6f}")
        
        # Save Best
        if avg_val < best_loss:
            best_loss = avg_val
            torch.save(model.vae.state_dict(), os.path.join(args.save_dir, f"depth_vae_stage1_best.pth"))
            
            # Save visual check (both train and val)
            save_plot_train_val(train_sample, train_recon, x, x_hat, epoch, plot_dir, prefix="stage1")
    
    print(f"\nStage 1 Complete! Best Val MSE: {best_loss:.6f}")
    return os.path.join(args.save_dir, f"depth_vae_stage1_best.pth")


def train_stage2_projections(args, vae_checkpoint):
    """Stage 2: Add and train projection layers with frozen VAE"""
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    
    # Create plot directory
    plot_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '03_results', 'train_plots', 'nyudepthv2')
    os.makedirs(plot_dir, exist_ok=True)
    
    print("\n" + "="*80)
    print("STAGE 2: Training Projection Layers (Frozen VAE)")
    print("="*80)
    
    # 1. Data
    print("Initializing Data...")
    train_ds = DepthDataset(args.data_root, split='train')
    val_ds = DepthDataset(args.data_root, split='val')
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    # 2. Model (VAE frozen, with projections)
    if vae_checkpoint is not None:
        print(f"Initializing Model (loading VAE from {vae_checkpoint})...")
    else:
        print("Initializing Model (using pretrained VAE without stage 1 finetuning)...")
    
    model = StrongDepthBranch(
        latent_dim=args.latent_dim,
        freeze_vae=True,  # Freeze VAE
        use_projections=True  # Add projection layers
    ).to(device)
    
    # Load finetuned VAE weights if checkpoint provided
    if vae_checkpoint is not None:
        model.vae.load_state_dict(torch.load(vae_checkpoint, map_location=device))
        print("Loaded finetuned VAE weights from Stage 1")
    else:
        print("Using pretrained Stable Diffusion VAE (no stage 1 checkpoint)")
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    vae_params = sum(p.numel() for p in model.vae.parameters())
    projection_params = sum(p.numel() for p in model.to_latent.parameters()) + sum(p.numel() for p in model.from_latent.parameters())
    print(f"\nModel Statistics:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  VAE parameters (frozen): {vae_params:,}")
    print(f"  Projection parameters (trainable): {projection_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    # Print sample image shape
    sample_batch = next(iter(train_loader))
    print(f"\nData Statistics:")
    print(f"  Input shape: {sample_batch.shape} (batch, channels, height, width)")
    print(f"  Training samples: {len(train_ds)}")
    print(f"  Validation samples: {len(val_ds)}")
    print(f"  Batch size: {args.batch_size}\n")
    
    # Optimizer (Only train projection layers)
    trainable_params = list(model.to_latent.parameters()) + list(model.from_latent.parameters())
    optimizer = torch.optim.AdamW(trainable_params, lr=args.lr_stage2)
    criterion = nn.MSELoss()
    
    # Linear LR Scheduler: decay from lr_stage2 to lr_stage2/100 over all epochs
    lambda_lr = lambda epoch: 1.0 - (epoch / args.epochs_stage2) * 0.99  # Goes from 1.0 to 0.01
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)
    
    # Mixed precision scaler
    scaler = GradScaler() if args.mixed_precision else None
    if args.mixed_precision:
        print("Using mixed precision training (BF16)")
    
    # Save initial reconstruction (before training)
    print("Saving initial reconstruction (before training)...")
    model.eval()
    with torch.no_grad():
        sample_batch = next(iter(val_loader)).to(device)
        z_init = model.encode(sample_batch)
        x_hat_init = model.decode(z_init)
        save_plot(sample_batch, x_hat_init, -1, plot_dir, prefix="stage2_initial")
    
    # 3. Loop
    print(f"Starting Stage 2 Training ({args.epochs_stage2} epochs)...")
    print(f"Learning rate schedule: {args.lr_stage2:.2e} -> {args.lr_stage2/100:.2e}")
    best_loss = float('inf')
    
    for epoch in range(args.epochs_stage2):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            x = batch.to(device)
            
            optimizer.zero_grad()
            
            # Forward (frozen VAE + trainable projections)
            if args.mixed_precision:
                with autocast(dtype=torch.bfloat16):
                    z = model.encode(x)
                    x_hat = model.decode(z)
                    loss = criterion(x_hat, x)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                z = model.encode(x)
                x_hat = model.decode(z)
                loss = criterion(x_hat, x)
                
                loss.backward()
                optimizer.step()
            
            train_loss += loss.item()
        
        # Store train samples for plotting
        with torch.no_grad():
            train_sample = x.detach()
            train_recon = x_hat.detach()
            
        avg_train = train_loss / len(train_loader)
        
        # Val
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x = batch.to(device)
                z = model.encode(x)
                x_hat = model.decode(z)
                loss = criterion(x_hat, x)
                val_loss += loss.item()
        avg_val = val_loss / len(val_loader)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"[Stage 2] Epoch {epoch+1}/{args.epochs_stage2} | Train MSE: {avg_train:.6f} | Val MSE: {avg_val:.6f} | LR: {current_lr:.2e}")
        
        # Step the scheduler
        scheduler.step()
        
        # Save Best
        if avg_val < best_loss:
            best_loss = avg_val
            torch.save(model.state_dict(), os.path.join(args.save_dir, f"depth_ae_stage2_best.pth"))
            
            # Save visual check (both train and val)
            save_plot_train_val(train_sample, train_recon, x, x_hat, epoch, plot_dir, prefix="stage2")
    
    print(f"\nStage 2 Complete! Best Val MSE: {best_loss:.6f}")
    print(f"Final model saved to: {os.path.join(args.save_dir, 'depth_ae_stage2_best.pth')}")

def save_plot_train_val(train_orig, train_recon, val_orig, val_recon, epoch, save_dir, prefix=""):
    """Plot both training and validation samples side by side"""
    # Plot up to 4 samples from each
    n_train = min(4, train_orig.shape[0])
    n_val = min(4, val_orig.shape[0])
    n_samples = min(n_train, n_val)
    
    # Convert to float32 first (handles bfloat16), then to numpy
    train_orig = train_orig[:n_samples].float().cpu().numpy()
    train_recon = train_recon[:n_samples].float().cpu().numpy()
    val_orig = val_orig[:n_samples].float().cpu().numpy()
    val_recon = val_recon[:n_samples].float().cpu().numpy()
    
    # Create figure: 4 rows (train orig, train recon, val orig, val recon) x n_samples columns
    fig, axes = plt.subplots(4, n_samples, figsize=(2.5 * n_samples, 10))
    
    # Handle case where n_samples == 1
    if n_samples == 1:
        axes = axes.reshape(4, 1)
    
    for i in range(n_samples):
        axes[0,i].imshow(train_orig[i,0], cmap='inferno')
        axes[0,i].axis('off')
        if i == 0:
            axes[0,i].set_title("Train Original", fontsize=10)
        
        axes[1,i].imshow(train_recon[i,0], cmap='inferno')
        axes[1,i].axis('off')
        if i == 0:
            axes[1,i].set_title("Train Recon", fontsize=10)
        
        axes[2,i].imshow(val_orig[i,0], cmap='inferno')
        axes[2,i].axis('off')
        if i == 0:
            axes[2,i].set_title("Val Original", fontsize=10)
        
        axes[3,i].imshow(val_recon[i,0], cmap='inferno')
        axes[3,i].axis('off')
        if i == 0:
            axes[3,i].set_title("Val Recon", fontsize=10)
        
    plt.tight_layout()
    filename = f"{prefix}_recon_epoch_{epoch}.png" if prefix else f"recon_epoch_{epoch}.png"
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()

def save_plot(orig, recon, epoch, save_dir, prefix=""):
    # Plot up to 4 samples (handle smaller batches)
    n_samples = min(4, orig.shape[0])
    orig = orig[:n_samples].cpu().numpy()
    recon = recon[:n_samples].cpu().numpy()
    
    fig, axes = plt.subplots(2, n_samples, figsize=(2.5 * n_samples, 5))
    
    # Handle case where n_samples == 1 (axes won't be 2D array)
    if n_samples == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(n_samples):
        axes[0,i].imshow(orig[i,0], cmap='inferno')
        axes[0,i].axis('off')
        axes[0,i].set_title("Original")
        
        axes[1,i].imshow(recon[i,0], cmap='inferno')
        axes[1,i].axis('off')
        axes[1,i].set_title("Recon")
        
    plt.tight_layout()
    filename = f"{prefix}_recon_epoch_{epoch}.png" if prefix else f"recon_epoch_{epoch}.png"
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', type=str, default='/ewsc/vschuste/data/mm_benchmarks/')
    parser.add_argument('--save_dir', type=str, default='/ewsc/vschuste/models/mm_benchmarks/')
    parser.add_argument('--latent_dim', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=16)  # VAE uses VRAM, keep batch small
    
    # Stage 1: VAE finetuning
    parser.add_argument('--epochs_stage1', type=int, default=100, help='Epochs for VAE finetuning')
    parser.add_argument('--lr_stage1', type=float, default=1e-5, help='Learning rate for VAE finetuning (lower for pretrained model)')
    
    # Stage 2: Projection training
    parser.add_argument('--epochs_stage2', type=int, default=100, help='Epochs for projection layer training')
    parser.add_argument('--lr_stage2', type=float, default=1e-3, help='Learning rate for projection layers')
    
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--stage', type=str, choices=['both', '1', '2'], default='both',
                        help='Which stage to run: both, 1 (VAE only), or 2 (projections only)')
    parser.add_argument('--stage1_checkpoint', type=str, default=None,
                        help='Path to stage 1 checkpoint (optional for stage=2, will use pretrained VAE if not provided)')
    parser.add_argument('--mixed_precision', action='store_true',
                        help='Use mixed precision training (BF16) to reduce memory requirements')
    
    args = parser.parse_args()
    
    if args.stage in ['both', '1']:
        # Run Stage 1: VAE finetuning
        vae_checkpoint = train_stage1_vae(args)
    else:
        vae_checkpoint = args.stage1_checkpoint
    
    if args.stage in ['both', '2']:
        # Run Stage 2: Projection training
        # vae_checkpoint can be None - will use pretrained VAE without stage 1 finetuning
        train_stage2_projections(args, vae_checkpoint)