import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import tqdm
import gc
from src.models.larrp_multimodal import AdaptiveRankReducedAE
from src.models.larrp_multimodal_vae import AdaptiveRankReducedVAE
from src.functions.linear_probing import parallel_linear_regression
from src.data.loading import MMSimData
from src.functions.loss import frobenius_loss_shared, frobenius_loss_crossmodal
from src.visualization.logging import plot_training_state, create_training_movie
from src.functions.pretrain_mm_sim import pretrain_overcomplete_ae

def compute_direct_r_squared(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate R² for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            # Calculate mean of original data
            original_mean = original.mean(dim=0).cpu()
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()
            
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean for modality {i}. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data for modality {i}. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                # Normal case - calculate standard R²
                ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                ss_tot = ((original_cpu - original_mean)**2).sum(0)
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Handle negative R² values (poor fit) - use correlation as fallback
            """
            if r_squared < 0:
                print(f"   Warning: Negative R² value detected for modality {i}: {r_squared}. Using correlation instead.")
                # Use correlation as a fallback
                correlation_matrix = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))
                r_squared = correlation_matrix[0, 1]
                if torch.isnan(r_squared):
                    r_squared = torch.tensor(0.0)
                else:
                    # Square the correlation to get R²-like measure
                    r_squared = r_squared ** 2
            """
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            
            r_squared_values.append(r_squared.item())
    
    return r_squared_values

def train_overcomplete_ae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, 
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Define max gradient norm for clipping
    max_grad_norm = 1.0
    
    # Create model with adaptive rank reduction
    if rank_or_sparse == 'sparse':
        raise NotImplementedError("Sparse autoencoder training not implemented in this function.")
    else:
        input_dims = [d.shape[1] for d in data]
        if isinstance(latent_dim, int):
            latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
        elif isinstance(latent_dim, list):
            if (len(latent_dim) == 1) & (len(input_dims) > 1):
                latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
            else:
                latent_dims = latent_dim
        model = AdaptiveRankReducedAE(
            input_dims, latent_dims, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        ).to(device)
        # print the device the model is on
        print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer and loss function
    if l2_norm_adaptivelayers is not None:
        # Use AdamW with separate weight decay for adaptive layers
        adaptive_params = []
        #for layer in model.adaptive_layers if not multi_gpu else model.module.adaptive_layers:
        #    adaptive_params.extend(list(layer.parameters()))
        # also add the first decoder layers
        for i in range(len(model.decoders)):
            adaptive_params.extend(list(model.decoders[i][0].parameters()))
        
        # Get all other parameters (excluding adaptive layers)
        all_params = set(model.parameters())
        adaptive_params_set = set(adaptive_params)
        other_params = list(all_params - adaptive_params_set)
        
        # Create parameter groups with different weight decay
        param_groups = [
            {'params': other_params, 'weight_decay': wd},
            {'params': adaptive_params, 'weight_decay': l2_norm_adaptivelayers}
        ]
        optimizer = torch.optim.AdamW(param_groups, lr=lr)
    else:
        # Use standard Adam optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    data_indices = torch.randperm(data[0].shape[0])
    train_indices = data_indices[:n_samples_train]
    val_indices = data_indices[n_samples_train:]
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = MMSimData(train_data)
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = MMSimData(val_data)
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * len(data) # per modality
    initial_losses = [None] * len(data) # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * len(data)
    current_loss_per_mod = [None] * len(data)  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    patience_counter = 0
    #if rank_or_sparse == 'rank':
        #rank_history = {'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else 
        #            (model.module.get_total_rank() if multi_gpu else 0)],
        #            'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
        #            'epoch':[0],
        #            'loss':[float('inf')],
        #            'val_loss':[float('inf')],
        #            'rsquare':[]}
        #rank_history = {'total_rank':[],
        #            'ranks':[],
        #            'epoch':[],
        #            'loss':[],
        #            'val_loss':[]}
    #else:
    #    rank_history = {'rank':model.latent_dim, 'epoch':[0], 'loss':[float('inf')], 'rsquare':[]}
        #current_rank = model.latent_dim
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(len(data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9
    
    # Initialize loss balancer if needed
    if include_ortholoss:
        if ortho_loss_balancing:
            recon_loss_ema = None
            ortho_loss_ema = None
            ema_decay = 0.9
        elif ortho_loss_anneal_epochs is not None:
            if ortho_loss_warmup is not None and ortho_loss_warmup > 0:
                ortho_loss_weights = np.zeros(ortho_loss_warmup)
                # create ortho loss weight vector
                ortho_loss_weights = np.concatenate([ortho_loss_weights, np.linspace(ortho_loss_start_weight, ortho_loss_end_weight, ortho_loss_anneal_epochs)])
            else:
                ortho_loss_weights = np.linspace(ortho_loss_start_weight, ortho_loss_end_weight, ortho_loss_anneal_epochs)
            # pad the rest with the end weight for the max epochs
            ortho_loss_weights = np.concatenate([ortho_loss_weights, np.ones(epochs - ortho_loss_anneal_epochs) * ortho_loss_end_weight])

    ### plotting test
    # Initialize variables for plotting
    last_batch_data = None
    last_batch_labels = None
    plot_save_dir = "./03_results/plots/temp_latent_plots/" + model_name if model_name else "./03_results/plots/temp_latent_plots/"
    ###

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, (x, mask) in enumerate(data_loader):
            # if mask is nan, set to None
            if isinstance(mask, torch.Tensor) and torch.isnan(mask).all():
                mask = None
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            #if start_reduction:
            if include_l1:
                ortho_loss = torch.mean(torch.abs(h_list[0])) + sum(torch.mean(torch.abs(h_m)) for h_m in h_list[1])
            elif include_ortholoss:
                #ortho_loss = frobenius_loss_shared(h_list[0], orthogonal_weight=l1_weight)
                ortho_loss = torch.tensor(0.0, device=device)
                temp_ranks = [layer.active_dims for layer in model.adaptive_layers]
                for j in range(len(h_list[1])): # in case there are more than two modalities
                    ortho_loss += frobenius_loss_crossmodal(h_list[0], h_list[1][j], temp_ranks[0], temp_ranks[j+1])
                    for k in range(j+1, len(h_list[1])):
                        ortho_loss += frobenius_loss_crossmodal(h_list[1][j], h_list[1][k], temp_ranks[j+1], temp_ranks[k+1])
            else:
                ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
                # sanity check: see how many samples are masked
                #print([f"Modality {i} data shape: {x_m.shape}, mask shape: {mask_i.shape}, fraction unmasked: {mask_i.sum() / mask_i.numel()}" for i, mask_i in enumerate(modality_masks)])
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Skip if the modality is not present in this batch
                #if x_m.shape[0] == 0:
                #    modality_losses.append(torch.tensor(0.0, device=device))
                #    continue
                
                # Compute MSE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            # Adaptive loss balancing
            if include_ortholoss:
                if ortho_loss_balancing:
                    # Update exponential moving averages
                    if recon_loss_ema is None:
                        recon_loss_ema = loss.item()
                        ortho_loss_ema = ortho_loss.item()
                    else:
                        recon_loss_ema = ema_decay * recon_loss_ema + (1 - ema_decay) * loss.item()
                        ortho_loss_ema = ema_decay * ortho_loss_ema + (1 - ema_decay) * ortho_loss.item()
                    
                    # Balance losses to have similar magnitudes
                    if ortho_loss_ema > 0:
                        ortho_scale = recon_loss_ema / ortho_loss_ema
                        total_loss += loss + ortho_scale * ortho_loss
                    else:
                        total_loss += loss
                elif ortho_loss_anneal_epochs is not None:
                    ortho_weight = ortho_loss_weights[epoch]
                    total_loss += loss + ortho_weight * ortho_loss
                else:
                    total_loss += loss + ortho_loss
            else:
                total_loss += loss
            
            # Store initial losses for scaling (first batch of first epoch)
            #if epoch == 0 and batch_idx == 0:
            #    for i, m_loss in enumerate(modality_losses):
            #        if m_loss > 0:
            #            initial_losses[i] = m_loss.detach()
            #    # Prevent division by zero
            #    initial_losses = torch.clamp(initial_losses, min=1e-8)
            #    print(f"Initial modality losses: {initial_losses.cpu().numpy()}")
            
            # Calculate dynamic loss scales based on current loss values
            #with torch.no_grad():
            #    for i, m_loss in enumerate(modality_losses):
            #        if initial_losses[i] > 0 and m_loss > 0:
            #            # Use moving average to update scales
            #            target_scale = initial_losses[i] / (m_loss + 1e-8)
            #            loss_scales[i] = 0.9 * loss_scales[i] + 0.1 * target_scale
            
            # Apply loss scales and calculate total loss
            #total_loss = torch.tensor(0.0, device=device, requires_grad=True)
            #for i, m_loss in enumerate(modality_losses):
            #    if m_loss > 0:
            #        # Apply scaling to balance modality contributions
            #        total_loss = total_loss + loss_scales[i].detach() * m_loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            
            # Apply gradient clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val, mask in val_data_loader:
                x_val = [x_m.to(device) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            #'ortho_loss': round(total_ortho_loss, 4) if (start_reduction and include_ortholoss) else 'N/A',
            #'scales': loss_scales.detach().cpu().numpy().round(2),
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            #'best_loss': round(best_loss, 4),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(data))],
            #'mi': [round(mi, 3) for mi in mutual_info] if mutual_info is not None else 'N/A',
            #'sim': [round(sim, 3) for sim in space_sims] if space_sims is not None else 'N/A',
            'patience': patience_counter,
        }
        if include_ortholoss:
            log_dict.update({'ortho_loss': round(total_ortho_loss, 4) if include_ortholoss else 'N/A'})
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            # Check if we should reduce rank based on loss condition
            #should_reduce = True
            #should_increase = False
            #if reduce_on_best_loss == 'true' and train_loss > best_loss:
            #    should_reduce = False
            #    # Don't print a separate message, will show in progress bar
            #elif reduce_on_best_loss == 'stagnation' and patience_counter < patience:
            #    should_reduce = False
            #    # Don't print a separate message, will show in progress bar
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                with torch.no_grad():
                    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                    encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                    if not compute_jacobian:
                        encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                    else:
                        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                mask = val_data.mask
                modality_masks_data = []
                modality_masks_latent = []
                modality_masks_space = []
                if mask is not None:
                    start_idx = 0
                    for j, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        #modality_masks.append(mask[:, start_idx:end_idx])
                        # expand it to match the encoded shape
                        temp_mask = mask[:, start_idx]
                        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                        modality_masks_latent.append(temp_mask)
                        modality_masks_data.append(mask[:, start_idx:end_idx])
                        modality_masks_space.append(mask[:, start_idx])
                        start_idx = end_idx
                
                # Calculate R² or loss based on compressibility_type and reduction_criterion
                if compressibility_type == 'linear':
                    # Original linear probing approach (always uses R²)
                    for i, encoded in enumerate(encoded_per_modality):
                        nonzero_mod_data = (val_data.data[i][modality_masks_data[i]]).view(-1, val_data.data[i].shape[1])
                        encoded = (encoded[(modality_masks_latent[i][:,:encoded.shape[1]]).to(device)]).view(-1, encoded.shape[1])
                        n_temp_samples = nonzero_mod_data.shape[0]
                        initial_square = parallel_linear_regression(encoded, 
                                                                    nonzero_mod_data.to(device),
                                                                    n_temp_samples, 
                                                                    int(n_temp_samples*0.9), 
                                                                    device,
                                                                    args,
                                                                    n_epochs=500, 
                                                                    early_stopping=50)
                        current_rsquares.append(initial_square.mean().item())
                        current_rsquare_per_mod[i] = initial_square.mean().item()
                elif compressibility_type == 'direct':
                    if reduction_criterion == 'r_squared':
                        # Direct reconstruction R² approach (original behavior)
                        val_data_list = [val_data.data[i] for i in range(len(data))]
                        direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                        
                        for i, r_squared_val in enumerate(direct_r_squared_values):
                            current_rsquares.append(r_squared_val)
                            current_rsquare_per_mod[i] = r_squared_val
                    elif reduction_criterion in ['train_loss', 'val_loss']:
                        # Loss-based approach - use per-modality reconstruction losses
                        val_data_list = [val_data.data[i] for i in range(len(data))]
                        
                        # Calculate per-modality losses
                        model.eval()
                        with torch.no_grad():
                            val_data_tensors = [d.to(device) for d in val_data_list]
                            reconstructions, _ = model(val_data_tensors)
                            
                            modality_losses = []
                            for i, (original, reconstruction) in enumerate(zip(val_data_tensors, reconstructions)):
                                # Calculate MSE loss for this modality
                                if val_data.mask is not None:
                                    # Apply mask if available
                                    mask_mod = modality_masks_data[i].to(device)
                                    loss = torch.nn.functional.mse_loss(
                                        reconstruction[mask_mod], 
                                        original[mask_mod]
                                    )
                                else:
                                    loss = torch.nn.functional.mse_loss(reconstruction, original)
                                modality_losses.append(loss.item())
                                current_loss_per_mod[i] = loss.item()
                        
                        # For loss-based criteria, we work with losses instead of R²
                        # We'll still use current_rsquares for compatibility but store losses
                        current_rsquares = modality_losses.copy()  # Store losses in rsquares for compatibility
                    else:
                        raise ValueError(f"reduction_criterion must be 'r_squared', 'train_loss', or 'val_loss', got {reduction_criterion}")
                else:
                    raise ValueError(f"compressibility_type must be 'linear' or 'direct', got {compressibility_type}")
                    
                r_squares.append(current_rsquares)

                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                    #if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    #    bottom_reached = True
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        #if len(r_squares) >= 3:
                        #    i_rsquares = [r[i] for r in r_squares[-3:]]
                        #else:
                        #    i_rsquares = None
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                            #if i_rsquares is not None and all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                            #    modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                            #if i_rsquares is not None and all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                            #    modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                            model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # problem
                    #raise ValueError("Both modalities to reduce and increase found, check your logic.")
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [i + 1 for i in modalities_to_increase]
                    #print(f"Increasing layers {layers_to_increase} and reducing layers {layers_to_reduce}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                            #layers_to_increase = []
                        else:
                            # if not, the shared does not need to be increased
                            #if sharedwhenall:
                            #    layers_to_increase = [i + 1 for i in modalities_to_increase]
                            #else:
                            #    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            layers_to_increase = [i + 1 for i in modalities_to_increase]
                            #layers_to_increase = []
                        #patience_counter = 0
                        #print(f"Increasing layers {layers_to_increase}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            #layers_to_reduce = [i for i in range(len(model.adaptive_layers))]
                            #reduce_shared = np.random.rand() < 0.5
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #reduce_shared = np.random.rand() < 0.5
                            if sharedwhenall:
                                reduce_shared = False
                            else:
                                reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:

                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                            #layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]  # +1 for shared layer
                        #patience_counter = 0
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    #print(f"Contractive losses: {contractive_losses}")
                    # reduce this list to the one with the highest contractive loss
                    #if len(layers_to_reduce) == len(model.adaptive_layers):
                    #    # sample 50:50 whether to include the shared space
                    #    if np.random.rand() < 0.9:
                    #        layers_to_reduce = range(1,len(model.adaptive_layers))
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        ### #newest change
                        #if len(layers_to_reduce) == len(model.adaptive_layers):
                        #    # check whether it is the shared or modality-specific layer that has max contractive loss
                        #    max_layer = 0 if contractive_losses[0] == max_contractive_loss else 1
                        #    if max_layer == 0:
                        #        layers_to_reduce = [0]
                        #    else:
                        #        layers_to_reduce = [i + 1 for i in layers_to_reduce]
                        #else: ###
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
                
                ### Plotting test
                # Plot state after rank change
                if last_batch_data is not None and model_name is not None:
                    plot_training_state(model, last_batch_data, last_batch_labels, epoch, 
                                      multi_gpu, plot_save_dir, device, verbose=verbose)
                ###
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            for i in range(len(encoded_per_modality)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)

            # also get mutual information between all spaces
            valid_spaces = []
            for encoded in encoded_per_space:
                valid_spaces_temp = []
                for i in range(len(encoded_per_modality)):
                    mask = modality_masks_space[i]
                    #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
                    if mask is not None:
                        #valid_spaces_temp.append(normalized_encoded[mask])
                        valid_spaces_temp.append(encoded[mask])
                    else:
                        #valid_spaces_temp.append(normalized_encoded)
                        valid_spaces_temp.append(encoded)
                # stack the valid spaces
                valid_spaces_temp = torch.vstack(valid_spaces_temp)
                valid_spaces.append(valid_spaces_temp)
            # compute mutual information for each pair of spaces
            #mutual_info = []
            space_sims = []
            current_space_ranks = [layer.active_dims for layer in model.adaptive_layers]
            for i in range(len(valid_spaces)):
                #entropy_i = - torch.sum(valid_spaces[i][:,:current_space_ranks[i]] * torch.log(valid_spaces[i][:,:current_space_ranks[i]] + 1e-9)) #/ np.log(2)
                for j in range(i + 1, len(valid_spaces)):
                    #entropy_j = - torch.sum(valid_spaces[j][:,:current_space_ranks[j]] * torch.log(valid_spaces[j][:,:current_space_ranks[j]] + 1e-9)) #/ np.log(2)
                    #outer_product = valid_spaces[i][:,:current_space_ranks[i]] @ valid_spaces[j][:,:current_space_ranks[j]].T
                    #mi = (entropy_i + entropy_j - torch.sum(outer_product * torch.log(outer_product + 1e-9))) / math.log(2)
                    #mi = torch.mean(torch.abs(valid_spaces[i][:,:current_space_ranks[i]] @ valid_spaces[j][:,:current_space_ranks[j]].T))
                    #mutual_info.append(mi.item())
                    dot_product = valid_spaces[i][:,:current_space_ranks[i]].T @ valid_spaces[j][:,:current_space_ranks[j]]
                    #sim = torch.mean(dot_product / (torch.norm(valid_spaces[i][:,:current_space_ranks[i]], dim=0) * torch.norm(valid_spaces[j][:,:current_space_ranks[j]], dim=0) + 1e-9))
                    #sim = torch.sum(dot_product) / (torch.norm(valid_spaces[i][:,:current_space_ranks[i]]) * torch.norm(valid_spaces[j][:,:current_space_ranks[j]]) + 1e-9)
                    norm_product_matrix = torch.outer(torch.norm(valid_spaces[i][:,:current_space_ranks[i]], dim=0), torch.norm(valid_spaces[j][:,:current_space_ranks[j]], dim=0))
                    sim = torch.mean(dot_product / norm_product_matrix)
                    space_sims.append(sim.item())
            #rank_history['MI (0-1,0-2,1-2)'].append(', '.join([f"{mi:.2f}" for mi in mutual_info]))
            rank_history['Sim (0-1,0-2,1-2)'].append(', '.join([f"{sim:.4f}" for sim in space_sims]))
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
            if epoch % 100 == 0:
                if last_batch_data is not None and model_name is not None:
                    plot_training_state(model, last_batch_data, last_batch_labels, epoch, 
                                        multi_gpu, plot_save_dir, device, verbose=verbose)

        # Update progress bar with both loss and current rank information
        #if rank_or_sparse == 'rank':
        #    current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        
        # Get normalized weights for display
        if multi_gpu:
            weights = model.module.modality_weights
        else:
            weights = model.modality_weights
            
        pos_weights = F.softplus(weights)
        norm_weights = (pos_weights / (pos_weights.sum() + 1e-8)).detach().cpu().numpy().round(3)
        
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            #if (epoch > early_stopping) & (start_reduction is False): # for quick testing
            #print(f"Early stopping at epoch {epoch}")
            rank_history = {'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[epoch],
                'loss':[train_loss],
                'val_loss':[val_loss]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")
            
            ### Plotting test
            # Plot initial state when starting rank reduction
            if last_batch_data is not None and model_name is not None:
                plot_training_state(model, last_batch_data, last_batch_labels, epoch, 
                                  multi_gpu, plot_save_dir, device, verbose=verbose)
            ###

            #if start_reduction & (reduce_on_best_loss == 'rsquare') & (initial_square is None):
            with torch.no_grad():
                #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            mask = val_data.mask
            modality_masks_latent = []
            modality_masks_data = []
            modality_masks_space = []
            if mask is not None:
                start_idx = 0
                for j, x_m in enumerate(x_val):
                    end_idx = start_idx + x_m.shape[1]
                    temp_mask = mask[:, start_idx]
                    # expand it to match the encoded shape
                    temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                    modality_masks_latent.append(temp_mask)
                    modality_masks_data.append(mask[:, start_idx:end_idx])
                    modality_masks_space.append(mask[:, start_idx])
                    start_idx = end_idx
            
            # Calculate R² or loss based on compressibility_type and reduction_criterion
            if compressibility_type == 'linear':
                # Original linear probing approach
                for i, encoded in enumerate(encoded_per_modality):
                    # get the mask for the validation set
                    nonzero_mod_data = (val_data.data[i][modality_masks_data[i]]).view(-1, val_data.data[i].shape[1])
                    encoded = (encoded[(modality_masks_latent[i][:,:encoded.shape[1]]).to(device)]).view(-1, encoded.shape[1])
                    n_temp_samples = nonzero_mod_data.shape[0]
                    initial_square = parallel_linear_regression(encoded, 
                                                                nonzero_mod_data.to(device),
                                                                n_temp_samples, 
                                                                int(n_temp_samples*0.9), 
                                                                device,
                                                                args,
                                                                n_epochs=500, 
                                                                early_stopping=50)
                    initial_squares[i] = initial_square.mean().item()
                    
                    # Calculate threshold based on threshold_type
                    if threshold_type == 'relative':
                        min_rsquares.append(initial_square.mean().item() * r_square_threshold)
                    elif threshold_type == 'absolute':
                        min_rsquares.append(initial_square.mean().item() - r_square_threshold)
                    else:
                        raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                        
                    current_rsquare_per_mod[i] = initial_square.mean().item()
                    rank_history[f'rsquare {i}'] = [initial_square.mean().item()]
                    
            elif compressibility_type == 'direct':
                if reduction_criterion == 'r_squared':
                    # Direct reconstruction R² approach (original behavior)
                    val_data_list = [val_data.data[i] for i in range(len(data))]
                    direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                    
                    for i, r_squared_val in enumerate(direct_r_squared_values):
                        initial_squares[i] = r_squared_val
                        
                        # Calculate threshold based on threshold_type
                        if threshold_type == 'relative':
                            min_rsquares.append(r_squared_val * r_square_threshold)
                        elif threshold_type == 'absolute':
                            min_rsquares.append(r_squared_val - r_square_threshold)
                        else:
                            raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                            
                        current_rsquare_per_mod[i] = r_squared_val
                        rank_history[f'rsquare {i}'] = [r_squared_val]
                elif reduction_criterion in ['train_loss', 'val_loss']:
                    # Loss-based approach - calculate initial losses
                    val_data_list = [val_data.data[i] for i in range(len(data))]
                    
                    # Calculate per-modality losses
                    model.eval()
                    with torch.no_grad():
                        val_data_tensors = [d.to(device) for d in val_data_list]
                        reconstructions, _ = model(val_data_tensors)
                        
                        for i, (original, reconstruction) in enumerate(zip(val_data_tensors, reconstructions)):
                            # Calculate MSE loss for this modality
                            if val_data.mask is not None:
                                # Apply mask if available
                                mask_mod = modality_masks_data[i].to(device)
                                loss = torch.nn.functional.mse_loss(
                                    reconstruction[mask_mod], 
                                    original[mask_mod]
                                )
                            else:
                                loss = torch.nn.functional.mse_loss(reconstruction, original)
                            
                            initial_losses[i] = loss.item()
                            initial_squares[i] = loss.item()  # Store in initial_squares for compatibility
                            
                            # Calculate threshold based on threshold_type
                            # For loss: lower is better, so thresholds work differently
                            if threshold_type == 'relative':
                                # relative threshold means multiply by factor
                                min_rsquares.append(loss.item() * r_square_threshold)
                            elif threshold_type == 'absolute':
                                # absolute threshold means add/subtract value
                                min_rsquares.append(loss.item() + r_square_threshold)  # Note: + for loss
                            else:
                                raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                                
                            current_rsquare_per_mod[i] = loss.item()  # Store in rsquare for compatibility
                            current_loss_per_mod[i] = loss.item()
                            rank_history[f'rsquare {i}'] = [loss.item()]  # Store in rsquare history for compatibility
                else:
                    raise ValueError(f"reduction_criterion must be 'r_squared', 'train_loss', or 'val_loss', got {reduction_criterion}")
            else:
                raise ValueError(f"compressibility_type must be 'linear' or 'direct', got {compressibility_type}")
                #print(f"Initial R-squared value: {initial_square.mean().item()}, setting threshold to {min_rsquare}")
            #min_rsquare /= len(encoded_per_modality)
            if verbose:
                if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                    print(f"Initial {reduction_criterion} values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
                else:
                    print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
            # also get mutual information between all spaces
            valid_spaces = []
            for encoded in encoded_per_space:
                valid_spaces_temp = []
                for i in range(len(encoded_per_modality)):
                    mask = modality_masks_space[i]
                    # normalize the encoded to fit between 0 and 1 so that I don't get NaNs in the entropy for negative values
                    #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
                    if mask is not None:
                        #valid_spaces_temp.append(normalized_encoded[mask])
                        valid_spaces_temp.append(encoded[mask])
                    else:
                        #valid_spaces_temp.append(normalized_encoded)
                        valid_spaces_temp.append(encoded)
                # stack the valid spaces
                # print the shapes of the temp spaces
                valid_spaces_temp = torch.vstack(valid_spaces_temp)
                valid_spaces.append(valid_spaces_temp)
            # compute mutual information for each pair of spaces
            space_sims = []
            current_space_ranks = [layer.active_dims for layer in model.adaptive_layers]
            for i in range(len(valid_spaces)):
                # H(X)
                #entropy_i = - torch.sum(valid_spaces[i][:,:current_space_ranks[i]] * torch.log(valid_spaces[i][:,:current_space_ranks[i]] + 1e-9)) #/ np.log(2)
                for j in range(i + 1, len(valid_spaces)):
                    # H(Y)
                    #entropy_j = - torch.sum(valid_spaces[j][:,:current_space_ranks[j]] * torch.log(valid_spaces[j][:,:current_space_ranks[j]] + 1e-9)) #/ np.log(2)
                    # can I treat the outer product as the joint distribution?
                    # H(X,Y)
                    #outer_product = valid_spaces[i][:,:current_space_ranks[i]] @ valid_spaces[j][:,:current_space_ranks[j]].T
                    #entropy_ij = - torch.sum(outer_product * torch.log(outer_product + 1e-9))
                    # mutual information
                    #mi = (entropy_i + entropy_j - entropy_ij) / math.log(2)
                    #mi = torch.mean(torch.abs(valid_spaces[i][:,:current_space_ranks[i]] @ valid_spaces[j][:,:current_space_ranks[j]].T))
                    dot_product = valid_spaces[i][:,:current_space_ranks[i]].T @ valid_spaces[j][:,:current_space_ranks[j]]
                    #magnitudes = torch.sqrt(torch.sum(valid_spaces[i][:,:current_space_ranks[i]] ** 2, dim=0)) * torch.sqrt(torch.sum(valid_spaces[j][:,:current_space_ranks[j]] ** 2, dim=0))
                    #sim = torch.mean(dot_product / (magnitude + 1e-9))
                    #sim = torch.mean(dot_product / (torch.norm(valid_spaces[i][:,:current_space_ranks[i]], dim=0) * torch.norm(valid_spaces[j][:,:current_space_ranks[j]], dim=0) + 1e-9))
                    norm_product_matrix = torch.outer(torch.norm(valid_spaces[i][:,:current_space_ranks[i]], dim=0), torch.norm(valid_spaces[j][:,:current_space_ranks[j]], dim=0))
                    sim = torch.mean(dot_product / norm_product_matrix)
                    space_sims.append(sim.item())
            rank_history['Sim (0-1,0-2,1-2)'] = [', '.join([f"{sim:.4f}" for sim in space_sims])]
            if verbose:
                print(f"Cosine similarity between spaces: {space_sims}")

        
        # early stopping but conditioned on rank reduction
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
    
    # final plotting
    if last_batch_data is not None and model_name is not None:
        plot_training_state(model, last_batch_data, last_batch_labels, epoch, 
                            multi_gpu, plot_save_dir, device, verbose=verbose)
    
    # Calculate latent representations in batches
    #'''
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)#.cpu()
            else:
                batch_reps = model.encode(x_batch)#.cpu()
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            # No need to convert dtype
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Create movie from plots at the end of training (only if model_name is provided)
    if model_name is not None:
        create_training_movie(plot_save_dir)
    
    # Combine latent representations from all batches
    #reps = torch.cat(reps_list, dim=0)

    '''
    # empty cache
    del model, optimizer, loss_fn, data_loader
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    # Linear regression evaluation
    r_squares = parallel_linear_regression(reps, data[:n_samples_train], n_samples_train, int(n_samples_train*0.9), n_epochs=500, early_stopping=50, verbose=True)
    
    # remove all nan and inf values
    r_squares = r_squares[torch.isfinite(r_squares)]
    
    # Free memory
    del reps, reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()
    '''

    #return np.mean(train_losses[-5:]), r_squares.mean().item(), rank_history, train_losses
    return model, reps, np.mean(train_losses[-5:]), r_squares[-1], rank_history, [train_losses, val_losses]

import os
def train_overcomplete_ae_with_pretrained(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, paired=False
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """

    # check if there is an existing pretrained model for the seed, early stopping, and training hyperparameters (lr, wd, batch size, model architecture)
    pretrained_model_path = f"./03_results/models/pretrained_models/{pretrained_name}.pt" if pretrained_name else None
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        input_dims = [d.shape[1] for d in data]
        if isinstance(latent_dim, int):
            latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
        elif isinstance(latent_dim, list):
            if (len(latent_dim) == 1) & (len(input_dims) > 1):
                latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
            else:
                latent_dims = latent_dim
        model = AdaptiveRankReducedAE(
            input_dims, latent_dims, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        )
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
        # make sure that the weights are changed
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        # also load the loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        # print last losses
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            model, [train_losses, val_losses], data_indices = pretrain_overcomplete_ae(
                data, n_samples_train, latent_dim, device, args, epochs=epochs, early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, ae_width=ae_width, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank,
                verbose=verbose
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            # also save data_indices
            if data_indices is not None:
                data_indices_path = pretrained_model_path.replace('.pt', '_data_indices.pt')
                torch.save(data_indices, data_indices_path)
            print(f"Saved pretrained model to {pretrained_model_path} and loss curves to {loss_curve_path}")
            model.epoch = len(train_losses)
        else:
            raise ValueError("model_name must be provided to save/load pretrained models.")
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    if paired:
        # load the saved data_indices if it exists
        data_indices_path = f"./03_results/models/pretrained_models/{pretrained_name}_data_indices.pt"
        data_indices = torch.load(data_indices_path)
        #data_indices = torch.randperm(data[0].shape[0])
        train_indices = data_indices[:n_samples_train]
        val_indices = data_indices[n_samples_train:]
    else:
        print("Using non-paired data splitting")
        # use the first n_samples_train samples for training and the rest for validation
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = MMSimData(train_data)
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = MMSimData(val_data)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * len(data) # per modality
    initial_losses = [None] * len(data) # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * len(data)
    current_loss_per_mod = [None] * len(data)  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(len(data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, (x, mask) in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
                # sanity check: see how many samples are masked
                #print([f"Modality {i} data shape: {x_m.shape}, mask shape: {mask_i.shape}, fraction unmasked: {mask_i.sum() / mask_i.numel()}" for i, mask_i in enumerate(modality_masks)])
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Compute MSE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val, mask in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(data))],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1

        if (start_reduction is False) and (epoch == model.epoch + patience): # giving the optimizer a start
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)

            with torch.no_grad():
                encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            mask = val_data.mask
            modality_masks_latent = []
            modality_masks_data = []
            modality_masks_space = []
            if mask is not None:
                start_idx = 0
                for j, x_m in enumerate(x_val):
                    end_idx = start_idx + x_m.shape[1]
                    temp_mask = mask[:, start_idx]
                    # expand it to match the encoded shape
                    temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                    modality_masks_latent.append(temp_mask)
                    modality_masks_data.append(mask[:, start_idx:end_idx])
                    modality_masks_space.append(mask[:, start_idx])
                    start_idx = end_idx
            
            # Direct reconstruction R² approach (original behavior)
            val_data_list = [val_data.data[i] for i in range(len(data))]
            direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                # Calculate threshold based on threshold_type
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                    
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
                
            if verbose:
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                with torch.no_grad():
                    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                    encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                    if not compute_jacobian:
                        encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                    else:
                        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                mask = val_data.mask
                modality_masks_data = []
                modality_masks_latent = []
                modality_masks_space = []
                if mask is not None:
                    start_idx = 0
                    for j, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        #modality_masks.append(mask[:, start_idx:end_idx])
                        # expand it to match the encoded shape
                        temp_mask = mask[:, start_idx]
                        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                        modality_masks_latent.append(temp_mask)
                        modality_masks_data.append(mask[:, start_idx:end_idx])
                        modality_masks_space.append(mask[:, start_idx])
                        start_idx = end_idx
                
                # Direct reconstruction R² approach (original behavior)
                val_data_list = [val_data.data[i] for i in range(len(data))]
                direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                        
                r_squares.append(current_rsquares)

                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                            model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [i + 1 for i in modalities_to_increase]
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [i + 1 for i in modalities_to_increase]
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            if sharedwhenall:
                                reduce_shared = False
                            else:
                                reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:

                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            for i in range(len(encoded_per_modality)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)

            # also get mutual information between all spaces
            valid_spaces = []
            for encoded in encoded_per_space:
                valid_spaces_temp = []
                for i in range(len(encoded_per_modality)):
                    mask = modality_masks_space[i]
                    #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
                    if mask is not None:
                        #valid_spaces_temp.append(normalized_encoded[mask])
                        valid_spaces_temp.append(encoded[mask])
                    else:
                        #valid_spaces_temp.append(normalized_encoded)
                        valid_spaces_temp.append(encoded)
                # stack the valid spaces
                valid_spaces_temp = torch.vstack(valid_spaces_temp)
                valid_spaces.append(valid_spaces_temp)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # Get normalized weights for display
        if multi_gpu:
            weights = model.module.modality_weights
        else:
            weights = model.modality_weights
            
        pos_weights = F.softplus(weights)
        norm_weights = (pos_weights / (pos_weights.sum() + 1e-8)).detach().cpu().numpy().round(3)

        
        # early stopping but conditioned on rank reduction
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
    
    # Calculate latent representations in batches
    #'''
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)#.cpu()
            else:
                batch_reps = model.encode(x_batch)#.cpu()
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            # No need to convert dtype
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]

    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses]

def post_train_multimodal_ae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50,
                            lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5,
                            patience=10, verbose=True, model_name=None, pretrained_name=None,
                            recon_loss_balancing=False, paired=False):
    """
    Post-training function that loads a pretrained multimodal model and continues training until early stopping.
    
    Parameters:
    - data: Input data tensor list
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - device: Device to run computation on
    - args: Arguments object containing multi_gpu and other settings
    - epochs: Maximum number of additional training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate for continued training
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder (for model architecture verification)
    - ae_width: Width multiplier for hidden layers (for model architecture verification)
    - dropout: Dropout rate (for model architecture verification)
    - wd: Weight decay
    - patience: Early stopping patience
    - verbose: Print progress
    - model_name: Name to save the final model
    - pretrained_name: Name of the pretrained model to load
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities
    - paired: Whether to use paired data splitting
    
    Returns:
    - model: The trained model
    - train_losses: List of training losses during post-training
    - val_losses: List of validation losses during post-training
    """
    import os
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import tqdm
    
    # Check for existing pretrained model
    if not pretrained_name:
        raise ValueError("pretrained_name must be provided to load pretrained model.")
    
    pretrained_model_path = f"./03_results/models/pretrained_models/{pretrained_name}.pt"
    if not os.path.exists(pretrained_model_path):
        raise FileNotFoundError(f"Pretrained model not found at {pretrained_model_path}")
    
    print(f"Loading pretrained model from {pretrained_model_path} for post-training...")
    
    # Recreate model architecture
    input_dims = [d.shape[1] for d in data]
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1)
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    
    model = AdaptiveRankReducedAE(
        input_dims, latent_dims, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=1.0, min_rank=10
    )
    
    # Load pretrained weights
    model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
    model.eval()
    for param in model.parameters():
        param.requires_grad = True
    
    # Add epoch tracking attribute if it doesn't exist
    if not hasattr(model, 'epoch'):
        model.epoch = 0
    
    print(f"Loaded pretrained model. Current epoch: {model.epoch}")
    
    # Move to device and handle multi-GPU
    model.to(device)
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} for multi-GPU")
        
        try:
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0")
            
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
            
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loaders
    if paired:
        # Load saved data indices if they exist
        data_indices_path = f"./03_results/models/pretrained_models/{pretrained_name}_data_indices.pt"
        if os.path.exists(data_indices_path):
            data_indices = torch.load(data_indices_path)
            train_indices = data_indices[:n_samples_train]
            val_indices = data_indices[n_samples_train:]
        else:
            print("Warning: Paired splitting requested but no saved data indices found. Using random split.")
            data_indices = torch.randperm(data[0].shape[0])
            train_indices = data_indices[:n_samples_train]
            val_indices = data_indices[n_samples_train:]
    else:
        # Use simple sequential split
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    
    train_data = [d[train_indices] for d in data]
    train_data = MMSimData(train_data)
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    
    val_data = [data[i][val_indices] for i in range(len(data))]
    val_data = MMSimData(val_data)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    
    # Training loop
    train_losses = []
    val_losses = []
    best_loss = float('inf')
    patience_counter = 0
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9
    
    print(f"Starting post-training for up to {epochs} epochs...")
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, (x, mask) in enumerate(data_loader):
            loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_encoded = model(x)
            
            # Calculate per-modality losses
            modality_losses = []
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                m_loss = F.mse_loss(x_hat_m, x_m)
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation
                for m_loss in modality_losses:
                    loss += m_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            for x_val, mask in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)
                
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                val_loss += val_batch_loss / len(x_val)
        
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        
        # Update progress bar
        log_dict = {
            'loss': round(train_loss, 4),
            'val_loss': round(val_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss and patience
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if (epoch > early_stopping) and (min(val_losses[-early_stopping:]) > min(val_losses)) and (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
    
    # Save the final model if model_name provided
    if model_name:
        os.makedirs("./03_results/models/", exist_ok=True)
        final_model_path = f"./03_results/models/{model_name}_post_trained.pt"
        if multi_gpu:
            torch.save(model.module.state_dict(), final_model_path)
        else:
            torch.save(model.state_dict(), final_model_path)
        if verbose:
            print(f"Post-trained model saved to {final_model_path}")
    
    if verbose:
        print(f"Post-training completed. Final training loss: {train_losses[-1]:.6f}, Final validation loss: {val_losses[-1]:.6f}")
    
    return model, train_losses, val_losses

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import make_scorer

def compute_classification(z_np, y_np):
    z_np = z_np
    y_np = y_np.astype(int)
    # Check if binary classification
    unique_classes = np.unique(y_np)
    
    # Calculate minimum samples per class using unique values and their counts
    unique_values, class_counts = np.unique(y_np, return_counts=True)
    min_samples_per_class = np.min(class_counts)
    
    # Adaptive cross-validation: use fewer folds if classes have few samples
    max_cv_folds = min(5, min_samples_per_class)
    if max_cv_folds < 2:
        # If we can't do cross-validation, use train-test split
        print(f"Using train-test split instead of cross-validation: {max_cv_folds} folds")
        
        # Check if we can use stratified split (need at least 2 samples per class)
        can_stratify = min_samples_per_class >= 2
        
        if can_stratify:
            X_train, X_test, y_train, y_test = train_test_split(
                z_np, y_np, test_size=0.3, random_state=42, stratify=y_np
            )
        else:
            # Fall back to simple random split if stratification is not possible
            X_train, X_test, y_train, y_test = train_test_split(
                z_np, y_np, test_size=0.3, random_state=42
            )
        
        model = LogisticRegression(
            max_iter=1000, 
            class_weight='balanced',
            solver='liblinear',
            random_state=42
        )
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        return accuracy_score(y_test, y_pred)
        
    # Initialize classifier with balanced class weights for robustness
    model = LogisticRegression(
        max_iter=1000, 
        class_weight='balanced',
        solver='liblinear',  # Works well for small datasets
        random_state=42
    )
    
    # Use adaptive cross-validation for robust performance estimation
    
    scoring = {
                'accuracy': make_scorer(accuracy_score)
            }
    cv_results = cross_validate(
                model, z_np, y_np, 
                cv=max_cv_folds, 
                scoring=scoring,
                return_train_score=False
            )
    return cv_results['test_accuracy'].mean()


# ============================================================================
# VAE-specific functions
# ============================================================================

def compute_vae_loss(x_recon, x, vae_params, kl_weight=1.0, reduction='mean'):
    """
    Compute VAE loss = Reconstruction Loss + KL Divergence
    
    Args:
        x_recon: List of reconstructed outputs for each modality
        x: List of original inputs for each modality
        vae_params: Dictionary with 'z_shared' (deterministic), 'specific_mus', 'specific_logvars'
        kl_weight: Weight for KL divergence term
        reduction: 'mean' or 'sum' for reconstruction loss
    
    Returns:
        total_loss, recon_loss, kl_loss
    """
    # Reconstruction loss (MSE averaged over modalities)
    recon_loss = 0
    for x_orig, x_rec in zip(x, x_recon):
        if reduction == 'mean':
            recon_loss += F.mse_loss(x_rec, x_orig, reduction='mean')
        else:
            recon_loss += F.mse_loss(x_rec, x_orig, reduction='sum')
    recon_loss = recon_loss / len(x)  # Average across modalities
    
    # KL divergence only for modality-specific spaces (shared is deterministic)
    kl_loss = 0
    for mu, logvar in zip(vae_params['specific_mus'], vae_params['specific_logvars']):
        kl_mod = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_mod = kl_mod / mu.size(0)  # Average over batch
        kl_loss += kl_mod
    kl_loss = kl_loss / len(vae_params['specific_mus'])  # Average across modalities
    
    # Total loss
    total_loss = recon_loss + kl_weight * kl_loss
    
    return total_loss, recon_loss, kl_loss


def compute_direct_r_squared_vae(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on VAE reconstruction performance
    
    Parameters:
    - model: The trained VAE model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    with torch.no_grad():
        # Get model predictions - use deterministic encoding
        data_tensors = [d.to(device) for d in data]
        
        # Encode to get deterministic shared and mean-only specific latents
        z_shared, specific_mus, specific_logvars, _ = model.encode(data_tensors)
        
        # Use means for modality-specific (deterministic for R² evaluation)
        z_specifics = specific_mus  # No sampling for R² computation
        
        # Decode
        reconstructions = model.decode(z_shared, z_specifics)
        
        # Calculate R² for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            # Calculate mean of original data
            original_mean = original.mean(dim=0).cpu()
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()
            
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean for modality {i}. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data for modality {i}. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                # Normal case - calculate standard R²
                ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                ss_tot = ((original_cpu - original_mean)**2).sum(0)
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            
            r_squared_values.append(r_squared.item())
    
    return r_squared_values


def train_multimodal_vae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50,
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5,
                         initial_rank_ratio=1.0, min_rank=10,
                         rank_reduction_frequency=10, rank_reduction_threshold=0.01,
                         warmup_epochs=0, patience=10,
                         r_square_threshold=0.9, threshold_type='relative',
                         kl_weight=1.0, kl_anneal_epochs=None,
                         verbose=True, model_name=None, **kwargs):
    """
    Train a multi-modal VAE with adaptive rank reduction
    
    Parameters:
    - data: List of input data tensors [modality1, modality2, ...]
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of latent spaces (can be int or list)
    - device: Device to train on
    - args: Arguments object with multi_gpu, gpu_ids, etc.
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of encoder/decoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio for adaptive layers
    - min_rank: Minimum rank for adaptive layers
    - rank_reduction_frequency: How often to try reducing rank
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs before rank reduction starts
    - patience: Patience for rank reduction based on performance
    - r_square_threshold: R² threshold for rank reduction
    - threshold_type: 'relative' or 'absolute'
    - kl_weight: Weight for KL divergence term
    - kl_anneal_epochs: Number of epochs to anneal KL weight from 0 to kl_weight (if None, use kl_weight from start)
    - verbose: Print detailed information
    - model_name: Name for the model (for saving)
    
    Returns:
    - model: Trained model
    - reps: Latent representations
    - final_loss: Final training loss
    - r_squares: Final R² values
    - rank_history: Dictionary with training history
    - loss_curves: Tuple of (train_losses, val_losses)
    """
    
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model
    input_dims = [d.shape[1] for d in data]
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1)  # +1 for shared space
    elif isinstance(latent_dim, list):
        if len(latent_dim) == 1 and len(input_dims) > 1:
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    
    model = AdaptiveRankReducedVAE(
        input_dims, latent_dims, depth=ae_depth, width=ae_width,
        dropout=dropout, initial_rank_ratio=initial_rank_ratio,
        min_rank=min_rank
    ).to(device)
    
    if verbose:
        print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU (simplified for now)
    if multi_gpu:
        if verbose:
            print("Multi-GPU training not fully implemented for VAE yet, using single GPU")
        multi_gpu = False
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Prepare data - match the non-VAE training approach
    # Randomize the data indices first, then split into train/val
    data_indices = torch.randperm(data[0].shape[0])
    train_indices = data_indices[:n_samples_train]
    val_indices = data_indices[n_samples_train:]
    train_data = [d[train_indices] for d in data]
    val_data = [d[val_indices] for d in data]
    
    # Create DataLoader
    # We need to create a custom dataset for multiple modalities
    class MultiModalDataset(torch.utils.data.Dataset):
        def __init__(self, data_list):
            self.data_list = data_list
            self.n_samples = data_list[0].shape[0]
        
        def __len__(self):
            return self.n_samples
        
        def __getitem__(self, idx):
            return [d[idx] for d in self.data_list]
    
    train_dataset = MultiModalDataset(train_data)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=lambda batch: [torch.stack([item[i] for item in batch]) for i in range(len(batch[0]))]
    )
    
    val_dataset = MultiModalDataset(val_data)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=lambda batch: [torch.stack([item[i] for item in batch]) for i in range(len(batch[0]))]
    )
    
    # Training history
    train_losses = []
    val_losses = []
    train_recon_losses = []
    train_kl_losses = []
    rank_history = {
        'epoch': [],
        'ranks': [],
        'total_rank': []
    }
    
    # Add R² tracking for each modality
    for i in range(len(input_dims)):
        rank_history[f'rsquare {i}'] = []
    
    # Compute initial R² threshold if relative
    if threshold_type == 'relative':
        initial_r_squares = compute_direct_r_squared_vae(model, train_data, device, multi_gpu, verbose)
        absolute_threshold = [r * r_square_threshold for r in initial_r_squares]
        if verbose:
            print(f"Initial R²: {initial_r_squares}")
            print(f"Absolute R² thresholds: {absolute_threshold}")
    else:
        absolute_threshold = [r_square_threshold] * len(input_dims)
    
    best_val_loss = float('inf')
    patience_counter = 0
    epochs_since_rank_change = 0
    
    # Training loop
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Determine current KL weight (with annealing if specified)
        if kl_anneal_epochs is not None and epoch < kl_anneal_epochs:
            current_kl_weight = kl_weight * (epoch / kl_anneal_epochs)
        else:
            current_kl_weight = kl_weight
        
        # Training phase
        model.train()
        train_loss_epoch = 0.0
        train_recon_epoch = 0.0
        train_kl_epoch = 0.0
        
        for batch_data in train_loader:
            batch_data = [d.to(device) for d in batch_data]
            
            optimizer.zero_grad()
            x_recon, vae_params = model(batch_data)
            
            loss, recon_loss, kl_loss = compute_vae_loss(
                x_recon, batch_data, vae_params, 
                kl_weight=current_kl_weight, reduction='mean'
            )
            
            loss.backward()
            optimizer.step()
            
            train_loss_epoch += loss.item()
            train_recon_epoch += recon_loss.item()
            train_kl_epoch += kl_loss.item()
        
        train_loss_epoch /= len(train_loader)
        train_recon_epoch /= len(train_loader)
        train_kl_epoch /= len(train_loader)
        
        train_losses.append(train_loss_epoch)
        train_recon_losses.append(train_recon_epoch)
        train_kl_losses.append(train_kl_epoch)
        
        # Validation phase
        model.eval()
        val_loss_epoch = 0.0
        with torch.no_grad():
            for batch_data in val_loader:
                batch_data = [d.to(device) for d in batch_data]
                x_recon, vae_params = model(batch_data)
                loss, _, _ = compute_vae_loss(
                    x_recon, batch_data, vae_params,
                    kl_weight=current_kl_weight, reduction='mean'
                )
                val_loss_epoch += loss.item()
        val_loss_epoch /= len(val_loader)
        val_losses.append(val_loss_epoch)
        
        # Compute R²
        current_r_squares = compute_direct_r_squared_vae(model, train_data, device, multi_gpu, verbose=False)
        
        # Update progress bar
        pbar.set_postfix({
            'train_loss': f'{train_loss_epoch:.4f}',
            'val_loss': f'{val_loss_epoch:.4f}',
            'recon': f'{train_recon_epoch:.4f}',
            'kl': f'{train_kl_epoch:.4f}',
            'R²': f'{current_r_squares[0]:.3f},{current_r_squares[1]:.3f}',
            'rank': model.get_total_rank()
        })
        
        # Track history
        rank_history['epoch'].append(epoch)
        # Get current ranks for each adaptive layer
        current_ranks = [layer.active_dims for layer in model.adaptive_layers]
        rank_history['ranks'].append(current_ranks)
        rank_history['total_rank'].append(model.get_total_rank())
        for i, r2 in enumerate(current_r_squares):
            rank_history[f'rsquare {i}'].append(r2)
        
        # Early stopping check (only after KL warmup if annealing is used)
        min_epoch_for_early_stopping = kl_anneal_epochs if kl_anneal_epochs is not None else 0
        if epoch >= min_epoch_for_early_stopping:
            if val_loss_epoch < best_val_loss:
                best_val_loss = val_loss_epoch
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= early_stopping:
                if verbose:
                    print(f"\nEarly stopping at epoch {epoch}")
                break
        else:
            # During warmup, just track best loss but don't increment patience
            if val_loss_epoch < best_val_loss:
                best_val_loss = val_loss_epoch
        
        # Rank reduction logic (after warmup and periodically)
        epochs_since_rank_change += 1
        if epoch >= warmup_epochs and epochs_since_rank_change >= rank_reduction_frequency:
            # Check if we should reduce rank based on R²
            can_reduce = all(r2 >= thresh for r2, thresh in zip(current_r_squares, absolute_threshold))
            
            if can_reduce:
                # Try to reduce rank
                layer_ids = list(range(len(model.adaptive_layers)))
                changes_made = model.reduce_rank(
                    reduction_ratio=0.9,
                    threshold=rank_reduction_threshold,
                    layer_ids=layer_ids,
                    dim=0
                )
                if changes_made:
                    if verbose:
                        print(f"\nRank reduced at epoch {epoch}. New total rank: {model.get_total_rank()}")
                    epochs_since_rank_change = 0
        
        # Garbage collection
        if epoch % 100 == 0:
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Get final representations
    model.eval()
    with torch.no_grad():
        reps = model.encode_modalities([d.to(device) for d in train_data])
        reps = [r.cpu() for r in reps]
    
    # Clean up
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return model, reps, train_losses[-1], current_r_squares, rank_history, (train_losses, val_losses)