

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import math
import numpy as np
import tqdm
import gc
from src.visualization.logging import plot_training_state, create_training_movie
from src.models.larrp_multimodal_cnn import AdaptiveRankReducedAE_CNN, MMSimData
import matplotlib.pyplot as plt
import pandas as pd
import os
from transformers import T5Tokenizer
from transformers import BertTokenizer

def compute_direct_r_squared(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Either DataLoader or list of tensors [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    from torch.utils.data import DataLoader
    model.eval()
    
    # Check if data is a DataLoader (for image-text datasets)
    if isinstance(data, DataLoader):
        # Accumulate statistics across batches
        n_modalities = 2  # image-text always has 2 modalities
        sum_squared_residuals = [0.0] * n_modalities
        sum_squared_total = [0.0] * n_modalities
        n_samples = 0
        
        # First pass: compute mean per modality
        modality_sums = [0.0] * n_modalities
        with torch.no_grad():
            for batch in data:
                images = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                batch_size = images.shape[0]
                n_samples += batch_size
                
                # Accumulate for mean computation (flatten spatial/sequence dims)
                modality_sums[0] += images.flatten(start_dim=1).sum(dim=0).cpu()
                modality_sums[1] += input_ids.flatten(start_dim=1).sum(dim=0).cpu()
        
        # Compute means
        modality_means = [s / n_samples for s in modality_sums]
        
        # Second pass: compute R²
        first_batch_debug = True
        with torch.no_grad():
            for batch in data:
                images = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Get reconstructions
                reconstructions, _ = model(images, input_ids, attention_mask)
                img_recon, text_logits = reconstructions
                
                # For text, convert logits to tokens for comparison
                text_recon = torch.argmax(text_logits, dim=-1)
                
                # Flatten and compute statistics
                img_flat = images.flatten(start_dim=1).cpu()
                img_recon_flat = img_recon.flatten(start_dim=1).cpu()
                text_flat = input_ids.flatten(start_dim=1).cpu()
                text_recon_flat = text_recon.flatten(start_dim=1).cpu()
                
                # Create mask for non-padding tokens (0 is padding in BERT)
                text_mask = (input_ids != 0).flatten(start_dim=1).cpu()
                
                # Debug: Check values on first batch
                if first_batch_debug and verbose:
                    print(f"\n=== R² Computation Debug (First Batch) ===")
                    print(f"Image data - shape: {img_flat.shape}, min: {img_flat.min().item():.4f}, max: {img_flat.max().item():.4f}, mean: {img_flat.mean().item():.4f}")
                    print(f"Image recon - shape: {img_recon_flat.shape}, min: {img_recon_flat.min().item():.4f}, max: {img_recon_flat.max().item():.4f}, mean: {img_recon_flat.mean().item():.4f}")
                    print(f"Text data - shape: {text_flat.shape}, min: {text_flat.min().item()}, max: {text_flat.max().item()}, mean: {text_flat.float().mean().item():.4f}")
                    print(f"Text recon - shape: {text_recon_flat.shape}, min: {text_recon_flat.min().item()}, max: {text_recon_flat.max().item()}, mean: {text_recon_flat.float().mean().item():.4f}")
                    print(f"Text non-padding tokens: {text_mask.sum().item()} / {text_mask.numel()}")
                    print(f"Modality means - Image: {modality_means[0].mean().item():.4f}, Text: {modality_means[1].mean().item():.4f}")
                    first_batch_debug = False
                
                # Accumulate SSR and SST for images
                sum_squared_residuals[0] += ((img_flat - img_recon_flat) ** 2).sum().item()
                sum_squared_total[0] += ((img_flat - modality_means[0]) ** 2).sum().item()
                
                # Accumulate SSR and SST for text (only non-padding tokens)
                text_flat_valid = text_flat.float()[text_mask]
                text_recon_flat_valid = text_recon_flat.float()[text_mask]
                modality_means_1_expanded = modality_means[1].unsqueeze(0).expand_as(text_flat)[text_mask]
                
                sum_squared_residuals[1] += ((text_flat_valid - text_recon_flat_valid) ** 2).sum().item()
                sum_squared_total[1] += ((text_flat_valid - modality_means_1_expanded) ** 2).sum().item()
        
        # Compute R² for each modality
        r_squared_values = []
        if verbose:
            print(f"\n=== R² Final Computation ===")
        for i in range(n_modalities):
            if verbose:
                print(f"Modality {i}: SSR={sum_squared_residuals[i]:.4f}, SST={sum_squared_total[i]:.4f}")
            if sum_squared_total[i] > 1e-6:
                r_squared = 1.0 - (sum_squared_residuals[i] / (sum_squared_total[i] + 1e-9))
                r_squared = max(0.0, r_squared)  # Clamp to [0, 1]
            else:
                r_squared = 0.0
            if verbose:
                print(f"Modality {i}: R²={r_squared:.4f}")
            r_squared_values.append(r_squared)
        
        return r_squared_values
    
    # Original tensor list implementation
    r_squared_values = []
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate R² for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            # Move to CPU and flatten to (N, D)
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            # Align batch dimension if needed
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Debug: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for R² computation.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Calculate mean of original flattened data
            original_mean = orig_flat.mean(dim=0).cpu()
            original_cpu = orig_flat
            reconstruction_cpu = recon_flat
            
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean for modality {i}. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data for modality {i}. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                #print(f"   Computing R² normally for modality {i}.")
                # check if there are any NaNs or Infs in original or reconstruction
                if torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)) or \
                   torch.any(torch.isnan(reconstruction_cpu)) or torch.any(torch.isinf(reconstruction_cpu)):
                    if verbose:
                        print(f"   Warning: NaN or Inf values found in data for modality {i}. Handling them.")
                    # Handle NaN or Inf values
                    valid_mask = ~torch.isnan(original_cpu) & ~torch.isinf(original_cpu) & \
                                 ~torch.isnan(reconstruction_cpu) & ~torch.isinf(reconstruction_cpu)
                    if valid_mask.sum() == 0:
                        # If no valid values, set R² to 0
                        r_squared = torch.tensor(0.0)
                    else:
                        valid_indices = valid_mask
                        ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                        ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                else:
                    # Normal case - calculate standard R²
                    # print mean original and mean reconstruction
                    #print(f"   Mean original (modality {i}): {original_cpu.mean().item()}, Mean reconstruction: {reconstruction_cpu.mean().item()}")
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)
                    # if there are any very small ss_tot values, print a warning
                    if torch.any(ss_tot < 1e-3):
                        if verbose:
                            print(f"   Warning: Very small ss_tot values found for modality {i}. This may lead to unstable R² values.")
                        valid_mask = ss_tot >= 1e-3
                        if valid_mask.sum() == 0:
                            r_squared = torch.tensor(0.0)
                        else:
                            ssr = ssr[valid_mask]
                            ss_tot = ss_tot[valid_mask]
                #print(f"   SSR sum: {ssr.sum().item()}, SSTot sum: {ss_tot.sum().item()}")
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                # mask out the ones that are negative
                # if there are any negative r_squared values, print a warning
                if torch.any(r_squared < 0):
                    if verbose: # print the fraction of negative values
                        n_negative = (r_squared < 0).sum().item()
                        total = r_squared.numel()
                        fraction_negative = n_negative / total
                        print(f"   Warning: {fraction_negative:.2%} Negative R² values detected for modality {i}. This may indicate poor reconstruction.")
                    r_squared = torch.clamp(r_squared, min=0.0)
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Handle negative R² values (poor fit) - use correlation as fallback
            """
            if r_squared < 0:
                print(f"   Warning: Negative R² value detected for modality {i}: {r_squared}. Using correlation instead.")
                # Use correlation as a fallback
                correlation_matrix = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))
                r_squared = correlation_matrix[0, 1]
                if torch.isnan(r_squared):
                    r_squared = torch.tensor(0.0)
                else:
                    # Square the correlation to get R²-like measure
                    r_squared = r_squared ** 2
            """
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            
            r_squared_values.append(r_squared.item())
    
    return r_squared_values

