"""
Training functions for MultiBench CNN-based multimodal autoencoders with adaptive rank reduction.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import gc
import os
import math
import random

from src.models.larrp_multibench import AdaptiveRankReducedAE_MultiBench, AdaptiveRankReducedAE_Static
from src.functions.linear_probing import parallel_linear_regression
from src.data.multibench_loader import _process_1, _process_2


def compute_reconstruction_loss(reconstructions, targets, masks=None):
    """
    Compute reconstruction loss with optional masking.
    
    Args:
        reconstructions: List of tensors [(batch, seq_len, n_features), ...]
        targets: List of tensors [(batch, seq_len, n_features), ...]
        masks: Optional list of boolean masks [(batch, seq_len, n_features), ...]
    
    Returns:
        Total reconstruction loss (scalar)
    """
    total_loss = 0.0
    n_modalities = len(reconstructions)
    
    for m in range(n_modalities):
        recon = reconstructions[m]
        target = targets[m]
        
        # Compute MSE
        mse = F.mse_loss(recon, target, reduction='none')
        
        # Apply mask if provided
        if masks is not None and masks[m] is not None:
            mask = masks[m].float()
            # Average over masked positions only
            mse = (mse * mask).sum() / (mask.sum() + 1e-8)
        else:
            # Average over all positions
            mse = mse.mean()
        
        total_loss += mse
    
    return total_loss / n_modalities


def compute_direct_r_squared_multibench(model, val_loader, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance for MultiBench data.
    
    Parameters:
    - model: The trained model
    - val_loader: Validation data loader
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    # Access model methods (handle DataParallel)
    model_module = model.module if multi_gpu and hasattr(model, 'module') else model
    # Number of data modalities (adaptive_layers includes shared + modality-specific)
    n_modalities = getattr(model_module, 'n_modalities', max(1, len(model_module.adaptive_layers) - 1))
    
    # Track total variances and residuals per modality
    total_var_per_mod = [0.0] * n_modalities
    residual_var_per_mod = [0.0] * n_modalities
    n_samples = 0
    
    def _normalize_batch(batch):
        """Normalize different collate outputs into (data_list, masks).

        Supported collate outputs:
        - _process_1: (processed_input_list, processed_input_lengths, inds, labels)
        - _process_2: (tensor_mod1, tensor_mod2, tensor_mod3, labels)
        - (data, masks)
        - fallback: iterable of modality tensors
        """
        masks = None
        # tuple/list batches from collate
        if isinstance(batch, (tuple, list)):
            # Case: _process_1 -> first element is a list of modality tensors
            if len(batch) >= 1 and isinstance(batch[0], list):
                data = batch[0]
            # Case: _process_2 -> last element is labels, modalities before
            elif len(batch) >= 3 and not isinstance(batch[0], list):
                # assume last element is labels
                data = list(batch[:-1])
            # Case: (data, masks)
            elif len(batch) == 2:
                data, masks = batch
                # data might be list/tuple or tensor sequence
                if isinstance(data, (tuple, list)):
                    data = list(data)
                else:
                    data = [data]
            else:
                # fallback: try to coerce into list of tensors
                data = list(batch)
        else:
            # single sample
            if isinstance(batch, (tuple, list)):
                data = list(batch)
            else:
                data = [batch]

        # move tensors to device
        data = [d.to(device, non_blocking=True) for d in data]
        if masks is not None:
            masks = [m.to(device, non_blocking=True) if m is not None else None for m in masks]
        return data, masks

    with torch.no_grad():
        for batch in val_loader:
            data, masks = _normalize_batch(batch)

            # Forward pass
            reconstructions, _ = model(data)
            
            # Calculate per-modality variances
            for m in range(n_modalities):
                target = data[m]
                recon = reconstructions[m]
                
                # Apply mask if provided
                if masks is not None and masks[m] is not None:
                    mask = masks[m].to(device)
                    target = target[mask]
                    recon = recon[mask]
                
                # Flatten for computation (use reshape to handle non-contiguous tensors)
                target_flat = target.reshape(-1)
                recon_flat = recon.reshape(-1)
                
                # Compute variances
                total_var_per_mod[m] += torch.var(target_flat, unbiased=False).item() * target_flat.numel()
                residual_var_per_mod[m] += torch.sum((target_flat - recon_flat) ** 2).item()
                
            n_samples += data[0].shape[0]
    
    # Compute R² for each modality
    for m in range(n_modalities):
        if total_var_per_mod[m] > 0:
            r_squared = 1.0 - (residual_var_per_mod[m] / total_var_per_mod[m])
            r_squared_values.append(max(0.0, r_squared))  # Clip to [0, 1]
        else:
            r_squared_values.append(0.0)
    
    return r_squared_values


def compute_direct_explained_variance_multibench(model, val_loader, device, multi_gpu=False, verbose=False):
    """
    Compute explained variance score based on direct model reconstruction performance for MultiBench data.
    
    Parameters:
    - model: The trained model
    - val_loader: Validation data loader
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of explained variance values for each modality
    """
    model.eval()
    explained_variance_values = []
    
    # Access model methods (handle DataParallel)
    model_module = model.module if multi_gpu and hasattr(model, 'module') else model
    # Number of data modalities (adaptive_layers includes shared + modality-specific)
    n_modalities = getattr(model_module, 'n_modalities', max(1, len(model_module.adaptive_layers) - 1))
    
    # Track variance components per modality
    residual_var_per_mod = [0.0] * n_modalities
    total_var_per_mod = [0.0] * n_modalities
    n_samples = 0
    
    def _normalize_batch(batch):
        """Normalize different collate outputs into (data_list, masks)."""
        masks = None
        if isinstance(batch, (tuple, list)):
            if len(batch) >= 1 and isinstance(batch[0], list):
                data = batch[0]
            elif len(batch) >= 3 and not isinstance(batch[0], list):
                data = list(batch[:-1])
            elif len(batch) == 2:
                data, masks = batch
                if isinstance(data, (tuple, list)):
                    pass
                else:
                    data = [data]
            else:
                data = list(batch)
        else:
            if isinstance(batch, (tuple, list)):
                data = list(batch)
            else:
                data = [batch]

        data = [d.to(device, non_blocking=True) for d in data]
        if masks is not None:
            masks = [m.to(device, non_blocking=True) if m is not None else None for m in masks]
        return data, masks

    with torch.no_grad():
        for batch in val_loader:
            data, masks = _normalize_batch(batch)

            # Forward pass
            reconstructions, _ = model(data)
            
            # Calculate per-modality explained variance
            for m in range(n_modalities):
                target = data[m]
                recon = reconstructions[m]
                
                # Apply mask if provided
                if masks is not None and masks[m] is not None:
                    mask = masks[m].float()
                    target = target * mask
                    recon = recon * mask
                
                # Flatten for computation
                target_flat = target.reshape(-1)
                recon_flat = recon.reshape(-1)
                
                # Compute explained variance: 1 - Var(y - y_pred) / Var(y)
                residual_var = torch.var(target_flat - recon_flat, unbiased=False).item()
                total_var = torch.var(target_flat, unbiased=False).item()
                
                residual_var_per_mod[m] += residual_var
                total_var_per_mod[m] += total_var
                
            n_samples += data[0].shape[0]
    
    # Average and compute explained variance for each modality
    n_batches = len(list(val_loader))
    for m in range(n_modalities):
        avg_residual_var = residual_var_per_mod[m] / n_batches
        avg_total_var = total_var_per_mod[m] / n_batches
        
        if avg_total_var > 0:
            explained_var = 1.0 - (avg_residual_var / avg_total_var)
            # Clamp to [0, 1]
            explained_var = max(0.0, min(1.0, explained_var))
            explained_variance_values.append(explained_var)
        else:
            explained_variance_values.append(0.0)
    
    return explained_variance_values


def train_multibench_ae(train_dataset, val_dataset, input_shapes=None, input_dims=None, 
                        latent_dim=200, device=None, args=None,
                        epochs=1000, lr=1e-4, batch_size=32, 
                        conv_channels=[64, 128, 256], kernel_sizes=[3, 3, 3],
                        dropout=0.1, wd=1e-5, l2norm=0.0,
                        initial_rank_ratio=1.0, min_rank=10,
                        rank_schedule=None, rank_reduction_frequency=10, 
                        rank_reduction_threshold=0.01,
                        warmup_epochs=100, patience=10,
                        reduce_on_best_loss='rsquare', r_square_threshold=0.9, 
                        threshold_type='relative', compressibility_type='direct',
                        reduction_criterion='r_squared', decision_metric='R2',
                        early_stopping=50, verbose=True, model_name=None,
                        sharedwhenall=True, activation=None):
    """
    Train MultiBench multimodal autoencoder with adaptive rank reduction.
    Supports both temporal (3D) and static (2D) feature datasets.
    Implements the same sophisticated rank reduction/increase logic as train_overcomplete_ae.
    
    Args:
        train_dataset: Training dataset (torch.utils.data.Dataset)
        val_dataset: Validation dataset
        input_shapes: List of (seq_len, n_features) tuples for temporal data (or None)
        input_dims: List of feature dimensions for static data (or None)
        latent_dim: Latent dimension (int or list)
        device: torch device
        args: Training arguments object
        epochs: Number of training epochs
        lr: Learning rate
        batch_size: Batch size
        conv_channels: List of channel sizes for conv layers
        kernel_sizes: List of kernel sizes for conv layers
        dropout: Dropout rate
        wd: Weight decay
        l2norm: L2 norm regularization weight for first decoder layer
        initial_rank_ratio: Initial rank ratio for adaptive layers
        min_rank: Minimum rank for adaptive layers
        rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
        rank_reduction_frequency: How often to attempt rank reduction (epochs)
        rank_reduction_threshold: Threshold for rank reduction
        warmup_epochs: Number of epochs before rank reduction starts
        patience: Patience for rank reduction
        reduce_on_best_loss: 'true', 'stagnation', or 'rsquare'
        r_square_threshold: R² threshold for rank reduction
        threshold_type: 'relative' or 'absolute'
        compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
        reduction_criterion: 'r_squared', 'train_loss', or 'val_loss'
        decision_metric: 'R2' (coefficient of determination) or 'ExVarScore' (explained variance)
        early_stopping: Early stopping patience
        verbose: Print training progress
        model_name: Optional name for saving checkpoints
        sharedwhenall: Whether to always include shared layer in reductions
        activation: Optional output activation ('tanh', 'sigmoid', 'softmax', or None)
    
    Returns:
        Tuple of (model, representations, final_loss, r_squares, rank_history, loss_curves)
    """
    
    # Declare multi_gpu
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Determine if we're using static or temporal model
    from src.data.multibench_loader import MMIMDbDataset
    use_static_model = isinstance(train_dataset, MMIMDbDataset) or input_dims is not None
    
    # Create model
    if use_static_model:
        # Static features (MM-IMDb) - use standard feedforward AE
        print("Using AdaptiveRankReducedAE_Static for static features...")
        model = AdaptiveRankReducedAE_Static(
            input_dims=input_dims,
            latent_dims=latent_dim,
            depth=2,
            width=1.0,
            dropout=dropout,
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        n_modalities = len(input_dims)
    else:
        # Temporal features (Affect datasets) - use CNN/GRU/Transformer AE
        print("Using AdaptiveRankReducedAE_MultiBench for temporal features...")
        model = AdaptiveRankReducedAE_MultiBench(
            input_shapes=input_shapes,
            latent_dims=latent_dim,
            conv_channels=conv_channels,
            kernel_sizes=kernel_sizes,
            dropout=dropout,
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank,
            model_type=args.model_type,
            gru_hidden_dim=args.gru_hidden_dim,
            gru_num_layers=args.gru_num_layers,
            attn_num_heads=args.attn_num_heads,
            attn_num_layers=args.attn_num_layers,
            activation=activation
        )
        n_modalities = len(input_shapes)
    
    model = model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU
    if multi_gpu:
        if hasattr(args, 'gpu_ids') and args.gpu_ids:
            gpu_ids = [int(i) for i in args.gpu_ids.split(',')]
        else:
            gpu_ids = list(range(torch.cuda.device_count()))
        
        try:
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Wrap model with DataParallel
            model = nn.DataParallel(model, device_ids=gpu_ids)
            if verbose:
                print(f"Using DataParallel across GPUs: {gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Data loaders
    num_workers = getattr(args, 'num_workers', 4)
    
    # Choose appropriate collate function based on dataset type
    from src.data.multibench_loader import _process_mmimdb, _process_1, _process_2
    
    if isinstance(train_dataset, MMIMDbDataset):
        # MM-IMDb uses specialized collate function
        collate_fn_train = _process_mmimdb
        collate_fn_val = _process_mmimdb
    else:
        # Affect datasets use _process_1 or _process_2
        collate_fn_train = _process_2 if getattr(train_dataset, 'max_pad', False) else _process_1
        collate_fn_val = _process_2 if getattr(val_dataset, 'max_pad', False) else _process_1

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_train
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_val
    )
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                   epochs, 
                                   rank_reduction_frequency))
    
    # Initialize tracking variables (matching train_overcomplete_ae)
    initial_squares = [None] * n_modalities
    initial_losses = [None] * n_modalities
    start_reduction = False
    current_rsquare_per_mod = [None] * n_modalities
    current_loss_per_mod = [None] * n_modalities
    min_rsquares = []
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Training history
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers] if not multi_gpu else \
                [layer.active_dims for layer in model.module.adaptive_layers]
    best_loss = float('inf')
    patience_counter = 0
    
    # Initialize rank history dictionary
    rank_history = {
        'total_rank': [],
        'ranks': [],
        'epoch': [],
        'loss': [],
        'val_loss': [],
        'Sim (0-1,0-2,1-2)': []
    }
    for i in range(n_modalities):
        rank_history[f'rsquare {i}'] = []
    
    if verbose:
        print(f"\nStarting training for {epochs} epochs...")
        print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
        print(f"Batch size: {batch_size}, Learning rate: {lr}")
        print(f"Warmup epochs: {warmup_epochs}, Rank reduction frequency: {rank_reduction_frequency}")
        print(f"R² threshold: {r_square_threshold} ({threshold_type})")
        print(f"Compressibility type: {compressibility_type}, Reduction criterion: {reduction_criterion}")
    
    # Training loop
    pbar = tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss_epoch = 0.0
        per_modality_losses = [0.0] * n_modalities
        n_train_batches = 0
        
        for batch in train_loader:
            # Normalize batch into list of modality tensors + optional masks
            # Handle different batch formats from different collate functions:
            # - MM-IMDb: (text, image, labels) - 3 tensors
            # - Affect _process_1: (data_list, lengths, ids, labels) - list + extras
            # - Affect _process_2: (vision, audio, text, labels) - 4 tensors
            
            if isinstance(batch, (tuple, list)):
                # Check if first element is a list (from _process_1)
                if len(batch) >= 1 and isinstance(batch[0], list):
                    data = batch[0]
                    masks = None
                # MM-IMDb format: (text, image, labels) or Affect _process_2: (v, a, t, labels)
                # Last element is labels, rest are modalities
                elif len(batch) >= 2 and all(isinstance(b, torch.Tensor) for b in batch):
                    # All elements are tensors - take all but last as modalities
                    data = list(batch[:-1])
                    masks = None
                # If (data, masks) format
                elif len(batch) == 2:
                    data, masks = batch
                    if isinstance(data, (tuple, list)):
                        data = list(data)
                else:
                    data = list(batch)
                    masks = None
            else:
                data = [batch]
                masks = None

            # Move to device
            data = [d.to(device, non_blocking=True) for d in data]
            if masks is not None:
                masks = [m.to(device, non_blocking=True) if m is not None else None for m in masks]
            
            # Forward pass
            reconstructions, latents = model(data)
            
            # Calculate per-modality losses
            modality_losses = []
            for m in range(n_modalities):
                recon = reconstructions[m]
                target = data[m]
                
                # Apply mask if provided
                if masks is not None and masks[m] is not None:
                    mask = masks[m].float()
                    mse = F.mse_loss(recon * mask, target * mask, reduction='sum') / (mask.sum() + 1e-8)
                else:
                    mse = F.mse_loss(recon, target)
                
                # Check for NaN
                if torch.isnan(mse):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {m}")
                    mse = torch.tensor(0.0, device=device)
                
                modality_losses.append(mse)
                per_modality_losses[m] += mse.item()
            
            # Combine modality losses
            total_loss = sum(modality_losses) / n_modalities
            
            # Add L2 norm regularization for first decoder layer if specified
            if l2norm > 0:
                l2_penalty = 0.0
                adaptive_layers = model.module.adaptive_layers if multi_gpu else model.adaptive_layers
                for layer in adaptive_layers:
                    if hasattr(layer, 'decoder') and hasattr(layer.decoder, 'fc1'):
                        l2_penalty += torch.norm(layer.decoder.fc1.weight, p=2)
                total_loss = total_loss + l2norm * l2_penalty
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            train_loss_epoch += total_loss.item()
            n_train_batches += 1
        
        train_loss_epoch /= n_train_batches
        per_modality_losses = [loss / n_train_batches for loss in per_modality_losses]
        train_losses.append(train_loss_epoch)
        
        # Validation phase
        model.eval()
        val_loss_epoch = 0.0
        n_val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Normalize val batch same as train
                if isinstance(batch, (tuple, list)):
                    if len(batch) >= 1 and isinstance(batch[0], list):
                        data = batch[0]
                        masks = None
                    elif len(batch) >= 3:
                        data = list(batch[:-1])
                        masks = None
                    elif len(batch) == 2:
                        data, masks = batch
                        if isinstance(data, (tuple, list)):
                            data = list(data)
                    else:
                        data = list(batch)
                        masks = None
                else:
                    data = [batch]
                    masks = None

                data = [d.to(device, non_blocking=True) for d in data]
                if masks is not None:
                    masks = [m.to(device, non_blocking=True) if m is not None else None for m in masks]

                # Forward pass
                reconstructions, _ = model(data)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for m in range(n_modalities):
                    recon = reconstructions[m]
                    target = data[m]
                    
                    # Apply mask if provided
                    if masks is not None and masks[m] is not None:
                        mask = masks[m].float()
                        mse = F.mse_loss(recon * mask, target * mask, reduction='sum') / (mask.sum() + 1e-8)
                    else:
                        mse = F.mse_loss(recon, target)
                    
                    if not torch.isnan(mse):
                        val_batch_loss += mse.item()
                
                val_loss_epoch += val_batch_loss / n_modalities
                n_val_batches += 1
        
        val_loss_epoch /= n_val_batches
        val_losses.append(val_loss_epoch)
        
        # Update progress bar
        adaptive_layers = model.module.adaptive_layers if multi_gpu else model.adaptive_layers
        log_dict = {
            'loss': round(train_loss_epoch, 4),
            'val_loss': round(val_loss_epoch, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in adaptive_layers],
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(n_modalities)],
            'patience': patience_counter,
        }
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss_epoch < best_loss:
            best_loss = train_loss_epoch
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1
        
        # Check if we should start rank reduction (early stopping for initial training)
        if (epoch > early_stopping) and (min(val_losses[-early_stopping:]) > min(val_losses)) and (start_reduction is False):
            start_reduction = True
            break_counter = 0
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")
            
            # Initialize R² tracking based on compressibility_type
            model_access = model.module if multi_gpu else model
            
            if compressibility_type == 'direct':
                if reduction_criterion == 'r_squared':
                    # Direct reconstruction R² or Explained Variance
                    if decision_metric == 'ExVarScore':
                        #direct_r_squared_values = compute_direct_explained_variance_multibench(model, val_loader, device, multi_gpu)
                        # use the first 10% of the train data
                        train_subset = torch.utils.data.Subset(train_loader.dataset, range(int(0.1 * len(train_loader.dataset))))
                        train_subset_loader = torch.utils.data.DataLoader(train_subset, batch_size=train_loader.batch_size, shuffle=False)
                        direct_r_squared_values = compute_direct_explained_variance_multibench(model, train_subset_loader, device, multi_gpu)
                    else:  # Default to R2
                        #direct_r_squared_values = compute_direct_r_squared_multibench(model, val_loader, device, multi_gpu)
                        train_subset = torch.utils.data.Subset(train_loader.dataset, range(int(0.1 * len(train_loader.dataset))))
                        train_subset_loader = torch.utils.data.DataLoader(train_subset, batch_size=train_loader.batch_size, shuffle=False)
                        direct_r_squared_values = compute_direct_r_squared_multibench(model, train_subset_loader, device, multi_gpu)
                    
                    for i, r_squared_val in enumerate(direct_r_squared_values):
                        initial_squares[i] = r_squared_val
                        current_rsquare_per_mod[i] = r_squared_val
                        rank_history[f'rsquare {i}'] = [r_squared_val]
                        
                        # Calculate threshold
                        if threshold_type == 'relative':
                            min_rsquares.append(r_squared_val * r_square_threshold)
                        elif threshold_type == 'absolute':
                            min_rsquares.append(r_squared_val - r_square_threshold)
                
                elif reduction_criterion in ['train_loss', 'val_loss']:
                    # Loss-based approach
                    for m in range(n_modalities):
                        loss_val = per_modality_losses[m] if reduction_criterion == 'train_loss' else val_loss_epoch / n_modalities
                        initial_losses[m] = loss_val
                        initial_squares[m] = loss_val
                        current_rsquare_per_mod[m] = loss_val
                        current_loss_per_mod[m] = loss_val
                        rank_history[f'rsquare {m}'] = [loss_val]
                        
                        # Calculate threshold (for loss, lower is better)
                        if threshold_type == 'relative':
                            min_rsquares.append(loss_val * r_square_threshold)
                        elif threshold_type == 'absolute':
                            min_rsquares.append(loss_val + r_square_threshold)
            max_rsquares = initial_squares.copy()
            if verbose:
                if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                    print(f"Initial {reduction_criterion} values: {[rank_history[f'rsquare {i}'][0] for i in range(n_modalities)]}, setting {threshold_type} thresholds to {min_rsquares}")
                else:
                    print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'][0] for i in range(n_modalities)]}, setting {threshold_type} thresholds to {min_rsquares}")
        
        # Apply rank reduction at scheduled epochs
        if (epoch in rank_schedule) and start_reduction and (break_counter == 0):
            # Calculate R² or loss metrics
            model_access = model.module if multi_gpu else model
            
            if reduce_on_best_loss == 'rsquare' and start_reduction:
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                
                # Calculate metrics based on compressibility_type
                if compressibility_type == 'direct':
                    if reduction_criterion == 'r_squared':
                        train_subset = torch.utils.data.Subset(train_loader.dataset, range(int(0.1 * len(train_loader.dataset))))
                        train_subset_loader = torch.utils.data.DataLoader(train_subset, batch_size=train_loader.batch_size, shuffle=False)
                        if decision_metric == 'ExVarScore':
                            #direct_r_squared_values = compute_direct_explained_variance_multibench(model, val_loader, device, multi_gpu)
                            direct_r_squared_values = compute_direct_explained_variance_multibench(model, train_subset_loader, device, multi_gpu)
                        else:  # Default to R2
                            #direct_r_squared_values = compute_direct_r_squared_multibench(model, val_loader, device, multi_gpu)
                            direct_r_squared_values = compute_direct_r_squared_multibench(model, train_subset_loader, device, multi_gpu)
                        for i, r_squared_val in enumerate(direct_r_squared_values):
                            current_rsquares.append(r_squared_val)
                            current_rsquare_per_mod[i] = r_squared_val
                    
                    elif reduction_criterion in ['train_loss', 'val_loss']:
                        for m in range(n_modalities):
                            loss_val = per_modality_losses[m] if reduction_criterion == 'train_loss' else val_loss_epoch / n_modalities
                            current_rsquares.append(loss_val)
                            current_rsquare_per_mod[m] = loss_val
                            current_loss_per_mod[m] = loss_val
                
                r_squares.append(current_rsquares)
                
                # Update min_rsquares based on maximum observed (allows threshold to follow best performance)
                update_max = False
                for i, r in enumerate(current_rsquares):
                    if r > max_rsquares[i]:
                        max_rsquares[i] = r
                        update_max = True
                if update_max:
                    if threshold_type == 'relative':
                        min_rsquares = [r * r_square_threshold for r in max_rsquares]
                    elif threshold_type == 'absolute':
                        # For R²: subtract threshold; for loss: add threshold
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            min_rsquares = [r + r_square_threshold for r in max_rsquares]
                        else:
                            min_rsquares = [r - r_square_threshold for r in max_rsquares]
                    if verbose:
                        print(f"Updated threshold {min_rsquares} based on new max {max_rsquares}")
                
                # Determine which modalities to reduce or increase
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                
                elif (len(r_squares) >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                
                # Set min and max ranks if all modalities can be reduced
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            adaptive_layers[i].max_rank = min(sum(current_ranks), adaptive_layers[i].max_rank)
                
                # Determine layers to reduce or increase
                layers_to_reduce = []
                layers_to_increase = []
                
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [i + 1 for i in modalities_to_increase]
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            layers_to_increase = [i for i in range(len(adaptive_layers))]
                        else:
                            layers_to_increase = [i + 1 for i in modalities_to_increase]
                    
                    if len(modalities_to_reduce) > 0:
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            reduce_shared = False if sharedwhenall else True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
            
            # Apply rank changes
            any_changes_made = False
            
            if len(layers_to_reduce) > 0:
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            
            if len(layers_to_increase) > 0:
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience
            
            if any_changes_made:
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Store current rank in history
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in adaptive_layers))
            rank_history['epoch'].append(epoch)
            for i in range(n_modalities):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss_epoch)
            rank_history['val_loss'].append(val_loss_epoch)
            
            # Compute space similarities (placeholder for compatibility)
            space_sims = [0.0] * 3  # Placeholder
            rank_history['Sim (0-1,0-2,1-2)'].append(', '.join([f"{sim:.4f}" for sim in space_sims]))
        
        else:
            if (epoch in rank_schedule) and start_reduction and (break_counter > 0):
                break_counter -= 1
        
        # Early stopping after rank reduction has started
        if (epoch > early_stopping) and (min(val_losses[-early_stopping:]) > min(val_losses)) and start_reduction and (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1] if rank_history['ranks'] else 'N/A'}")
            break
    
    # Extract final representations
    # IMPORTANT: Create a non-shuffled dataloader to ensure representations are in the same order as the original data
    model.eval()
    
    # Create non-shuffled dataloader for representation extraction
    representation_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False,
        num_workers=train_loader.num_workers if hasattr(train_loader, 'num_workers') else 0,
        pin_memory=train_loader.pin_memory if hasattr(train_loader, 'pin_memory') else False
    )
    
    # Allocate space for shared + modality-specific latents. Some models return
    # (shared, [specifics]) while others return a list of per-modality latents.
    # Allocate n_modalities + 1 so we can safely append shared at index 0 when present.
    all_latents = [[] for _ in range(n_modalities + 1)]
    
    with torch.no_grad():
        for batch in representation_loader:
            # Normalize batch for encoding
            if isinstance(batch, (tuple, list)):
                if len(batch) >= 1 and isinstance(batch[0], list):
                    data = batch[0]
                elif len(batch) >= 3:
                    data = list(batch[:-1])
                elif len(batch) == 2:
                    data, _ = batch
                    if isinstance(data, (tuple, list)):
                        data = list(data)
                else:
                    data = list(batch)
            else:
                data = [batch]

            # Move to device
            data = [d.to(device, non_blocking=True) for d in data]

            # Encode
            model_access = model.module if multi_gpu else model
            latents_batch = model_access.encode(data)
            
            # latents_batch is either [z_shared, [z_specific_0, z_specific_1, ...]]
            # or just [z_0, z_1, ...] depending on model structure
            if isinstance(latents_batch, tuple) and len(latents_batch) == 2:
                # Has shared and specific: [shared, [specific_list]]
                z_shared, z_specifics = latents_batch
                # Store shared + specifics
                all_latents[0].append(z_shared.cpu())
                for m, z_spec in enumerate(z_specifics):
                    all_latents[m + 1].append(z_spec.cpu())
            else:
                # Just a list of latents per modality
                for m, z in enumerate(latents_batch):
                    all_latents[m].append(z.cpu())
    
    # Concatenate latents (now in same order as original training data, not shuffled)
    representations = [torch.cat(latents, dim=0).numpy() for latents in all_latents if len(latents) > 0]
    
    # IMPORTANT: Truncate representations to active dimensions only
    # The adaptive layers output full latent_dim but only active_dims are meaningful
    adaptive_layers = model.module.adaptive_layers if multi_gpu else model.adaptive_layers
    active_dims_list = [layer.active_dims for layer in adaptive_layers]
    
    print(f"\nTruncating representations to active dimensions:")
    print(f"  Active dims: {active_dims_list}")
    
    truncated_representations = []
    for i, (rep, active_dim) in enumerate(zip(representations, active_dims_list)):
        truncated_rep = rep[:, :active_dim]  # Keep only first active_dim columns
        truncated_representations.append(truncated_rep)
        rep_type = "Shared" if i == 0 else f"Modality {i-1} specific"
        print(f"  {rep_type}: {rep.shape} -> {truncated_rep.shape} (kept {active_dim}/{rep.shape[1]} dims)")
    
    representations = truncated_representations
    
    # Get final metrics
    final_loss = train_losses[-1]
    final_r_squares = r_squares[-1] if r_squares else [None] * n_modalities
    
    # Clean up
    del train_loader, val_loader, representation_loader
    gc.collect()
    torch.cuda.empty_cache()
    
    model_access = model.module if multi_gpu else model
    return model_access, representations, final_loss, final_r_squares, rank_history, (train_losses, val_losses)