def compute_direct_explained_variance(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute explained variance score based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Either DataLoader or list of tensors [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of explained variance values for each modality
    """
    from torch.utils.data import DataLoader
    model.eval()
    
    # Check if data is a DataLoader (for image-text datasets)
    if isinstance(data, DataLoader):
        # Accumulate statistics across batches
        n_modalities = 2  # image-text always has 2 modalities
        n_samples = 0
        
        # Accumulate flattened data for variance computation
        all_originals = [[], []]
        all_reconstructions = [[], []]
        all_text_masks = []  # Track non-padding tokens for text modality
        
        first_batch_debug = True
        with torch.no_grad():
            for batch in data:
                images = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                # Get reconstructions
                reconstructions, _ = model(images, input_ids, attention_mask)
                img_recon, text_logits = reconstructions
                
                # For text, convert logits to tokens
                text_recon = torch.argmax(text_logits, dim=-1)
                
                # Flatten and store
                img_flat = images.flatten(start_dim=1).cpu()
                img_recon_flat = img_recon.flatten(start_dim=1).cpu()
                text_flat = input_ids.flatten(start_dim=1).cpu().float()
                text_recon_flat = text_recon.flatten(start_dim=1).cpu().float()
                
                # Store mask for non-padding tokens (0 is padding in BERT)
                text_mask = (input_ids != 0).flatten(start_dim=1).cpu()
                
                # Debug: Check values on first batch
                if first_batch_debug and verbose:
                    print(f"\n=== Explained Variance Computation Debug (First Batch) ===")
                    print(f"Image data - shape: {img_flat.shape}, min: {img_flat.min().item():.4f}, max: {img_flat.max().item():.4f}, mean: {img_flat.mean().item():.4f}")
                    print(f"Image recon - shape: {img_recon_flat.shape}, min: {img_recon_flat.min().item():.4f}, max: {img_recon_flat.max().item():.4f}, mean: {img_recon_flat.mean().item():.4f}")
                    print(f"Text data - shape: {text_flat.shape}, min: {text_flat.min().item():.4f}, max: {text_flat.max().item():.4f}, mean: {text_flat.mean().item():.4f}")
                    print(f"Text recon - shape: {text_recon_flat.shape}, min: {text_recon_flat.min().item():.4f}, max: {text_recon_flat.max().item():.4f}, mean: {text_recon_flat.mean().item():.4f}")
                    print(f"Text non-padding tokens: {text_mask.sum().item()} / {text_mask.numel()}")
                    first_batch_debug = False
                
                all_originals[0].append(img_flat)
                all_reconstructions[0].append(img_recon_flat)
                all_originals[1].append(text_flat)
                all_reconstructions[1].append(text_recon_flat)
                all_text_masks.append(text_mask)
        
        # Concatenate all batches
        explained_variance_values = []
        if verbose:
            print(f"\n=== Explained Variance Final Computation ===")
        for i in range(n_modalities):
            orig = torch.cat(all_originals[i], dim=0)
            recon = torch.cat(all_reconstructions[i], dim=0)
            
            if i == 1:  # Text modality - only compute variance on non-padding tokens
                text_mask = torch.cat(all_text_masks, dim=0)
                # Flatten and select only non-padding tokens
                orig_valid = orig[text_mask]
                recon_valid = recon[text_mask]
                
                # Compute explained variance only on valid tokens
                var_diff = torch.var(recon_valid - orig_valid)
                var_orig = torch.var(orig_valid)
                
                if verbose:
                    print(f"Modality {i} (text, non-padding only): var_diff={var_diff.item():.4f}, var_orig={var_orig.item():.4f}, n_valid_tokens={orig_valid.numel()}")
            else:  # Image modality - compute on all pixels
                # Compute explained variance
                var_diff = torch.var(recon - orig)
                var_orig = torch.var(orig)
                
                if verbose:
                    print(f"Modality {i}: var_diff={var_diff.item():.4f}, var_orig={var_orig.item():.4f}")
            
            if var_orig > 1e-6:
                explained_var = 1.0 - (var_diff / (var_orig + 1e-9))
                explained_var = max(0.0, explained_var.item())
            else:
                explained_var = 0.0
            
            if verbose:
                print(f"Modality {i}: Explained Variance={explained_var:.4f}")
            
            explained_variance_values.append(explained_var)
        
        return explained_variance_values
    
    # Original tensor list implementation
    explained_variance_values = []
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate explained variance for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            # Flatten originals and reconstructions to (N, D)
            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            if verbose:
                print(f"   Debug: modality {i} original shape {original_cpu.shape} -> flat {orig_flat.shape}; recon shape {reconstruction_cpu.shape} -> flat {recon_flat.shape}")

            # Align batch dimension
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Warning: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for explained variance.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Handle NaN or Inf values
            if torch.any(torch.isnan(orig_flat)) or torch.any(torch.isinf(orig_flat)) or \
               torch.any(torch.isnan(recon_flat)) or torch.any(torch.isinf(recon_flat)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in flattened data for modality {i}. Handling them.")
                valid_mask = ~torch.isnan(orig_flat) & ~torch.isinf(orig_flat) & \
                             ~torch.isnan(recon_flat) & ~torch.isinf(recon_flat)
                if valid_mask.sum() == 0:
                    explained_variance = torch.tensor(0.0)
                else:
                    original_valid = orig_flat[valid_mask]
                    reconstruction_valid = recon_flat[valid_mask]
                    explained_variance = 1 - (torch.var(reconstruction_valid - original_valid) / (torch.var(original_valid) + 1e-9))
            else:
                # Normal case - calculate standard explained variance on flattened data
                explained_variance = 1 - (torch.var(recon_flat - orig_flat) / (torch.var(orig_flat) + 1e-9))

            # Handle negative explained variance values (poor fit)
            if isinstance(explained_variance, torch.Tensor):
                if explained_variance < 0:
                    if verbose:
                        print(f"   Warning: Negative explained variance value detected for modality {i}: {explained_variance}. Clamping to 0.")
                    explained_variance = torch.clamp(explained_variance, min=0.0)
            else:
                if explained_variance < 0:
                    explained_variance = 0.0

            # Ensure explained_variance is a scalar tensor
            if not isinstance(explained_variance, torch.Tensor):
                explained_variance = torch.tensor(explained_variance)

            explained_variance_values.append(explained_variance.item())
    
    return explained_variance_values

import os

def pretrain_overcomplete_ae(train_loader, val_loader, model, device, args=None, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True, recon_loss_balancing=False, lr_schedule=None,
                         input_shapes=None, save_frequency=None, pretrained_name=None, warmup_epochs=10):
    """
    Train an autoencoder for mm_sim pretraining (no rank reduction).
    
    Parameters:
    - train_loader: DataLoader for training data
    - val_loader: DataLoader for validation data
    - model: The model to train (already initialized)
    - device: Device to train on
    - args: Optional arguments namespace
    """
    # Ensure args exists
    if args is None:
        from types import SimpleNamespace
        args = SimpleNamespace()
    
    multi_gpu = getattr(args, 'multi_gpu', False)
    
    # Get number of modalities from a sample batch
    sample_batch = next(iter(train_loader))
    n_modalities = len(sample_batch)
    
    # Model is already provided
    print(f"Model is on device: {next(model.parameters()).device}")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params} parameters")
    
    # Handle multi-GPU setup if needed
    if multi_gpu and not isinstance(model, nn.DataParallel):
        try:
            if hasattr(args, 'gpu_ids') and args.gpu_ids:
                if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                    raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                cuda0_device = torch.device('cuda:0')
                model = model.to(cuda0_device)
                for param in model.parameters():
                    if param.device != cuda0_device:
                        param.data = param.data.to(cuda0_device)
                model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
                if verbose:
                    print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler with linear warmup
    scheduler = None
    warmup_scheduler = None
    main_scheduler = None
    
    # Create warmup scheduler (linear warmup from 0 to lr over warmup_epochs)
    if warmup_epochs > 0:
        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs
        )
    
    # Create main scheduler after warmup
    if lr_schedule == 'linear':
        try:
            main_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=max(1, epochs - warmup_epochs))
        except Exception:
            main_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs - warmup_epochs))))
    elif lr_schedule == 'step':
        main_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 1000], gamma=0.1)
    elif lr_schedule == 'cosine':
        main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=1e-6)
    
    # Combine warmup and main scheduler
    if warmup_epochs > 0 and main_scheduler is not None:
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, 
            schedulers=[warmup_scheduler, main_scheduler],
            milestones=[warmup_epochs]
        )
    elif warmup_epochs > 0:
        scheduler = warmup_scheduler
    elif main_scheduler is not None:
        scheduler = main_scheduler

    # Use the provided data loaders directly
    data_loader = train_loader
    val_data_loader = val_loader
    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(n_modalities, device=device)
    #loss_scales[1] = 0.1
    loss_history = {f'mod_{i}_loss': [] for i in range(n_modalities)}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * n_modalities
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * n_modalities
        
        for batch_idx, batch in enumerate(data_loader):
            # Extract from dictionary batch
            images = batch['image'].to(device, non_blocking=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['label']
            
            # Forward pass
            x_hat, h_list = model(images, input_ids, attention_mask)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Modality 0: Image reconstruction (MSE loss)
            img_loss = F.mse_loss(x_hat[0], images, reduction='mean')
            modality_losses.append(img_loss)
            
            # Modality 1: Text reconstruction (Cross-Entropy loss)
            # x_hat[1] is logits of shape (B, Seq, Vocab)
            # input_ids is shape (B, Seq)
            text_logits = x_hat[1]
            text_loss = F.cross_entropy(
                text_logits.reshape(-1, text_logits.size(-1)),  # (B*Seq, Vocab)
                input_ids.reshape(-1),  # (B*Seq,)
                ignore_index=0,  # Ignore padding token (BERT pad token is 0)
                reduction='mean'
            )
            """
            if torch.isnan(img_loss):
                if verbose:
                    print(f"Warning: NaN loss detected for image modality")
                img_loss = torch.tensor(0.0, device=device)
            """
            modality_losses.append(img_loss)
            per_modality_losses[0] += img_loss.item()

            """
            if torch.isnan(text_loss):
                if verbose:
                    print(f"Warning: NaN loss detected for text modality")
                text_loss = torch.tensor(0.0, device=device)
            """
            modality_losses.append(text_loss)
            per_modality_losses[1] += text_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                loss = sum(m_loss * (min_ema / (modality_loss_emas[i] + 1e-8)) 
                          for i, m_loss in enumerate(modality_losses))
            else:
                # Simple sum of modality losses
                loss = sum(modality_losses)
            
            total_loss = loss

            # Backward pass and optimize (normal training; do not skip)
            optimizer.zero_grad()
            total_loss.backward()
            # Gradient clipping to improve stability
            try:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            except Exception:
                # If model is DataParallel, clip parameters of module
                if hasattr(model, 'module'):
                    torch.nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=1.0)
                else:
                    raise
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        # Ortho loss is not used in pretraining
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for batch_val in val_data_loader:
                # Extract from dictionary batch
                images_val = batch_val['image'].to(device, non_blocking=True)
                input_ids_val = batch_val['input_ids'].to(device, non_blocking=True)
                attention_mask_val = batch_val['attention_mask'].to(device, non_blocking=True)
                
                # Forward pass
                x_val_hat, _ = model(images_val, input_ids_val, attention_mask_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                
                # Image loss
                img_loss = F.mse_loss(x_val_hat[0], images_val, reduction='mean')
                if not torch.isnan(img_loss):
                    val_batch_loss += img_loss.item()
                
                # Text loss
                text_logits = x_val_hat[1]
                text_loss = F.cross_entropy(
                    text_logits.reshape(-1, text_logits.size(-1)),
                    input_ids_val.reshape(-1),
                    ignore_index=0,
                    reduction='mean'
                )
                if not torch.isnan(text_loss):
                    val_batch_loss += text_loss.item()
                
                val_loss += val_batch_loss / 2.0  # Average over 2 modalities
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            #try:
            scheduler.step()
            #except Exception:
            #    # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
            #    pass

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1

        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
        
        # Periodic checkpoint saving - save to the main pretrained path so it gets updated
        effective_save_freq = save_frequency if save_frequency is not None else 10
        if pretrained_name is not None and (epoch + 1) % effective_save_freq == 0:
            if 'food' in pretrained_name.lower():
                dataset_folder = 'food101'
            elif 'so2sat' in pretrained_name.lower():
                dataset_folder = 'so2sat'
            else:
                dataset_folder = 'models'
            checkpoint_path = f"{project_config.RESULTS_DIR}/{dataset_folder}/pretrained_{pretrained_name}.pt"
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            # Also save loss curves
            loss_curve_path = checkpoint_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            print(f"Saved checkpoint at epoch {epoch+1} to {checkpoint_path}")
        
        # Update progress bar
        # include current learning rate in progress bar
        if scheduler is not None:
            current_lr = scheduler.get_last_lr()[0]
            pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.4e}")
        else:
            try:
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.2e}")
            except Exception:
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
    
    return model, [train_losses, val_losses]

def train_overcomplete_ae_with_pretrained(train_loader, val_loader, model, device, latent_dim=None, args=None, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, lr_schedule='constant',
                         decision_metric='R2', input_shapes=None, end_lr=1e-5, save_frequency=None,
                         from_unimodal=False, unimodal_seed=42
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - train_data: Training data list of tensors (one per modality)
    - val_data: Validation data list of tensors (one per modality)
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """

    # Ensure args exists
    if args is None:
        from types import SimpleNamespace
        args = SimpleNamespace()
    
    # Extract datasets for accessing raw data when needed
    train_dataset = train_loader.dataset
    val_dataset = val_loader.dataset
    train_data = getattr(train_dataset, 'data', train_dataset)
    val_data = getattr(val_dataset, 'data', val_dataset)
    
    # Get number of modalities (image-text multimodal model always has 2 modalities)
    n_modalities = 2
    
    multi_gpu = getattr(args, 'multi_gpu', False)

    # If from_unimodal is True, load unimodal pretrained weights instead of multimodal pretrained
    if from_unimodal:
        print("\n" + "="*60)
        print("Loading pretrained unimodal models...")
        print("="*60)
        
        # Determine dataset folder from pretrained_name
        if 'food' in pretrained_name.lower():
            dataset_folder = 'food101'
        elif 'so2sat' in pretrained_name.lower():
            dataset_folder = 'so2sat'
        else:
            dataset_folder = 'models'
        
        results_dir = f'{project_config.RESULTS_DIR}/{dataset_folder}'
        
        # Load image autoencoder
        img_ae_path = os.path.join(results_dir, f'image_ae_seed{unimodal_seed}.pth')
        if os.path.exists(img_ae_path):
            img_checkpoint = torch.load(img_ae_path, map_location='cpu')
            # Extract encoder and decoder weights
            img_encoder_dict = {k.replace('encoder_conv.', 'encoder_conv.').replace('to_latent.', 'to_latent.'): v 
                               for k, v in img_checkpoint['state_dict'].items() 
                               if k.startswith('encoder_conv') or k.startswith('to_latent')}
            img_decoder_dict = {k.replace('from_latent.', 'from_latent.').replace('decoder_conv.', 'decoder_conv.'): v 
                               for k, v in img_checkpoint['state_dict'].items() 
                               if k.startswith('from_latent') or k.startswith('decoder_conv')}
            
            # Load into img_branch
            model.img_branch.encoder_conv.load_state_dict(
                {k.replace('encoder_conv.', ''): v for k, v in img_encoder_dict.items() if k.startswith('encoder_conv')}, strict=False)
            model.img_branch.to_latent.load_state_dict(
                {k.replace('to_latent.', ''): v for k, v in img_encoder_dict.items() if k.startswith('to_latent')}, strict=False)
            model.img_branch.from_latent.load_state_dict(
                {k.replace('from_latent.', ''): v for k, v in img_decoder_dict.items() if k.startswith('from_latent')}, strict=False)
            model.img_branch.decoder_conv.load_state_dict(
                {k.replace('decoder_conv.', ''): v for k, v in img_decoder_dict.items() if k.startswith('decoder_conv')}, strict=False)
            
            print(f"✓ Loaded image encoder/decoder from {img_ae_path}")
        else:
            print(f"  Warning: Image autoencoder not found at {img_ae_path}")
        
        # Load text autoencoder (check for transformer or bert variant)
        text_ae_path = os.path.join(results_dir, f'text_ae_transformer_seed{unimodal_seed}.pth')
        if not os.path.exists(text_ae_path):
            text_ae_path = os.path.join(results_dir, f'text_ae_bert_seed{unimodal_seed}.pth')
        
        if os.path.exists(text_ae_path):
            text_checkpoint = torch.load(text_ae_path, map_location='cpu')
            
            # Load trainable transformer weights if not using BERT
            if hasattr(model.text_branch, 'trainable_branch'):
                # Load state dict but skip incompatible keys (like positional encodings with different seq lengths)
                state_dict = text_checkpoint['state_dict']
                model_state = model.text_branch.trainable_branch.state_dict()
                
                # Filter out keys with size mismatches
                compatible_state = {}
                for key, value in state_dict.items():
                    if key in model_state:
                        if value.shape == model_state[key].shape:
                            compatible_state[key] = value
                        else:
                            print(f"  Skipping {key}: shape mismatch ({value.shape} vs {model_state[key].shape})")
                
                model.text_branch.trainable_branch.load_state_dict(compatible_state, strict=False)
                print(f"✓ Loaded text trainable transformer from {text_ae_path} ({len(compatible_state)}/{len(state_dict)} keys)")
            elif hasattr(model.text_branch, 'to_latent'):
                # Load the projection and decoder weights (BERT encoder is frozen anyway)
                text_proj_dict = {k: v for k, v in text_checkpoint['state_dict'].items() 
                                 if k.startswith('to_latent') or k.startswith('from_latent') or 
                                    k.startswith('decoder') or k.startswith('output_head')}
                model.text_branch.load_state_dict(text_proj_dict, strict=False)
                print(f"✓ Loaded text decoder from {text_ae_path}")
        else:
            print(f"  Warning: Text autoencoder not found at {text_ae_path}")
        
        # Freeze encoders and decoders - will be unfrozen after warmup
        if warmup_epochs > 0:
            print("\nFreezing image and text branches for warmup phase...")
            for param in model.img_branch.parameters():
                param.requires_grad = False
            for param in model.text_branch.parameters():
                param.requires_grad = False
            
            # Keep adaptive layers trainable
            for layer in model.adaptive_layers:
                for param in layer.parameters():
                    param.requires_grad = True
            
            # check if the fusion layers exist and set them to trainable (explicitly)
            if hasattr(model, 'img_fusion'):
                for param in model.img_fusion.parameters():
                    param.requires_grad = True
            if hasattr(model, 'text_fusion'):
                for param in model.text_fusion.parameters():
                    param.requires_grad = True
            
            print("✓ Only adaptive and fusion layers will be trained during warmup")
        
        print("✓ Unimodal weights loaded successfully")
        print("="*60 + "\n")
        
        # Skip multimodal pretrained loading, initialize for training
        train_losses = []
        val_losses = []
        model.epoch = 0
        pretrained_model_path = None
    else:
        # check if there is an existing pretrained model for the seed, early stopping, and training hyperparameters (lr, wd, batch size, model architecture)
        # Determine which dataset folder to use based on pretrained_name
        if 'food' in pretrained_name.lower():
            dataset_folder = 'food101'
        elif 'so2sat' in pretrained_name.lower():
            dataset_folder = 'so2sat'
        else:
            dataset_folder = 'models'  # fallback
        pretrained_model_path = f"{project_config.RESULTS_DIR}/{dataset_folder}/pretrained_{pretrained_name}.pt" if pretrained_name else None
    
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        # Use strict=False to allow loading models with dynamically created layers (e.g., img_fusion, text_fusion)
        state_dict = torch.load(pretrained_model_path, weights_only=False)
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        if unexpected_keys:
            print(f"Note: Unexpected keys in checkpoint (will be ignored): {unexpected_keys}")
        if missing_keys:
            print(f"Note: Missing keys in model (will use random initialization): {missing_keys}")
        # make sure that the weights are changed
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        # also load the loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        # print last losses
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
        # If reconstruction plots do not exist, create and save example reconstructions
        try:
            mod0_plot_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_image_recon.png')
            mod1_plot_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_text_recon.txt')
            if (not os.path.exists(mod0_plot_path)) or (not os.path.exists(mod1_plot_path)):
                # ensure model parameters are on the same device as inputs
                try:
                    model = model.to(device)
                except Exception:
                    pass
                model.eval()
                with torch.no_grad():
                    # take a small sample from the provided data to generate example reconstructions
                    # Support multiple dataset types: dataset.data as list/tuple of tensors OR DataLoader that yields dict batches
                    dataset = getattr(train_loader, 'dataset', None)
                    n_plot = 8
                    # Prefer dataset.data if present and structured as list/tuple
                    if (dataset is not None) and hasattr(dataset, 'data') and isinstance(dataset.data, (list, tuple)):
                        try:
                            total_n = dataset.data[0].shape[0]
                        except Exception:
                            total_n = len(dataset)
                        n_plot = min(8, total_n)
                        rng = np.random.default_rng(seed=42)
                        sample_idx = rng.choice(total_n, size=n_plot, replace=False)
                        mod0_sample = dataset.data[0][sample_idx].to(device)
                        mod1_sample = dataset.data[1][sample_idx].to(device)
                    else:
                        # Fall back to sampling from a batch yielded by the DataLoader
                        batch = next(iter(train_loader))
                        if isinstance(batch, dict):
                            # Expect keys like 'image', 'input_ids', 'attention_mask'
                            n_total = batch.get('image').shape[0]
                            n_plot = min(8, n_total)
                            mod0_sample = batch['image'][:n_plot].to(device)
                            # For text modality use input_ids (model.encode expects input_ids, attention_mask)
                            # If attention_mask missing, create one
                            if 'input_ids' in batch:
                                mod1_sample = batch['input_ids'][:n_plot].to(device)
                            else:
                                # Fallback: try second element
                                mod1_sample = batch[list(batch.keys())[1]][:n_plot].to(device)
                        else:
                            # batch is likely a tuple/list of modality tensors
                            n_total = batch[0].shape[0]
                            n_plot = min(8, n_total)
                            mod0_sample = batch[0][:n_plot].to(device)
                            mod1_sample = batch[1][:n_plot].to(device)

                    # Use module if DataParallel wrapping already present
                    model_callable = model.module if (multi_gpu and hasattr(model, 'module')) else model

                    # Try to call the model with image-text signature first: (images, input_ids, attention_mask)
                    called = False
                    try:
                        # If mod1_sample looks like token ids (2D tensor), create attention mask
                        if isinstance(mod1_sample, torch.Tensor) and mod1_sample.dim() == 2:
                            attention_mask = (mod1_sample != 0).to(mod1_sample.device)
                        else:
                            attention_mask = None

                        if attention_mask is not None:
                            reconstructions, _ = model_callable(mod0_sample, mod1_sample, attention_mask)
                            called = True
                        else:
                            # Try calling with two positional args (images, text)
                            try:
                                reconstructions, _ = model_callable(mod0_sample, mod1_sample)
                                called = True
                            except TypeError:
                                called = False
                    except TypeError:
                        called = False

                    # Fallback to list-style input for older models
                    if not called:
                        try:
                            reconstructions, _ = model_callable([mod0_sample, mod1_sample])
                        except Exception as e:
                            # Re-raise with context for easier debugging
                            raise RuntimeError(f"Failed to call model for reconstruction plotting: {e}")

                    # Save plots (basic visualization for CNN outputs)
                    # This would need proper plotting functions for image data
                    try:
                        print(f"Debug: Reconstruction types: {[type(r) for r in reconstructions]}")
                        print(f"Debug: Reconstruction shapes: mod0={getattr(reconstructions[0], 'shape', None)}, mod1={getattr(reconstructions[1], 'shape', None)}")

                        # Save a simple side-by-side image reconstruction PNG
                        img_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_image_recon.png')
                        os.makedirs(os.path.dirname(img_path), exist_ok=True)
                        n_plot_local = min(8, mod0_sample.shape[0])
                        fig, axes = plt.subplots(2, n_plot_local, figsize=(n_plot_local * 1.5, 3))
                        for i in range(n_plot_local):
                            try:
                                img_orig = mod0_sample[i].permute(1, 2, 0).cpu().numpy()
                                img_orig = (img_orig - img_orig.min()) / (img_orig.max() - img_orig.min() + 1e-8)
                                axes[0, i].imshow(img_orig)
                                axes[0, i].axis('off')
                            except Exception as e:
                                axes[0, i].text(0.5, 0.5, 'orig N/A', ha='center')
                                axes[0, i].axis('off')

                            try:
                                img_rec = reconstructions[0][i].permute(1, 2, 0).cpu().numpy()
                                img_rec = (img_rec - img_rec.min()) / (img_rec.max() - img_rec.min() + 1e-8)
                                axes[1, i].imshow(img_rec)
                                axes[1, i].axis('off')
                            except Exception as e:
                                axes[1, i].text(0.5, 0.5, 'rec N/A', ha='center')
                                axes[1, i].axis('off')

                        plt.tight_layout()
                        plt.savefig(img_path, dpi=150, bbox_inches='tight')
                        plt.close()
                        print(f"Saved CNN image reconstructions to {img_path}")

                        # Save text reconstructions if available
                        try:
                            from transformers import BertTokenizer
                            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                            text_logits = reconstructions[1]
                            # If logits, take argmax; if already ids, use directly
                            if isinstance(text_logits, torch.Tensor) and text_logits.dim() == 3:
                                pred_ids = torch.argmax(text_logits, dim=-1)
                            else:
                                pred_ids = text_logits

                            txt_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_text_recon.txt')
                            os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                            with open(txt_path, 'w', encoding='utf-8') as f:
                                for i in range(min(8, pred_ids.shape[0])):
                                    try:
                                        orig_text = tokenizer.decode(mod1_sample[i], skip_special_tokens=True) if isinstance(mod1_sample, torch.Tensor) else str(mod1_sample[i])
                                    except Exception:
                                        orig_text = 'N/A'
                                    try:
                                        recon_text = tokenizer.decode(pred_ids[i], skip_special_tokens=True) if isinstance(pred_ids, torch.Tensor) else str(pred_ids[i])
                                    except Exception:
                                        recon_text = 'N/A'
                                    f.write(f"Sample {i+1}:\n")
                                    f.write(f"  Original: {orig_text}\n")
                                    f.write(f"  Reconstructed: {recon_text}\n\n")
                            print(f"Saved CNN text reconstructions to {txt_path}")
                        except Exception as e:
                            print(f"Warning: Could not save CNN text reconstructions: {e}")
                    except Exception as e:
                        print(f"Warning: Failed during CNN reconstruction plotting: {e}")
        except Exception as e:
            print(f"Warning: Could not save reconstruction plots: {e}")
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            model, [train_losses, val_losses] = pretrain_overcomplete_ae(
                train_loader, val_loader, model, device, epochs=int(epochs/2), early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, ae_width=ae_width, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank, lr_schedule=lr_schedule,
                verbose=verbose, input_shapes=input_shapes, save_frequency=save_frequency, pretrained_name=pretrained_name,
                warmup_epochs=min(10, int(epochs/20))  # Use 10 epochs or 5% of total epochs, whichever is smaller
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            # Also save a PNG of the pretraining loss curves (train & val)
            try:
                loss_png_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_pretrain_loss_curve.png')
                os.makedirs(os.path.dirname(loss_png_path), exist_ok=True)
                plt.figure(figsize=(6, 4))
                plt.plot(np.arange(1, len(train_losses) + 1), train_losses, label='train')
                plt.plot(np.arange(1, len(val_losses) + 1), val_losses, label='val')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Pretraining Loss Curves')
                plt.legend()
                plt.tight_layout()
                plt.savefig(loss_png_path, dpi=150)
                plt.close()
            except Exception as e:
                print(f"Warning: Could not save pretraining loss PNG: {e}")
            # Save example plots for reconstruction
            try:
                model.eval()
                with torch.no_grad():
                    # Get sample batch from train loader
                    sample_batch = next(iter(train_loader))
                    n_plot = min(8, len(sample_batch['image']))
                    
                    images = sample_batch['image'][:n_plot].to(device)
                    input_ids = sample_batch['input_ids'][:n_plot].to(device)
                    attention_mask = sample_batch['attention_mask'][:n_plot].to(device)

                    # Run through model (handle DataParallel)
                    if multi_gpu and hasattr(model, 'module'):
                        reconstructions, _ = model.module(images, input_ids, attention_mask)
                    else:
                        reconstructions, _ = model(images, input_ids, attention_mask)

                    img_recon, text_logits = reconstructions
                    
                    # Save image reconstructions
                    img_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_image_recon.png')
                    os.makedirs(os.path.dirname(img_path), exist_ok=True)
                    fig, axes = plt.subplots(2, n_plot, figsize=(n_plot * 1.5, 3))
                    for i in range(n_plot):
                        # Original image
                        img_orig = images[i].permute(1, 2, 0).cpu().numpy()
                        img_orig = (img_orig - img_orig.min()) / (img_orig.max() - img_orig.min() + 1e-8)
                        axes[0, i].imshow(img_orig)
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_ylabel('Original', fontsize=10)
                        # Reconstructed image
                        img_rec = img_recon[i].permute(1, 2, 0).cpu().numpy()
                        img_rec = (img_rec - img_rec.min()) / (img_rec.max() - img_rec.min() + 1e-8)
                        axes[1, i].imshow(img_rec)
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_ylabel('Reconstructed', fontsize=10)
                    plt.tight_layout()
                    plt.savefig(img_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved image reconstructions to {img_path}")
                    
                    # Save text reconstructions
                    from transformers import BertTokenizer
                    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                    
                    # Get predicted tokens (greedy decode)
                    pred_ids = torch.argmax(text_logits, dim=-1)  # (batch, seq_len)
                    
                    txt_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_text_recon.txt')
                    os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                    with open(txt_path, 'w', encoding='utf-8') as f:
                        for i in range(n_plot):
                            # Decode original
                            orig_text = tokenizer.decode(input_ids[i], skip_special_tokens=True)
                            # Decode reconstruction
                            recon_text = tokenizer.decode(pred_ids[i], skip_special_tokens=True)
                            
                            f.write(f"Sample {i+1}:\n")
                            f.write(f"  Original: {orig_text}\n")
                            f.write(f"  Reconstructed: {recon_text}\n\n")
                    print(f"Saved text reconstructions to {txt_path}")
                    
            except Exception as e:
                print(f"Warning: Could not save reconstruction plots: {e}")
                import traceback
                traceback.print_exc()
            
            model.epoch = len(train_losses)
            
            # Save reconstruction plots for both modalities
            try:
                # Need to create datasets from the train_data and val_data that are in scope
                temp_train_dataset = MMSimData(train_data)
                temp_val_dataset = MMSimData(val_data)
                
                model.eval()
                with torch.no_grad():
                    # Get 8 train and 8 test samples
                    n_plot = 8
                    
                    # Train samples
                    train_idx = torch.randperm(len(temp_train_dataset))[:n_plot]
                    train_samples = [temp_train_dataset.data[i][train_idx].to(device) for i in range(len(train_data))]
                    
                    # Test samples
                    test_idx = torch.randperm(len(temp_val_dataset))[:n_plot]
                    test_samples = [temp_val_dataset.data[i][test_idx].to(device) for i in range(len(train_data))]
                    
                    # Get reconstructions
                    if multi_gpu and hasattr(model, 'module'):
                        train_recon, _ = model.module(train_samples)
                        test_recon, _ = model.module(test_samples)
                    else:
                        train_recon, _ = model(train_samples)
                        test_recon, _ = model(test_samples)
                    
                    # Plot modality 0 (radar - channel 4 only)
                    mod0_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_image_recon.png')
                    os.makedirs(os.path.dirname(mod0_path), exist_ok=True)
                    fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                    for i in range(n_plot):
                        # Train original (channel 4)
                        axes[0, i].imshow(train_samples[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_title('Train', fontsize=10, loc='left')
                        # Train reconstruction (channel 4)
                        axes[1, i].imshow(train_recon[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                        # Test original (channel 4)
                        axes[2, i].imshow(test_samples[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[2, i].axis('off')
                        if i == 0:
                            axes[2, i].set_title('Test', fontsize=10, loc='left')
                        # Test reconstruction (channel 4)
                        axes[3, i].imshow(test_recon[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[3, i].axis('off')
                        if i == 0:
                            axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                    plt.tight_layout()
                    plt.savefig(mod0_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved radar reconstruction plot to {mod0_path}")
                    
                    # Plot modality 1 (optical - RGB)
                    mod1_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_text_recon.txt')
                    os.makedirs(os.path.dirname(mod1_path), exist_ok=True)
                    fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                    for i in range(n_plot):
                        # Train original (RGB)
                        img_train = train_samples[1][i].permute(1, 2, 0).cpu().numpy()
                        # Normalize to [0, 1] for display
                        img_train = (img_train - img_train.min()) / (img_train.max() - img_train.min() + 1e-8)
                        axes[0, i].imshow(img_train)
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_title('Train', fontsize=10, loc='left')
                        # Train reconstruction (RGB)
                        img_train_recon = train_recon[1][i].permute(1, 2, 0).cpu().numpy()
                        img_train_recon = (img_train_recon - img_train_recon.min()) / (img_train_recon.max() - img_train_recon.min() + 1e-8)
                        axes[1, i].imshow(img_train_recon)
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                        # Test original (RGB)
                        img_test = test_samples[1][i].permute(1, 2, 0).cpu().numpy()
                        img_test = (img_test - img_test.min()) / (img_test.max() - img_test.min() + 1e-8)
                        axes[2, i].imshow(img_test)
                        axes[2, i].axis('off')
                        if i == 0:
                            axes[2, i].set_title('Test', fontsize=10, loc='left')
                        # Test reconstruction (RGB)
                        img_test_recon = test_recon[1][i].permute(1, 2, 0).cpu().numpy()
                        img_test_recon = (img_test_recon - img_test_recon.min()) / (img_test_recon.max() - img_test_recon.min() + 1e-8)
                        axes[3, i].imshow(img_test_recon)
                        axes[3, i].axis('off')
                        if i == 0:
                            axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                    plt.tight_layout()
                    plt.savefig(mod1_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved optical reconstruction plot to {mod1_path}")
            except Exception as e:
                print(f"Warning: Could not save reconstruction plots: {e}")
        elif from_unimodal:
            # Unimodal weights already loaded above, no pretraining needed
            print("Skipping multimodal pretraining - using unimodal initialization.")
        else:
            raise ValueError("model_name must be provided to save/load pretrained models.")
    
    # Model already provided and initialized
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup if needed
    if multi_gpu and not isinstance(model, nn.DataParallel):
        try:
            if hasattr(args, 'gpu_ids') and args.gpu_ids:
                if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                    raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                cuda0_device = torch.device('cuda:0')
                model = model.to(cuda0_device)
                for param in model.parameters():
                    if param.device != cuda0_device:
                        param.data = param.data.to(cuda0_device)
                model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
                if verbose:
                    print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler with linear warmup
    scheduler = None
    warmup_scheduler = None
    main_scheduler = None
    
    # Create warmup scheduler (linear warmup from 0 to lr over warmup_epochs)
    if warmup_epochs > 0:
        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs
        )
    
    # Create main scheduler after warmup
    if lr_schedule == 'cosine':
        main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=end_lr
        )
    elif lr_schedule == 'linear':
        try:
            # Use LinearLR when available (PyTorch >= 1.11)
            main_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=end_lr/lr, total_iters=max(1, epochs - warmup_epochs))
        except Exception:
            # Fallback to LambdaLR for older PyTorch versions
            main_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(end_lr/lr, 1.0 - (epoch + 1) / float(max(1, epochs - warmup_epochs))))
    elif lr_schedule == 'step':
        main_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int((epochs - warmup_epochs)*0.5) + warmup_epochs, int((epochs - warmup_epochs)*0.75) + warmup_epochs], gamma=0.1)
    
    # Combine warmup and main scheduler
    if warmup_epochs > 0 and main_scheduler is not None:
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, 
            schedulers=[warmup_scheduler, main_scheduler],
            milestones=[warmup_epochs]
        )
    elif warmup_epochs > 0:
        scheduler = warmup_scheduler
    elif main_scheduler is not None:
        scheduler = main_scheduler

    # Use the provided data loaders directly
    data_loader = train_loader
    val_data_loader = val_loader
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * n_modalities # per modality
    initial_losses = [None] * n_modalities # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * n_modalities
    current_loss_per_mod = [None] * n_modalities  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(n_modalities, device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(n_modalities, device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(n_modalities)}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * n_modalities
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * n_modalities
        
        for batch_idx, batch in enumerate(data_loader):
            # Extract from dictionary batch
            images = batch['image'].to(device, non_blocking=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['label']
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            
            # Forward pass
            x_hat, h_list = model(images, input_ids, attention_mask)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Modality 0: Image reconstruction (MSE loss)
            img_loss = F.mse_loss(x_hat[0], images, reduction='mean')
            modality_losses.append(img_loss)
            
            # Modality 1: Text reconstruction (Cross-Entropy loss)
            text_logits = x_hat[1]
            text_loss = F.cross_entropy(
                text_logits.reshape(-1, text_logits.size(-1)),
                input_ids.reshape(-1),
                ignore_index=0,
                reduction='mean'
            )
            modality_losses.append(text_loss)
            
            # Check for NaNs in individual modality losses BEFORE summing
            """
            has_nan = False
            if torch.isnan(img_loss).any():
                if verbose:
                    print(f"Warning: NaN loss detected for image modality")
                has_nan = True
            if torch.isnan(text_loss).any():
                if verbose:
                    print(f"Warning: NaN loss detected for text modality")
                has_nan = True
            
            # Skip this batch if any modality has NaN
            if has_nan:
                # stop the training loop
                print("NaN detected in losses, stopping training.")
                break
            """
            
            # Safe to accumulate per-modality losses now
            per_modality_losses[0] += img_loss.item()
            per_modality_losses[1] += text_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss

            # Backward pass and optimize (normal training)
            optimizer.zero_grad()
            total_loss.backward()
            # Gradient clipping to improve stability
            try:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            except Exception:
                if hasattr(model, 'module'):
                    torch.nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=1.0)
                else:
                    raise
            optimizer.step()
            train_loss += loss.item()

        #if has_nan:
        #    break
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for batch_val in val_data_loader:
                # Extract from dictionary batch
                images_val = batch_val['image'].to(device, non_blocking=True)
                input_ids_val = batch_val['input_ids'].to(device, non_blocking=True)
                attention_mask_val = batch_val['attention_mask'].to(device, non_blocking=True)
                
                # Forward pass
                x_val_hat, _ = model(images_val, input_ids_val, attention_mask_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                
                # Image loss
                img_loss = F.mse_loss(x_val_hat[0], images_val, reduction='mean')
                if not torch.isnan(img_loss):
                    val_batch_loss += img_loss.item()
                
                # Text loss
                text_logits = x_val_hat[1]
                text_loss = F.cross_entropy(
                    text_logits.reshape(-1, text_logits.size(-1)),
                    input_ids_val.reshape(-1),
                    ignore_index=0,
                    reduction='mean'
                )
                if not torch.isnan(text_loss):
                    val_batch_loss += text_loss.item()
                
                val_loss += val_batch_loss / 2.0  # Average over 2 modalities
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            #try:
            scheduler.step()
            #except Exception:
            #    # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
            #    pass

        # Unfreeze encoders/decoders after warmup when using unimodal initialization
        """
        if from_unimodal and warmup_epochs > 0 and epoch == warmup_epochs - 1:
            print("\n" + "="*60)
            print(f"Warmup phase complete at epoch {epoch+1}")
            print("Unfreezing image and text branches for full training...")
            target_model = model.module if isinstance(model, nn.DataParallel) else model
            for param in target_model.img_branch.parameters():
                param.requires_grad = True
            for param in target_model.text_branch.parameters():
                param.requires_grad = True
            print("✓ All model parameters are now trainable")
            print("="*60 + "\n")
        """

        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        log_dict = {
            'loss': round(train_loss, 4),
            'lr': f'{current_lr:.2e}',
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(n_modalities)],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1

        #if (start_reduction is False) and (epoch == model.epoch + patience): # giving the optimizer a start
        if (start_reduction is False) and (epoch == model.epoch + warmup_epochs): # giving the optimizer a start
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)

            #with torch.no_grad():
            #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_modality = model.encode_modalities(val_data_tensors)
            #    #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
            #    encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            #if mask is not None:
            #    start_idx = 0
            #    for j, x_m in enumerate(x_val):
            #        end_idx = start_idx + x_m.shape[1]
            #        temp_mask = mask[:, start_idx]
            #        # expand it to match the encoded shape
            #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
            #        modality_masks_latent.append(temp_mask)
            #        modality_masks_data.append(mask[:, start_idx:end_idx])
            #        modality_masks_space.append(mask[:, start_idx])
            #        start_idx = end_idx
            
            # Direct reconstruction R² approach (use validation DataLoader)
            if verbose:
                print("   Computing R² from validation data...")
            
            if decision_metric == 'ExVarScore':
                direct_r_squared_values = compute_direct_explained_variance(model, train_loader, device, multi_gpu, verbose=verbose)
            else:  # Default to R2
                direct_r_squared_values = compute_direct_r_squared(model, train_loader, device, multi_gpu, verbose=verbose)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                # Calculate threshold based on threshold_type
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                    
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
            max_rsquares = initial_squares.copy()
                
            if verbose:
                #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(current_rsquare_per_mod))]}, setting {threshold_type} thresholds to {min_rsquares}")
            #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                #with torch.no_grad():
                #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
                #    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                #    encoded_per_modality = model.encode_modalities(val_data_tensors)
                #    if not compute_jacobian:
                #        #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                #        encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
                #        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                #    else:
                #        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                #if mask is not None:
                #    start_idx = 0
                #    for j, x_m in enumerate(x_val):
                #        end_idx = start_idx + x_m.shape[1]
                #        #modality_masks.append(mask[:, start_idx:end_idx])
                #        # expand it to match the encoded shape
                #        temp_mask = mask[:, start_idx]
                #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                #        modality_masks_latent.append(temp_mask)
                #        modality_masks_data.append(mask[:, start_idx:end_idx])
                #        modality_masks_space.append(mask[:, start_idx])
                #        start_idx = end_idx
                
                # Direct reconstruction R² approach (use validation DataLoader)
                if decision_metric == 'ExVarScore':
                    direct_r_squared_values = compute_direct_explained_variance(model, train_loader, device, multi_gpu)
                else:  # Default to R2
                    direct_r_squared_values = compute_direct_r_squared(model, train_loader, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                        
                r_squares.append(current_rsquares)
                
                # Update min_rsquares based on maximum observed (allows threshold to follow best performance)
                #max_rquares = [max(r_squares, key=lambda x: x[i])[i] for i in range(len(current_rsquare_per_mod))] if len(r_squares) > 0 else initial_squares
                update_max = False
                for i, r in enumerate(current_rsquares):
                    if r > max_rsquares[i]:
                        max_rsquares[i] = r
                        update_max = True
                if update_max:
                    if threshold_type == 'relative':
                        min_rsquares = [r * r_square_threshold for r in max_rsquares]
                    elif threshold_type == 'absolute':
                        # For R²: subtract threshold; for loss: add threshold
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            min_rsquares = [r + r_square_threshold for r in max_rsquares]
                        else:
                            min_rsquares = [r - r_square_threshold for r in max_rsquares]
                    if verbose:
                        print(f"Updated threshold {min_rsquares} based on new max {max_rsquares}")

                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                        #model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                    #bottom_reached = True
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    # set the min for the modality to be increased to current rank
                    model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #if sharedwhenall:
                            #    reduce_shared = False
                            #else:
                            #    reduce_shared = True
                            #if reduce_shared:
                            #    layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            #else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            #for i in range(len(encoded_per_modality)):
            for i in range(len(current_rsquare_per_mod)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)

            # also get mutual information between all spaces
            #valid_spaces = []
            #for encoded in encoded_per_space:
            #    valid_spaces_temp = []
            #    for i in range(len(encoded_per_modality)):
            #        #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
            #        if mask is not None:
            #            temp_mask = modality_masks_space[i]
            #            #valid_spaces_temp.append(normalized_encoded[mask])
            #            valid_spaces_temp.append(encoded[temp_mask])
            #        else:
            #            #valid_spaces_temp.append(normalized_encoded)
            #            valid_spaces_temp.append(encoded)
            #    # stack the valid spaces
            #    valid_spaces_temp = torch.vstack(valid_spaces_temp)
            #    valid_spaces.append(valid_spaces_temp)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if (epoch > early_stopping) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
        
        # Periodic checkpoint saving - save to the main model path so it gets updated
        effective_save_freq = save_frequency if save_frequency is not None else 10
        if model_name is not None and (epoch + 1) % effective_save_freq == 0:
            checkpoint_path = f"./03_results/models/{model_name}.pt"
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            if multi_gpu:
                torch.save(model.module.state_dict(), checkpoint_path)
            else:
                torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint at epoch {epoch+1} to {checkpoint_path}")

            # Also save intermediate reconstruction plots (train & val) every save_frequency epochs
            try:
                n_plot = 6
                # Get a small batch from train and val loaders
                train_batch = next(iter(train_loader))
                val_batch = next(iter(val_loader))

                # Support dict-style batches (image-text) or tuple/list batches
                if isinstance(train_batch, dict):
                    train_images = train_batch['image'][:n_plot].to(device)
                    train_input_ids = train_batch['input_ids'][:n_plot].to(device)
                    train_attention_mask = train_batch['attention_mask'][:n_plot].to(device) if 'attention_mask' in train_batch else (train_input_ids != 0).to(device)
                else:
                    train_images = train_batch[0][:n_plot].to(device)
                    train_input_ids = train_batch[1][:n_plot].to(device)
                    train_attention_mask = (train_input_ids != 0).to(device)

                if isinstance(val_batch, dict):
                    val_images = val_batch['image'][:n_plot].to(device)
                    val_input_ids = val_batch['input_ids'][:n_plot].to(device)
                    val_attention_mask = val_batch['attention_mask'][:n_plot].to(device) if 'attention_mask' in val_batch else (val_input_ids != 0).to(device)
                else:
                    val_images = val_batch[0][:n_plot].to(device)
                    val_input_ids = val_batch[1][:n_plot].to(device)
                    val_attention_mask = (val_input_ids != 0).to(device)

                # Run model to get reconstructions
                if multi_gpu and hasattr(model, 'module'):
                    train_recon, _ = model.module(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model.module(val_images, val_input_ids, val_attention_mask)
                else:
                    train_recon, _ = model(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model(val_images, val_input_ids, val_attention_mask)

                train_img_recon, train_text_logits = train_recon
                val_img_recon, val_text_logits = val_recon

                # Save image reconstructions to ./03_results/train_plots/
                img_path = checkpoint_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', f'_image_recon_epoch{epoch+1}.png')
                os.makedirs(os.path.dirname(img_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original
                    img = train_images[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[0, i].imshow(img)
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_ylabel('Train', fontsize=10)
                    # Train reconstruction
                    img = train_img_recon[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[1, i].imshow(img)
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_ylabel('Train Recon', fontsize=10)
                    # Val original
                    img = val_images[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[2, i].imshow(img)
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_ylabel('Val', fontsize=10)
                    # Val reconstruction
                    img = val_img_recon[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[3, i].imshow(img)
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_ylabel('Val Recon', fontsize=10)
                plt.tight_layout()
                plt.savefig(img_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved intermediate image reconstructions to {img_path}")

                # Save text reconstructions
                try:
                    from transformers import BertTokenizer
                    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                    train_pred_ids = torch.argmax(train_text_logits, dim=-1)
                    val_pred_ids = torch.argmax(val_text_logits, dim=-1)
                    txt_path = checkpoint_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', f'_text_recon_epoch{epoch+1}.txt')
                    os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                    with open(txt_path, 'w', encoding='utf-8') as f:
                        f.write('=== TRAIN SAMPLES ===\n\n')
                        for i in range(n_plot):
                            orig = tokenizer.decode(train_input_ids[i], skip_special_tokens=True)
                            recon = tokenizer.decode(train_pred_ids[i], skip_special_tokens=True)
                            f.write(f"Sample {i+1}:\n  Original: {orig}\n  Reconstructed: {recon}\n\n")
                        f.write('\n=== VAL SAMPLES ===\n\n')
                        for i in range(n_plot):
                            orig = tokenizer.decode(val_input_ids[i], skip_special_tokens=True)
                            recon = tokenizer.decode(val_pred_ids[i], skip_special_tokens=True)
                            f.write(f"Sample {i+1}:\n  Original: {orig}\n  Reconstructed: {recon}\n\n")
                    print(f"Saved intermediate text reconstructions to {txt_path}")
                except Exception as e:
                    print(f"Warning: Could not save intermediate text reconstructions: {e}")
            except Exception as e:
                print(f"Warning: Could not save intermediate reconstruction plots at epoch {epoch+1}: {e}")
    
    # Calculate latent representations in batches (only for training data)
    # Build representations with proper index mapping to handle both full datasets and subsets
    n_samples = len(train_loader.dataset)
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    
    # Collect all representations and indices first
    all_reps = [[] for _ in range(len(final_ranks))]
    all_indices = []
    
    model.eval()
    with torch.no_grad():
        for batch in train_loader:
            # Extract data from batch dictionary
            images = batch['image'].to(device, non_blocking=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            indices = batch['idx']  # Original indices from dataset
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(images, input_ids, attention_mask)
            else:
                batch_reps = model.encode(images, input_ids, attention_mask)
            batch_rep_list = [batch_reps[0].detach().cpu()] + [batch_reps[1][j].detach().cpu() for j in range(len(batch_reps[1]))]
            
            # Store representations and indices
            for j in range(len(final_ranks)):
                all_reps[j].append(batch_rep_list[j][:, :final_ranks[j]])
            all_indices.append(indices)
            
            # Free memory
            del images, input_ids, attention_mask, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Concatenate all batches
    all_indices = torch.cat(all_indices)
    for j in range(len(final_ranks)):
        all_reps[j] = torch.cat(all_reps[j], dim=0)
    
    # Create index mapping: original_idx -> position in subset
    # This handles both full datasets (identity mapping) and random_split subsets
    unique_indices = torch.unique(all_indices, sorted=True)
    idx_to_pos = {int(idx): pos for pos, idx in enumerate(unique_indices)}
    
    # Create final reps array and fill it in the correct order
    reps = [torch.empty((len(unique_indices), final_ranks[i])) for i in range(len(final_ranks))]
    for i, orig_idx in enumerate(all_indices):
        pos = idx_to_pos[int(orig_idx)]
        for j in range(len(reps)):
            reps[j][pos, :] = all_reps[j][i, :]
    reps = [rep for rep in reps]  # Move to CPU after full collection
    
    # Convert indices to numpy for easier saving/loading
    sorted_indices = unique_indices.cpu().numpy()
    print(f"Computed final latent representations for training data (n={len(sorted_indices)}).")
    
    # save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
        
        # Save final reconstruction plots for both modalities
        try:
            model.eval()
            with torch.no_grad():
                n_plot = 8
                
                # Get train samples
                train_batch = next(iter(train_loader))
                train_images = train_batch['image'][:n_plot].to(device)
                train_input_ids = train_batch['input_ids'][:n_plot].to(device)
                train_attention_mask = train_batch['attention_mask'][:n_plot].to(device)
                
                # Get val samples
                val_batch = next(iter(val_loader))
                val_images = val_batch['image'][:n_plot].to(device)
                val_input_ids = val_batch['input_ids'][:n_plot].to(device)
                val_attention_mask = val_batch['attention_mask'][:n_plot].to(device)
                
                # Get reconstructions
                if multi_gpu and hasattr(model, 'module'):
                    train_recon, _ = model.module(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model.module(val_images, val_input_ids, val_attention_mask)
                else:
                    train_recon, _ = model(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model(val_images, val_input_ids, val_attention_mask)
                
                train_img_recon, train_text_logits = train_recon
                val_img_recon, val_text_logits = val_recon
                
                # Save image reconstructions
                img_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_image_recon_final.png')
                os.makedirs(os.path.dirname(img_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original
                    img = train_images[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[0, i].imshow(img)
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_ylabel('Train', fontsize=10)
                    # Train reconstruction
                    img = train_img_recon[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[1, i].imshow(img)
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_ylabel('Train Recon', fontsize=10)
                    # Val original
                    img = val_images[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[2, i].imshow(img)
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_ylabel('Val', fontsize=10)
                    # Val reconstruction
                    img = val_img_recon[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[3, i].imshow(img)
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_ylabel('Val Recon', fontsize=10)
                plt.tight_layout()
                plt.savefig(img_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved final image reconstructions to {img_path}")
                
                # Save text reconstructions
                from transformers import BertTokenizer
                tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                
                train_pred_ids = torch.argmax(train_text_logits, dim=-1)
                val_pred_ids = torch.argmax(val_text_logits, dim=-1)
                
                txt_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_text_recon_final.txt')
                os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                with open(txt_path, 'w', encoding='utf-8') as f:
                    f.write("=== TRAIN SAMPLES ===\n\n")
                    for i in range(n_plot):
                        orig = tokenizer.decode(train_input_ids[i], skip_special_tokens=True)
                        recon = tokenizer.decode(train_pred_ids[i], skip_special_tokens=True)
                        f.write(f"Sample {i+1}:\n")
                        f.write(f"  Original: {orig}\n")
                        f.write(f"  Reconstructed: {recon}\n\n")
                    
                    f.write("\n=== VAL SAMPLES ===\n\n")
                    for i in range(n_plot):
                        orig = tokenizer.decode(val_input_ids[i], skip_special_tokens=True)
                        recon = tokenizer.decode(val_pred_ids[i], skip_special_tokens=True)
                        f.write(f"Sample {i+1}:\n")
                        f.write(f"  Original: {orig}\n")
                        f.write(f"  Reconstructed: {recon}\n\n")
                print(f"Saved final text reconstructions to {txt_path}")
                
        except Exception as e:
            print(f"Warning: Could not save final reconstruction plots: {e}")
            import traceback
            traceback.print_exc()
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]

    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses], sorted_indices

def plot_image_reconstruction(original_imgs, recon_imgs, n=8, out_path=None):
    """Plot original and reconstructed images side-by-side.
    original_imgs and recon_imgs are numpy arrays of shape (N, 784) with values in [0,1].
    """
    orig = original_imgs[:n]
    recon = recon_imgs[:n]
    fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
    for i in range(n):
        axes[0, i].imshow(orig[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_modal_image_reconstruction(original_imgs, recon_imgs, image_shape=(28,28), n=8, out_path=None):
    """Plot original and reconstructed modality images for arbitrary image shapes.
    original_imgs and recon_imgs can be:
      - numpy arrays shaped (N, H*W)
      - numpy arrays shaped (N, H, W)
      - numpy arrays shaped (N, 1, H, W)
    image_shape: tuple (H, W)
    """
    H, W = image_shape
    # Normalize and reshape originals
    orig = original_imgs[:n]
    recon = recon_imgs[:n]

    def ensure_hw(arr):
        if arr.ndim == 4 and arr.shape[1] == 1:
            return arr[:, 0]
        if arr.ndim == 3:
            return arr
        if arr.ndim == 2:
            # assume flattened
            return arr.reshape(-1, H, W)
        raise ValueError('Unsupported array shape for image reconstruction plotting: ' + str(arr.shape))

    orig_hw = ensure_hw(orig)
    recon_hw = ensure_hw(recon)

    fig, axes = plt.subplots(2, n, figsize=(n * 2, 3))
    for i in range(n):
        axes[0, i].imshow(orig_hw[i], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon_hw[i], cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def train_overcomplete_ae(train_loader, val_loader, model, device, latent_dim=None, args=None, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, lr_schedule='constant',
                         decision_metric='R2', input_shapes=None, end_lr=1e-5, save_frequency=None,
                         from_unimodal=False, unimodal_seed=42, post_warmup_lr=None, tokenizer=None
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - train_data: Training data list of tensors (one per modality)
    - val_data: Validation data list of tensors (one per modality)
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """

    # Ensure args exists
    if args is None:
        from types import SimpleNamespace
        args = SimpleNamespace()
    
    # Extract datasets for accessing raw data when needed
    train_dataset = train_loader.dataset
    val_dataset = val_loader.dataset
    train_data = getattr(train_dataset, 'data', train_dataset)
    val_data = getattr(val_dataset, 'data', val_dataset)
    
    # Get number of modalities (image-text multimodal model always has 2 modalities)
    n_modalities = 2
    
    multi_gpu = getattr(args, 'multi_gpu', False)

    # Model already provided and initialized
    # Initialize epoch counter if not already set (e.g., from checkpoint)
    if not hasattr(model, 'epoch'):
        model.epoch = 0
    
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup if needed
    if multi_gpu and not isinstance(model, nn.DataParallel):
        try:
            if hasattr(args, 'gpu_ids') and args.gpu_ids:
                if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                    raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                cuda0_device = torch.device('cuda:0')
                model = model.to(cuda0_device)
                for param in model.parameters():
                    if param.device != cuda0_device:
                        param.data = param.data.to(cuda0_device)
                model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
                if verbose:
                    print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler with warmup then (optionally) constant post-warmup LR
    scheduler = None
    warmup_scheduler = None
    main_scheduler = None

    # Create warmup scheduler (linear change from `lr` -> `end_lr` over `warmup_epochs`)
    if warmup_epochs > 0:
        try:
            warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=1.0, end_factor=max(float(end_lr) / float(lr), 0.0), total_iters=warmup_epochs
            )
        except Exception:
            # Fallback for older PyTorch: use LambdaLR that linearly interpolates between 1.0 and end_lr/lr
            def _warmup_lambda(epoch):
                t = float(min(epoch + 1, warmup_epochs)) / float(max(1, warmup_epochs))
                return 1.0 + (float(end_lr) / float(lr) - 1.0) * t
            warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_warmup_lambda)

    # If a post-warmup constant LR is specified, use it as the main scheduler (constant factor)
    if post_warmup_lr is not None:
        post_lr = float(post_warmup_lr)
        try:
            main_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: post_lr / float(lr))
        except Exception:
            main_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: post_lr / float(lr))
    else:
        # Create main scheduler after warmup according to existing policies
        if lr_schedule == 'cosine':
            main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=end_lr
            )
        elif lr_schedule == 'linear':
            try:
                # Use LinearLR when available (PyTorch >= 1.11)
                main_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=end_lr / lr, total_iters=max(1, epochs - warmup_epochs))
            except Exception:
                # Fallback to LambdaLR for older PyTorch versions
                main_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(end_lr / lr, 1.0 - (epoch + 1) / float(max(1, epochs - warmup_epochs))))
        elif lr_schedule == 'step':
            main_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int((epochs - warmup_epochs) * 0.5) + warmup_epochs, int((epochs - warmup_epochs) * 0.75) + warmup_epochs], gamma=0.1)

    # Combine warmup and main scheduler
    if warmup_epochs > 0 and main_scheduler is not None:
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, main_scheduler],
            milestones=[warmup_epochs]
        )
    elif warmup_epochs > 0:
        scheduler = warmup_scheduler
    elif main_scheduler is not None:
        scheduler = main_scheduler

    # Use the provided data loaders directly
    data_loader = train_loader
    val_data_loader = val_loader
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * n_modalities # per modality
    initial_losses = [None] * n_modalities # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * n_modalities
    current_loss_per_mod = [None] * n_modalities  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(n_modalities, device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(n_modalities, device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(n_modalities)}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * n_modalities
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * n_modalities
        
        for batch_idx, batch in enumerate(data_loader):
            optimizer.zero_grad()

            # Extract from dictionary batch
            images = batch['image'].to(device, non_blocking=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['label']
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            
            # Forward pass
            x_hat, h_list = model(images, input_ids, attention_mask)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Modality 0: Image reconstruction (MSE loss)
            img_loss = F.mse_loss(x_hat[0], images, reduction='mean')
            modality_losses.append(img_loss)
            
            # Modality 1: Text reconstruction (Cross-Entropy loss)
            text_logits = x_hat[1]
            text_loss = F.cross_entropy(
                text_logits.reshape(-1, text_logits.size(-1)),
                input_ids.reshape(-1),
                ignore_index=0,
                reduction='mean'
            )
            modality_losses.append(text_loss)
            
            # Check for NaNs in individual modality losses BEFORE summing
            """
            has_nan = False
            if torch.isnan(img_loss).any():
                if verbose:
                    print(f"Warning: NaN loss detected for image modality")
                has_nan = True
            if torch.isnan(text_loss).any():
                if verbose:
                    print(f"Warning: NaN loss detected for text modality")
                has_nan = True
            
            # Skip this batch if any modality has NaN
            if has_nan:
                # stop the training loop
                print("NaN detected in losses, stopping training.")
                break
            """
            
            # Safe to accumulate per-modality losses now
            per_modality_losses[0] += img_loss.item()
            per_modality_losses[1] += text_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss

            # Backward pass and optimize (normal training)
            total_loss.backward()
            # print the gradients for trainable parameters
            #if epoch == 0 and batch_idx == 0:
            #    for name, param in model.named_parameters():
            #        if param.requires_grad and param.grad is not None:
            #            print(f"Gradient for {name}: {param.grad.norm().item()}")
            # Gradient clipping to improve stability
            try:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            except Exception:
                if hasattr(model, 'module'):
                    torch.nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=1.0)
                else:
                    raise
            optimizer.step()
            train_loss += loss.item()

        #if has_nan:
        #    break
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for batch_val in val_data_loader:
                # Extract from dictionary batch
                images_val = batch_val['image'].to(device, non_blocking=True)
                input_ids_val = batch_val['input_ids'].to(device, non_blocking=True)
                attention_mask_val = batch_val['attention_mask'].to(device, non_blocking=True)
                
                # Forward pass
                x_val_hat, _ = model(images_val, input_ids_val, attention_mask_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                
                # Image loss
                img_loss = F.mse_loss(x_val_hat[0], images_val, reduction='mean')
                if not torch.isnan(img_loss):
                    val_batch_loss += img_loss.item()
                
                # Text loss
                text_logits = x_val_hat[1]
                text_loss = F.cross_entropy(
                    text_logits.reshape(-1, text_logits.size(-1)),
                    input_ids_val.reshape(-1),
                    ignore_index=0,
                    reduction='mean'
                )
                if not torch.isnan(text_loss):
                    val_batch_loss += text_loss.item()
                
                val_loss += val_batch_loss / 2.0  # Average over 2 modalities
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            #try:
            scheduler.step()
            #except Exception:
            #    # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
            #    pass

        # Unfreeze encoders/decoders after warmup when using unimodal initialization
        """
        if from_unimodal and warmup_epochs > 0 and epoch == warmup_epochs - 1:
            print("\n" + "="*60)
            print(f"Warmup phase complete at epoch {epoch+1}")
            print("Unfreezing image and text branches for full training...")
            target_model = model.module if isinstance(model, nn.DataParallel) else model
            for param in target_model.img_branch.parameters():
                param.requires_grad = True
            for param in target_model.text_branch.parameters():
                param.requires_grad = True
            print("✓ All model parameters are now trainable")
            print("="*60 + "\n")
        """

        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        log_dict = {
            'loss': round(train_loss, 4),
            'lr': f'{current_lr:.2e}',
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(n_modalities)],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1

        #if (start_reduction is False) and (epoch == model.epoch + patience): # giving the optimizer a start
        if (start_reduction is False) and (epoch == model.epoch + warmup_epochs): # giving the optimizer a start
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)

            #with torch.no_grad():
            #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_modality = model.encode_modalities(val_data_tensors)
            #    #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
            #    encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            #if mask is not None:
            #    start_idx = 0
            #    for j, x_m in enumerate(x_val):
            #        end_idx = start_idx + x_m.shape[1]
            #        temp_mask = mask[:, start_idx]
            #        # expand it to match the encoded shape
            #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
            #        modality_masks_latent.append(temp_mask)
            #        modality_masks_data.append(mask[:, start_idx:end_idx])
            #        modality_masks_space.append(mask[:, start_idx])
            #        start_idx = end_idx
            
            # Direct reconstruction R² approach (use validation DataLoader)
            if verbose:
                print("   Computing R² from validation data...")
            
            if decision_metric == 'ExVarScore':
                direct_r_squared_values = compute_direct_explained_variance(model, train_loader, device, multi_gpu, verbose=verbose)
            else:  # Default to R2
                direct_r_squared_values = compute_direct_r_squared(model, train_loader, device, multi_gpu, verbose=verbose)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                # Calculate threshold based on threshold_type
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                    
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
            max_rsquares = initial_squares.copy()
                
            if verbose:
                #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(current_rsquare_per_mod))]}, setting {threshold_type} thresholds to {min_rsquares}")
            #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            # Save model checkpoint after computing R²
            if model_name:
                latest_path = f"./03_results/models/{model_name}_latest.pt"
                os.makedirs(os.path.dirname(latest_path), exist_ok=True)
                if multi_gpu:
                    torch.save(model.module.state_dict(), latest_path)
                else:
                    torch.save(model.state_dict(), latest_path)
                if verbose:
                    print(f"Saved latest checkpoint to {latest_path}")
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                #with torch.no_grad():
                #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
                #    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                #    encoded_per_modality = model.encode_modalities(val_data_tensors)
                #    if not compute_jacobian:
                #        #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                #        encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
                #        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                #    else:
                #        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                #if mask is not None:
                #    start_idx = 0
                #    for j, x_m in enumerate(x_val):
                #        end_idx = start_idx + x_m.shape[1]
                #        #modality_masks.append(mask[:, start_idx:end_idx])
                #        # expand it to match the encoded shape
                #        temp_mask = mask[:, start_idx]
                #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                #        modality_masks_latent.append(temp_mask)
                #        modality_masks_data.append(mask[:, start_idx:end_idx])
                #        modality_masks_space.append(mask[:, start_idx])
                #        start_idx = end_idx
                
                # Direct reconstruction R² approach (use validation DataLoader)
                if decision_metric == 'ExVarScore':
                    direct_r_squared_values = compute_direct_explained_variance(model, train_loader, device, multi_gpu)
                else:  # Default to R2
                    direct_r_squared_values = compute_direct_r_squared(model, train_loader, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                        
                r_squares.append(current_rsquares)
                
                # Update min_rsquares based on maximum observed (allows threshold to follow best performance)
                #max_rquares = [max(r_squares, key=lambda x: x[i])[i] for i in range(len(current_rsquare_per_mod))] if len(r_squares) > 0 else initial_squares
                update_max = False
                for i, r in enumerate(current_rsquares):
                    if r > max_rsquares[i]:
                        max_rsquares[i] = r
                        update_max = True
                if update_max:
                    if threshold_type == 'relative':
                        min_rsquares = [r * r_square_threshold for r in max_rsquares]
                    elif threshold_type == 'absolute':
                        # For R²: subtract threshold; for loss: add threshold
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            min_rsquares = [r + r_square_threshold for r in max_rsquares]
                        else:
                            min_rsquares = [r - r_square_threshold for r in max_rsquares]
                    if verbose:
                        print(f"Updated threshold {min_rsquares} based on new max {max_rsquares}")

                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                        #model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                    #bottom_reached = True
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    # set the min for the modality to be increased to current rank
                    model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #if sharedwhenall:
                            #    reduce_shared = False
                            #else:
                            #    reduce_shared = True
                            #if reduce_shared:
                            #    layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            #else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            #for i in range(len(encoded_per_modality)):
            for i in range(len(current_rsquare_per_mod)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)

            # also get mutual information between all spaces
            #valid_spaces = []
            #for encoded in encoded_per_space:
            #    valid_spaces_temp = []
            #    for i in range(len(encoded_per_modality)):
            #        #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
            #        if mask is not None:
            #            temp_mask = modality_masks_space[i]
            #            #valid_spaces_temp.append(normalized_encoded[mask])
            #            valid_spaces_temp.append(encoded[temp_mask])
            #        else:
            #            #valid_spaces_temp.append(normalized_encoded)
            #            valid_spaces_temp.append(encoded)
            #    # stack the valid spaces
            #    valid_spaces_temp = torch.vstack(valid_spaces_temp)
            #    valid_spaces.append(valid_spaces_temp)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if (epoch > early_stopping) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
        
        # Periodic checkpoint saving - save to the main model path so it gets updated
        effective_save_freq = save_frequency if save_frequency is not None else 10
        if model_name is not None and (epoch + 1) % effective_save_freq == 0:
            checkpoint_path = f"./03_results/models/{model_name}.pt"
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            if multi_gpu:
                torch.save(model.module.state_dict(), checkpoint_path)
            else:
                torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint at epoch {epoch+1} to {checkpoint_path}")

            # Also save intermediate reconstruction plots (train & val) every save_frequency epochs
            try:
                n_plot = 6
                # Get a small batch from train and val loaders
                train_batch = next(iter(train_loader))
                val_batch = next(iter(val_loader))

                # Support dict-style batches (image-text) or tuple/list batches
                if isinstance(train_batch, dict):
                    train_images = train_batch['image'][:n_plot].to(device)
                    train_input_ids = train_batch['input_ids'][:n_plot].to(device)
                    train_attention_mask = train_batch['attention_mask'][:n_plot].to(device) if 'attention_mask' in train_batch else (train_input_ids != 0).to(device)
                else:
                    train_images = train_batch[0][:n_plot].to(device)
                    train_input_ids = train_batch[1][:n_plot].to(device)
                    train_attention_mask = (train_input_ids != 0).to(device)

                if isinstance(val_batch, dict):
                    val_images = val_batch['image'][:n_plot].to(device)
                    val_input_ids = val_batch['input_ids'][:n_plot].to(device)
                    val_attention_mask = val_batch['attention_mask'][:n_plot].to(device) if 'attention_mask' in val_batch else (val_input_ids != 0).to(device)
                else:
                    val_images = val_batch[0][:n_plot].to(device)
                    val_input_ids = val_batch[1][:n_plot].to(device)
                    val_attention_mask = (val_input_ids != 0).to(device)

                # Run model to get reconstructions
                if multi_gpu and hasattr(model, 'module'):
                    train_recon, _ = model.module(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model.module(val_images, val_input_ids, val_attention_mask)
                else:
                    train_recon, _ = model(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model(val_images, val_input_ids, val_attention_mask)

                train_img_recon, train_text_logits = train_recon
                val_img_recon, val_text_logits = val_recon

                # Save image reconstructions to ./03_results/train_plots/
                img_path = checkpoint_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', f'_image_recon_epoch{epoch+1}.png')
                os.makedirs(os.path.dirname(img_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original
                    img = train_images[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[0, i].imshow(img)
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_ylabel('Train', fontsize=10)
                    # Train reconstruction
                    img = train_img_recon[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[1, i].imshow(img)
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_ylabel('Train Recon', fontsize=10)
                    # Val original
                    img = val_images[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[2, i].imshow(img)
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_ylabel('Val', fontsize=10)
                    # Val reconstruction
                    img = val_img_recon[i].permute(1, 2, 0).detach().cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[3, i].imshow(img)
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_ylabel('Val Recon', fontsize=10)
                plt.tight_layout()
                plt.savefig(img_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved intermediate image reconstructions to {img_path}")

                # Save text reconstructions
                try:
                    # Use the tokenizer passed to the function (defaults to BERT if not provided)
                    if tokenizer is None:
                        from transformers import BertTokenizer
                        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                    train_pred_ids = torch.argmax(train_text_logits, dim=-1)
                    val_pred_ids = torch.argmax(val_text_logits, dim=-1)
                    txt_path = checkpoint_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', f'_text_recon_epoch{epoch+1}.txt')
                    os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                    with open(txt_path, 'w', encoding='utf-8') as f:
                        f.write('=== TRAIN SAMPLES ===\n\n')
                        for i in range(n_plot):
                            orig = tokenizer.decode(train_input_ids[i], skip_special_tokens=True)
                            recon = tokenizer.decode(train_pred_ids[i], skip_special_tokens=True)
                            f.write(f"Sample {i+1}:\n  Original: {orig}\n  Reconstructed: {recon}\n\n")
                        f.write('\n=== VAL SAMPLES ===\n\n')
                        for i in range(n_plot):
                            orig = tokenizer.decode(val_input_ids[i], skip_special_tokens=True)
                            recon = tokenizer.decode(val_pred_ids[i], skip_special_tokens=True)
                            f.write(f"Sample {i+1}:\n  Original: {orig}\n  Reconstructed: {recon}\n\n")
                    print(f"Saved intermediate text reconstructions to {txt_path}")
                except Exception as e:
                    print(f"Warning: Could not save intermediate text reconstructions: {e}")
            except Exception as e:
                print(f"Warning: Could not save intermediate reconstruction plots at epoch {epoch+1}: {e}")
    
    # Calculate latent representations in batches (only for training data)
    # Build representations with proper index mapping to handle both full datasets and subsets
    n_samples = len(train_loader.dataset)
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    
    # Collect all representations and indices first
    all_reps = [[] for _ in range(len(final_ranks))]
    all_indices = []
    
    model.eval()
    with torch.no_grad():
        for batch in train_loader:
            # Extract data from batch dictionary
            images = batch['image'].to(device, non_blocking=True)
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            indices = batch['idx']  # Original indices from dataset
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(images, input_ids, attention_mask)
            else:
                batch_reps = model.encode(images, input_ids, attention_mask)
            batch_rep_list = [batch_reps[0].detach().cpu()] + [batch_reps[1][j].detach().cpu() for j in range(len(batch_reps[1]))]
            
            # Store representations and indices
            for j in range(len(final_ranks)):
                all_reps[j].append(batch_rep_list[j][:, :final_ranks[j]])
            all_indices.append(indices)
            
            # Free memory
            del images, input_ids, attention_mask, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Concatenate all batches
    all_indices = torch.cat(all_indices)
    for j in range(len(final_ranks)):
        all_reps[j] = torch.cat(all_reps[j], dim=0)
    
    # Create index mapping: original_idx -> position in subset
    # This handles both full datasets (identity mapping) and random_split subsets
    unique_indices = torch.unique(all_indices, sorted=True)
    idx_to_pos = {int(idx): pos for pos, idx in enumerate(unique_indices)}
    
    # Create final reps array and fill it in the correct order
    reps = [torch.empty((len(unique_indices), final_ranks[i])) for i in range(len(final_ranks))]
    for i, orig_idx in enumerate(all_indices):
        pos = idx_to_pos[int(orig_idx)]
        for j in range(len(reps)):
            reps[j][pos, :] = all_reps[j][i, :]
    reps = [rep for rep in reps]  # Move to CPU after full collection
    
    # Convert indices to numpy for easier saving/loading
    sorted_indices = unique_indices.cpu().numpy()
    print(f"Computed final latent representations for training data (n={len(sorted_indices)}).")
    
    # save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
        
        # Save final reconstruction plots for both modalities
        try:
            model.eval()
            with torch.no_grad():
                n_plot = 8
                
                # Get train samples
                train_batch = next(iter(train_loader))
                train_images = train_batch['image'][:n_plot].to(device)
                train_input_ids = train_batch['input_ids'][:n_plot].to(device)
                train_attention_mask = train_batch['attention_mask'][:n_plot].to(device)
                
                # Get val samples
                val_batch = next(iter(val_loader))
                val_images = val_batch['image'][:n_plot].to(device)
                val_input_ids = val_batch['input_ids'][:n_plot].to(device)
                val_attention_mask = val_batch['attention_mask'][:n_plot].to(device)
                
                # Get reconstructions
                if multi_gpu and hasattr(model, 'module'):
                    train_recon, _ = model.module(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model.module(val_images, val_input_ids, val_attention_mask)
                else:
                    train_recon, _ = model(train_images, train_input_ids, train_attention_mask)
                    val_recon, _ = model(val_images, val_input_ids, val_attention_mask)
                
                train_img_recon, train_text_logits = train_recon
                val_img_recon, val_text_logits = val_recon
                
                # Save image reconstructions
                img_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_image_recon_final.png')
                os.makedirs(os.path.dirname(img_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original
                    img = train_images[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[0, i].imshow(img)
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_ylabel('Train', fontsize=10)
                    # Train reconstruction
                    img = train_img_recon[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[1, i].imshow(img)
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_ylabel('Train Recon', fontsize=10)
                    # Val original
                    img = val_images[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[2, i].imshow(img)
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_ylabel('Val', fontsize=10)
                    # Val reconstruction
                    img = val_img_recon[i].permute(1, 2, 0).cpu().numpy()
                    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                    axes[3, i].imshow(img)
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_ylabel('Val Recon', fontsize=10)
                plt.tight_layout()
                plt.savefig(img_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved final image reconstructions to {img_path}")
                
                # Save text reconstructions
                #from transformers import BertTokenizer
                tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
                #tokenizer = T5Tokenizer.from_pretrained('google/t5-efficient-tiny')
                
                train_pred_ids = torch.argmax(train_text_logits, dim=-1)
                val_pred_ids = torch.argmax(val_text_logits, dim=-1)
                
                txt_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_text_recon_final.txt')
                os.makedirs(os.path.dirname(txt_path), exist_ok=True)
                with open(txt_path, 'w', encoding='utf-8') as f:
                    f.write("=== TRAIN SAMPLES ===\n\n")
                    for i in range(n_plot):
                        orig = tokenizer.decode(train_input_ids[i], skip_special_tokens=True)
                        recon = tokenizer.decode(train_pred_ids[i], skip_special_tokens=True)
                        f.write(f"Sample {i+1}:\n")
                        f.write(f"  Original: {orig}\n")
                        f.write(f"  Reconstructed: {recon}\n\n")
                    
                    f.write("\n=== VAL SAMPLES ===\n\n")
                    for i in range(n_plot):
                        orig = tokenizer.decode(val_input_ids[i], skip_special_tokens=True)
                        recon = tokenizer.decode(val_pred_ids[i], skip_special_tokens=True)
                        f.write(f"Sample {i+1}:\n")
                        f.write(f"  Original: {orig}\n")
                        f.write(f"  Reconstructed: {recon}\n\n")
                print(f"Saved final text reconstructions to {txt_path}")
                
        except Exception as e:
            print(f"Warning: Could not save final reconstruction plots: {e}")
            import traceback
            traceback.print_exc()
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]

    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses], sorted_indices