import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import tqdm
import gc
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from src.models.larrp_ninfea import AdaptiveRankReducedAE_NInFEA, AdaptiveRankReducedAE_NInFEA_2Mods, AdaptiveRankReducedAE_MM
from src.functions.linear_probing import parallel_linear_regression
from src.data.loading import MMSimData, PairedMultimodalData
from src.functions.loss import frobenius_loss_shared, frobenius_loss_crossmodal
from src.visualization.logging import plot_training_state, create_training_movie
from src.functions.pretrain_mm_sim import pretrain_overcomplete_ae

def plot_ninfea_reconstructions(model, train_dataset, val_dataset, device, output_path, 
                               modality_names=None, n_samples=4):
    """
    Plot reconstructions for NInFEA model showing original vs reconstructed data.
    
    Parameters:
    - model: Trained NInFEA model
    - train_dataset: Training dataset 
    - val_dataset: Validation dataset
    - device: Device to run inference on
    - output_path: Path to save the plot
    - modality_names: List of modality names for labeling
    - n_samples: Number of samples to plot (default 4)
    """
    model.eval()
    
    # Ensure model is on the correct device
    # Handle DataParallel wrapped models
    if hasattr(model, 'module'):
        model_device = next(model.module.parameters()).device
    else:
        model_device = next(model.parameters()).device
    
    if model_device != device:
        print(f"Warning: Model is on {model_device}, but device parameter is {device}. Using model device.")
        device = model_device
    
    # Get sample data to determine structure
    sample_item = train_dataset[0]
    if isinstance(sample_item, tuple):
        sample_data = sample_item[0]
    else:
        sample_data = sample_item
    #sample_data, sample_masks = train_dataset[0]
    n_modalities = len(sample_data)
    
    if modality_names is None:
        modality_names = [f'Modality {i+1}' for i in range(n_modalities)]
    
    # Sample indices for plotting
    train_indices = np.random.choice(len(train_dataset), min(n_samples, len(train_dataset)), replace=False)
    val_indices = np.random.choice(len(val_dataset), min(n_samples, len(val_dataset)), replace=False)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(8*n_modalities, 3*2*n_samples))
    gs = gridspec.GridSpec(2*2*n_samples, n_modalities, figure=fig, hspace=0.3, wspace=0.3)
    
    with torch.no_grad():
        # Plot training samples
        for i, idx in enumerate(train_indices):
            dataset_item = train_dataset[idx]
            # Handle both masked and unmasked datasets flexibly
            if isinstance(dataset_item, tuple) and len(dataset_item) == 2:
                data, mask = dataset_item  # Dataset with masks
            else:
                data = dataset_item  # Dataset without masks
                mask = None
            # Ensure data is on the same device as model
            data_tensors = [d.unsqueeze(0).to(device, non_blocking=True) for d in data]  # Add batch dimension
            
            # Get reconstruction - handle potential DataParallel wrapping
            try:
                reconstructions, _ = model(data_tensors)
            except Exception as e:
                print(f"Error in model forward pass: {e}")
                continue
            
            # Plot each modality
            for j, (orig, recon) in enumerate(zip(data, reconstructions)):
                print(f"Plotting sample {i+1}, modality {j+1}")
                #ax = fig.add_subplot(gs[i, j])
                
                orig_np = orig.cpu().numpy()
                recon_np = recon.squeeze(0).cpu().numpy()  # Remove batch dimension
                
                # Handle mask if available
                if mask is not None and j < len(mask) and mask[j] is not None:
                    mask_np = mask[j].cpu().numpy().astype(bool)
                else:
                    mask_np = None
                
                # Determine if this is the last modality (likely image data)
                if len(orig_np.shape) > 2:
                    print(f"Modality {j+1} has shape {orig_np.shape}, treating as image data.")
                    ax1 = fig.add_subplot(gs[2*i, j])
                    ax2 = fig.add_subplot(gs[2*i+1, j])
                    # Treat as image data
                    if len(orig_np.shape) == 3:  # (H, W, C) or (C, H, W)
                        if orig_np.shape[0] <= 4:  # Likely (C, H, W)
                            orig_display = np.transpose(orig_np, (1, 2, 0))
                            recon_display = np.transpose(recon_np, (1, 2, 0))
                        else:  # Likely (H, W, C)
                            orig_display = orig_np
                            recon_display = recon_np
                        
                        if orig_display.shape[2] == 1:  # Grayscale
                            orig_display = orig_display.squeeze(2)
                            recon_display = recon_display.squeeze(2)
                            ax1.imshow(orig_display, cmap='gray', alpha=0.7)
                            ax2.imshow(recon_display, cmap='Reds', alpha=0.5)
                        else:  # Color image
                            # Normalize to [0,1] range for display
                            orig_display = (orig_display - orig_display.min()) / (orig_display.max() - orig_display.min() + 1e-8)
                            recon_display = (recon_display - recon_display.min()) / (recon_display.max() - recon_display.min() + 1e-8)
                            ax1.imshow(orig_display, alpha=0.7)
                            ax2.imshow(recon_display, alpha=0.5)
                    else:  # 2D data, treat as heatmap
                        if mask_np is not None:
                            orig_masked = orig_np.copy()
                            recon_masked = recon_np.copy()
                            orig_masked[~mask_np] = np.nan
                            recon_masked[~mask_np] = np.nan
                        else:
                            orig_masked = orig_np
                            recon_masked = recon_np
                        
                        im1 = ax1.imshow(orig_masked, cmap='Blues', alpha=0.7, aspect='auto')
                        im2 = ax2.imshow(recon_masked, cmap='Reds', alpha=0.5, aspect='auto')
                    
                    ax1.set_title(f'{modality_names[j]} (Train {i+1})')
                    ax1.set_xticks([])
                    ax1.set_yticks([])
                    ax2.set_xticks([])
                    ax2.set_yticks([])
                    
                else:
                    ax = fig.add_subplot(gs[(2*i):(2*i+2), j])
                    # Treat as time series/signal data
                    if len(orig_np.shape) == 1:
                        print(f"Modality {j+1} has shape {orig_np.shape}, treating as 1D signal.")
                        # 1D signal
                        x = np.arange(len(orig_np))
                        if mask_np is not None and len(mask_np.shape) == 1:
                            valid_indices = mask_np
                            ax.plot(x[valid_indices], orig_np[valid_indices], 'b-', alpha=0.7, label='Original', linewidth=2)
                            ax.plot(x[valid_indices], recon_np[valid_indices], 'r-', alpha=0.7, label='Reconstructed', linewidth=1)
                        else:
                            ax.plot(x, orig_np, 'b-', alpha=0.7, label='Original', linewidth=2)
                            ax.plot(x, recon_np, 'r-', alpha=0.7, label='Reconstructed', linewidth=1)
                        
                    elif len(orig_np.shape) == 2:
                        print(f"Modality {j+1} has shape {orig_np.shape}, treating as 2D signal.")
                        # Multi-channel signal (seq_length, n_channels) or (n_channels, seq_length)
                        if orig_np.shape[0] > orig_np.shape[1]:  # Likely (seq_length, n_channels)
                            seq_len, n_channels = orig_np.shape
                            x = np.arange(seq_len)
                            for ch in range(n_channels):  # Plot up to 3 channels
                                if mask_np is not None:
                                    if len(mask_np.shape) == 2:
                                        valid_mask = mask_np[:, ch] if mask_np.shape[1] > ch else mask_np[:, 0]
                                    else:
                                        valid_mask = mask_np
                                    valid_indices = valid_mask.astype(bool)
                                    ax.plot(x[valid_indices], orig_np[valid_indices, ch], 
                                           alpha=0.7, label=f'Orig Ch{ch+1}', linewidth=2)
                                    ax.plot(x[valid_indices], recon_np[valid_indices, ch], 
                                           '--', alpha=0.7, label=f'Recon Ch{ch+1}', linewidth=1)
                                else:
                                    #ax.plot(x, orig_np[:, ch], alpha=0.7, label=f'Orig Ch{ch+1}', linewidth=2)
                                    #ax.plot(x, recon_np[:, ch], '--', alpha=0.7, label=f'Recon Ch{ch+1}', linewidth=1)
                                    ax.scatter(orig_np[:, ch], recon_np[:, ch], alpha=0.7)
                        else:  # Likely (n_channels, seq_length)
                            n_channels, seq_len = orig_np.shape
                            x = np.arange(seq_len)
                            for ch in range(n_channels):  # Plot up to 3 channels
                                if mask_np is not None:
                                    if len(mask_np.shape) == 2:
                                        valid_mask = mask_np[ch, :] if mask_np.shape[0] > ch else mask_np[0, :]
                                    else:
                                        valid_mask = mask_np
                                    valid_indices = valid_mask.astype(bool)
                                    ax.plot(x[valid_indices], orig_np[ch, valid_indices], 
                                           alpha=0.7, label=f'Orig Ch{ch+1}', linewidth=2)
                                    ax.plot(x[valid_indices], recon_np[ch, valid_indices], 
                                           '--', alpha=0.7, label=f'Recon Ch{ch+1}', linewidth=1)
                                else:
                                    #ax.plot(x, orig_np[ch, :], alpha=0.7, label=f'Orig Ch{ch+1}', linewidth=2)
                                    #ax.plot(x, recon_np[ch, :], '--', alpha=0.7, label=f'Recon Ch{ch+1}', linewidth=1)
                                    ax.scatter(orig_np[:, ch], recon_np[:, ch], alpha=0.7)
                    
                    ax.set_title(f'{modality_names[j]} (Train {i+1})')
                    ax.grid(True, alpha=0.3)
                    if i == 0 and len(orig_np.shape) <= 2:  # Add legend only to first row
                        ax.legend(fontsize='small')
        
        # Plot validation samples
        """
        for i, idx in enumerate(val_indices):
            dataset_item = val_dataset[idx]
            # Handle both masked and unmasked datasets flexibly
            if isinstance(dataset_item, tuple) and len(dataset_item) == 2:
                data, mask = dataset_item  # Dataset with masks
            else:
                data = dataset_item  # Dataset without masks
                mask = None
            # Ensure data is on the same device as model
            data_tensors = [d.unsqueeze(0).to(device, non_blocking=True) for d in data]
            
            # Get reconstruction - handle potential DataParallel wrapping
            try:
                reconstructions, _ = model(data_tensors)
            except Exception as e:
                print(f"Error in model forward pass for validation: {e}")
                continue
            
            # Plot each modality
            for j, (orig, recon) in enumerate(zip(data, reconstructions)):
                ax = fig.add_subplot(gs[n_samples + i, j])
                
                orig_np = orig.cpu().numpy()
                recon_np = recon.squeeze(0).cpu().numpy()
                
                # Handle mask if available
                if mask is not None and j < len(mask) and mask[j] is not None:
                    mask_np = mask[j].cpu().numpy().astype(bool)
                else:
                    mask_np = None
                
                # Same plotting logic as above but for validation
                if j == n_modalities - 1 and len(orig_np.shape) >= 2:
                    # Image modality
                    if len(orig_np.shape) == 3:
                        if orig_np.shape[0] <= 4:
                            orig_display = np.transpose(orig_np, (1, 2, 0))
                            recon_display = np.transpose(recon_np, (1, 2, 0))
                        else:
                            orig_display = orig_np
                            recon_display = recon_np
                        
                        if orig_display.shape[2] == 1:
                            orig_display = orig_display.squeeze(2)
                            recon_display = recon_display.squeeze(2)
                            ax.imshow(orig_display, cmap='gray', alpha=0.7)
                            ax.imshow(recon_display, cmap='Reds', alpha=0.5)
                        else:
                            orig_display = (orig_display - orig_display.min()) / (orig_display.max() - orig_display.min() + 1e-8)
                            recon_display = (recon_display - recon_display.min()) / (recon_display.max() - recon_display.min() + 1e-8)
                            ax.imshow(orig_display, alpha=0.7)
                            ax.imshow(recon_display, alpha=0.5)
                    else:
                        if mask_np is not None:
                            orig_masked = orig_np.copy()
                            recon_masked = recon_np.copy()
                            orig_masked[~mask_np] = np.nan
                            recon_masked[~mask_np] = np.nan
                        else:
                            orig_masked = orig_np
                            recon_masked = recon_np
                        
                        ax.imshow(orig_masked, cmap='Blues', alpha=0.7, aspect='auto')
                        ax.imshow(recon_masked, cmap='Reds', alpha=0.5, aspect='auto')
                    
                    ax.set_title(f'{modality_names[j]} (Val {i+1})')
                    ax.set_xticks([])
                    ax.set_yticks([])
                    
                else:
                    # Time series modality
                    if len(orig_np.shape) == 1:
                        x = np.arange(len(orig_np))
                        if mask_np is not None and len(mask_np.shape) == 1:
                            valid_indices = mask_np
                            ax.plot(x[valid_indices], orig_np[valid_indices], 'b-', alpha=0.7, linewidth=2)
                            ax.plot(x[valid_indices], recon_np[valid_indices], 'r-', alpha=0.7, linewidth=1)
                        else:
                            ax.plot(x, orig_np, 'b-', alpha=0.7, linewidth=2)
                            ax.plot(x, recon_np, 'r-', alpha=0.7, linewidth=1)
                        
                    elif len(orig_np.shape) == 2:
                        if orig_np.shape[0] > orig_np.shape[1]:
                            seq_len, n_channels = orig_np.shape
                            x = np.arange(seq_len)
                            for ch in range(min(n_channels, 3)):
                                if mask_np is not None:
                                    if len(mask_np.shape) == 2:
                                        valid_mask = mask_np[:, ch] if mask_np.shape[1] > ch else mask_np[:, 0]
                                    else:
                                        valid_mask = mask_np
                                    valid_indices = valid_mask.astype(bool)
                                    ax.plot(x[valid_indices], orig_np[valid_indices, ch], alpha=0.7, linewidth=2)
                                    ax.plot(x[valid_indices], recon_np[valid_indices, ch], '--', alpha=0.7, linewidth=1)
                                else:
                                    ax.plot(x, orig_np[:, ch], alpha=0.7, linewidth=2)
                                    ax.plot(x, recon_np[:, ch], '--', alpha=0.7, linewidth=1)
                        else:
                            n_channels, seq_len = orig_np.shape
                            x = np.arange(seq_len)
                            for ch in range(min(n_channels, 3)):
                                if mask_np is not None:
                                    if len(mask_np.shape) == 2:
                                        valid_mask = mask_np[ch, :] if mask_np.shape[0] > ch else mask_np[0, :]
                                    else:
                                        valid_mask = mask_np
                                    valid_indices = valid_mask.astype(bool)
                                    ax.plot(x[valid_indices], orig_np[ch, valid_indices], alpha=0.7, linewidth=2)
                                    ax.plot(x[valid_indices], recon_np[ch, valid_indices], '--', alpha=0.7, linewidth=1)
                                else:
                                    ax.plot(x, orig_np[ch, :], alpha=0.7, linewidth=2)
                                    ax.plot(x, recon_np[ch, :], '--', alpha=0.7, linewidth=1)
                    
                    ax.set_title(f'{modality_names[j]} (Val {i+1})')
                    ax.grid(True, alpha=0.3)
        """
    
    # Add overall title
    fig.suptitle('NInFEA Multimodal Reconstructions: Original (blue/solid) vs Reconstructed (red/dashed)', 
                fontsize=14)
    
    # Add labels for training vs validation sections
    #fig.text(0.02, 0.75, 'Training\nSamples', rotation=90, fontsize=12, ha='center', va='center', weight='bold')
    #fig.text(0.02, 0.25, 'Validation\nSamples', rotation=90, fontsize=12, ha='center', va='center', weight='bold')
    
    # Save the plot
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Reconstruction plot saved to: {output_path}")

def pretrain_overcomplete_ae_ninfea(train_dataset, val_dataset, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, hidden_dim=512, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True, recon_loss_balancing=False, paired=False, conv_depth=3,
                         lr_schedule=None, lr_schedule_step_size=None, lr_schedule_gamma=0.1):
    """
    Train a NInFEA autoencoder for pretraining (no rank reduction).
    
    Parameters:
    - train_dataset: Training dataset with masks
    - val_dataset: Validation dataset with masks
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_hidden: Hidden dimension size for autoencoder layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank
    - patience: Early stopping patience
    - verbose: Print progress
    - recon_loss_balancing: Adaptive loss balancing across modalities
    - paired: Whether to use paired data splitting
    - conv_depth: Number of convolutional layers for NInFEA architecture
    - lr_schedule: Learning rate schedule type ('linear', 'step', 'cosine', None)
    - lr_schedule_step_size: Step size for StepLR scheduler (epochs)
    - lr_schedule_gamma: Multiplication factor for learning rate decay
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Extract shapes from first sample
    sample_item = train_dataset[0]
    if isinstance(sample_item, tuple):
        sample_data = sample_item[0]
        sample_masks = sample_item[1] if len(sample_item) > 1 else None
    else:
        sample_data = sample_item
        sample_masks = None
    # Handle both 2D and 3D data: 2D -> (seq_length, n_channels), 3D -> (height, width, n_channels)
    input_shapes = []
    for d in sample_data:
        if len(d.shape) == 2:
            # 2D data: (seq_length, n_channels)
            input_shapes.append((d.shape[0], d.shape[1]))
        elif len(d.shape) == 3:
            # 3D data: (height, width, n_channels)
            input_shapes.append((d.shape[0], d.shape[1], d.shape[2]))
        else:
            #raise ValueError(f"Unsupported data shape: {d.shape}. Expected 2D or 3D.")
            input_shapes.append(d.shape)  # Use original shape if not 2D/3D
    
    # Create NInFEA model with adaptive rank reduction
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_shapes) + 1) # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_shapes) > 1):
            latent_dims = [latent_dim[0]] * (len(input_shapes) + 1)
        else:
            latent_dims = latent_dim
    
    # Use hidden_dim directly (no calculation from ae_width)

    if train_dataset.n_modalities > 2:
        #model = AdaptiveRankReducedAE_NInFEA(
        model = AdaptiveRankReducedAE_MM(
            input_shapes, latent_dims, depth=ae_depth, hidden_dim=hidden_dim, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank, conv_depth=conv_depth
        )
    elif train_dataset.n_modalities == 2:
        model = AdaptiveRankReducedAE_NInFEA_2Mods(
            input_shapes, latent_dims, depth=ae_depth, hidden_dim=hidden_dim, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank, conv_depth=conv_depth
        )
    else:
        raise ValueError("NInFEA pretraining requires at least 2 modalities.")
    
    # Check model size before moving to GPU
    total_params = sum(p.numel() for p in model.parameters())
    total_size_mb = total_params * 4 / (1024**2)  # Assuming float32
    print(f"Model has {total_params:,} parameters ({total_size_mb:.1f} MB)")
    
    # Check GPU memory before moving model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        free_mem = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)
        print(f"GPU free memory before model transfer: {free_mem / (1024**3):.2f} GB")
    
    model = model.to(device)
    
    # Check GPU memory after moving model
    if torch.cuda.is_available():
        allocated_mem = torch.cuda.memory_allocated(device) / (1024**3)
        cached_mem = torch.cuda.memory_reserved(device) / (1024**3)
        print(f"GPU memory allocated after model transfer: {allocated_mem:.2f} GB")
        print(f"GPU memory cached: {cached_mem:.2f} GB")
    
    # Add epoch tracking attribute
    model.epoch = 0
    
    # print the device the model is on
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loaders
    num_workers = getattr(args, 'num_workers', 0)
    
    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=num_workers
    )
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=num_workers
    )
    n_samples_train = len(train_dataset)
    n_samples_val = len(val_dataset)
    
    # Create learning rate scheduler if specified  
    scheduler = None
    if lr_schedule is not None:
        if lr_schedule.lower() == 'linear':
            # Linear decay from initial LR to 0 over total epochs
            scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=1.0, end_factor=0.001, total_iters=1000
            )
        elif lr_schedule.lower() == 'step':
            # Step decay - reduce LR by gamma every step_size epochs
            step_size = lr_schedule_step_size if lr_schedule_step_size is not None else 30
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=step_size, gamma=lr_schedule_gamma
            )
        elif lr_schedule.lower() == 'cosine':
            # Cosine annealing - smooth decay following cosine curve
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=epochs, eta_min=0.0
            )
        else:
            print(f"Warning: Unknown lr_schedule '{lr_schedule}'. Available options: 'linear', 'step', 'cosine'")
    
    # Initialize loss scaling factors for dynamic loss balancing
    # use loss weights if provided in args
    if hasattr(args, 'loss_weights'):
        loss_scales = args.loss_weights
    else:
        loss_scales = [1.0] * len(sample_data)
    #loss_scales = torch.ones(len(sample_data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(sample_data))}  

    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 

    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(sample_data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(sample_data)
        
        for batch_idx, batch_data in enumerate(data_loader):
            # Handle both masked and unmasked data
            if isinstance(batch_data, tuple) and len(batch_data) == 2:
                x, mask = batch_data  # MMSimData format
            else:
                x = batch_data  # PairedMultimodalData format
                mask = None
                

            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_encoded = model(x)  # NInFEA returns hierarchical encoded dict
            
            # No orthogonal loss in pretraining
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality 
            modality_masks = []
            if mask is not None:
                if isinstance(mask, list):
                    # PaddedMultimodalDataWithMasks format - separate mask per modality
                    modality_masks = [m.to(device, non_blocking=True) for m in mask]
                else:
                    # MMSimData format - single concatenated mask, need to split by modality
                    start_idx = 0
                    for i, x_m in enumerate(x):
                        end_idx = start_idx + x_m.shape[1] * x_m.shape[2]  # seq_length * n_channels
                        modality_masks.append(mask[:, start_idx:end_idx].view(mask.shape[0], x_m.shape[1], x_m.shape[2]))
                        start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Compute MSE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        # Ortho loss is not used in pretraining
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for val_batch_data in val_data_loader:
                # Handle both masked and unmasked data
                if isinstance(val_batch_data, tuple) and len(val_batch_data) == 2:
                    x_val, mask = val_batch_data
                else:
                    x_val = val_batch_data
                    mask = None
                    
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    if isinstance(mask, list):
                        # List of masks from PaddedMultimodalDataWithMasks (one per modality)
                        modality_masks = [m.to(device, non_blocking=True) for m in mask]
                    else:
                        # Concatenated mask from MMSimData (need to split by modality)
                        start_idx = 0
                        for i, x_m in enumerate(x_val):
                            end_idx = start_idx + x_m.shape[1] * x_m.shape[2]  # seq_length * n_channels
                            modality_masks.append(mask[:, start_idx:end_idx].view(mask.shape[0], x_m.shape[1], x_m.shape[2]))
                            start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        
        # Update progress bar
        log_dict = {
            'loss': round(train_loss, 6),
            'val_loss': round(val_loss, 6),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'best_loss': round(best_loss, 6),
        }
        
        # Add current learning rate to progress bar if scheduler is used
        if scheduler is not None:
            current_lr = optimizer.param_groups[0]['lr']
            log_dict['lr'] = f"{current_lr:.2e}"
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1
        
        # Step the learning rate scheduler if provided
        if scheduler is not None:
            scheduler.step()
        
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
    
    return model, [train_losses, val_losses], None

def compute_direct_r_squared(model, data, device, masks=None, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance with mask support
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - masks: List of masks for each modality (optional)
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate R² for each modality with mask support
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()
            
            # Handle masks - compute R² per sample and then average
            if masks is not None and i < len(masks) and masks[i] is not None:
                mask_i = masks[i].cpu()
                sample_r_squared_values = []
                
                # Loop over each sample in the batch
                for sample_idx in range(original_cpu.shape[0]):
                    sample_original = original_cpu[sample_idx]
                    sample_reconstruction = reconstruction_cpu[sample_idx]
                    sample_mask = mask_i[sample_idx]
                    
                    # Apply mask - only use valid (non-masked) positions
                    valid_positions = sample_mask.bool()
                    if valid_positions.sum() == 0:
                        # No valid positions, skip this sample
                        continue
                    
                    # Don't flatten - compute R² per channel to preserve structure
                    # Handle different data shapes (2D: seq_len x n_channels, 3D: H x W x C, etc.)
                    if len(sample_original.shape) == 1:
                        # 1D data - compute R² directly
                        orig_valid = sample_original[valid_positions]
                        recon_valid = sample_reconstruction[valid_positions]
                        
                        if len(orig_valid) == 0:
                            continue
                            
                        orig_mean = orig_valid.mean()
                        ss_res = ((orig_valid - recon_valid) ** 2).sum()
                        ss_tot = ((orig_valid - orig_mean) ** 2).sum()
                        
                        if ss_tot == 0:
                            sample_r_squared = 1.0 if ss_res == 0 else 0.0
                        else:
                            sample_r_squared = 1 - (ss_res / ss_tot)
                            
                    else:
                        # Multi-dimensional data - compute R² per channel/feature dimension
                        channel_r_squared_values = []
                        
                        # Determine channel dimension based on data shape
                        if len(sample_original.shape) == 2:
                            # 2D: (seq_len, n_channels) - iterate over channels
                            n_channels = sample_original.shape[1]
                            for ch in range(n_channels):
                                ch_mask = valid_positions[:, ch] if len(valid_positions.shape) > 1 else valid_positions
                                if ch_mask.sum() == 0:
                                    continue
                                    
                                orig_ch = sample_original[:, ch][ch_mask]
                                recon_ch = sample_reconstruction[:, ch][ch_mask]
                                
                                if len(orig_ch) == 0:
                                    continue
                                    
                                orig_mean = orig_ch.mean()
                                ss_res = ((orig_ch - recon_ch) ** 2).sum()
                                ss_tot = ((orig_ch - orig_mean) ** 2).sum()
                                
                                if ss_tot == 0:
                                    ch_r_squared = 1.0 if ss_res == 0 else 0.0
                                else:
                                    ch_r_squared = 1 - (ss_res / ss_tot)
                                    
                                channel_r_squared_values.append(ch_r_squared.item())
                                
                        elif len(sample_original.shape) == 3:
                            # 3D: (H, W, C) or similar - iterate over last dimension as channels
                            n_channels = sample_original.shape[2]
                            for ch in range(n_channels):
                                if len(valid_positions.shape) == 3:
                                    ch_mask = valid_positions[:, :, ch]
                                else:
                                    ch_mask = valid_positions
                                    
                                if ch_mask.sum() == 0:
                                    continue
                                    
                                orig_ch = sample_original[:, :, ch][ch_mask]
                                recon_ch = sample_reconstruction[:, :, ch][ch_mask]
                                
                                if len(orig_ch) == 0:
                                    continue
                                    
                                orig_mean = orig_ch.mean()
                                ss_res = ((orig_ch - recon_ch) ** 2).sum()
                                ss_tot = ((orig_ch - orig_mean) ** 2).sum()
                                
                                if ss_tot == 0:
                                    ch_r_squared = 1.0 if ss_res == 0 else 0.0
                                else:
                                    ch_r_squared = 1 - (ss_res / ss_tot)
                                    
                                channel_r_squared_values.append(ch_r_squared.item())
                        
                        # Average R² across all valid channels for this sample
                        if len(channel_r_squared_values) > 0:
                            sample_r_squared = sum(channel_r_squared_values) / len(channel_r_squared_values)
                        else:
                            sample_r_squared = 0.0
                    
                    # Ensure sample_r_squared is a scalar value
                    if hasattr(sample_r_squared, 'item'):
                        sample_r_squared_values.append(sample_r_squared.item())
                    else:
                        sample_r_squared_values.append(float(sample_r_squared))
                
                # Average R² across all valid samples for this modality
                if len(sample_r_squared_values) > 0:
                    r_squared = sum(sample_r_squared_values) / len(sample_r_squared_values)
                else:
                    r_squared = 0.0
                    
            else:
                # No mask - use original logic but improved
                # Flatten all data for R² calculation
                orig_flat = original_cpu.flatten()
                recon_flat = reconstruction_cpu.flatten()
                
                # Remove NaN/Inf values
                valid_mask = ~torch.isnan(orig_flat) & ~torch.isinf(orig_flat) & ~torch.isnan(recon_flat) & ~torch.isinf(recon_flat)
                if valid_mask.sum() == 0:
                    r_squared = 0.0
                else:
                    orig_valid = orig_flat[valid_mask]
                    recon_valid = recon_flat[valid_mask]
                    
                    orig_mean = orig_valid.mean()
                    ss_res = ((orig_valid - recon_valid) ** 2).sum()
                    ss_tot = ((orig_valid - orig_mean) ** 2).sum()
                    
                    if ss_tot == 0:
                        r_squared = 1.0 if ss_res == 0 else 0.0
                    else:
                        r_squared = 1 - (ss_res / ss_tot)
                    
                    r_squared = r_squared.item()
            
            r_squared_values.append(r_squared)
            
            if verbose:
                print(f"   Modality {i}: R² = {r_squared:.4f}")
    
    return r_squared_values

def check_start_reduction_ninfea(model, val_data, device, multi_gpu, patience, epoch, 
                                r_square_threshold, threshold_type, val_masks=None, verbose=False):
    """
    Check if we should start rank reduction for NInFEA model.
    Adapted from the multimodal version for hierarchical architecture.
    Handles arbitrary number of modalities (>2) with hierarchical adaptive layers.
    
    Returns:
        tuple: (start_reduction, rank_history, initial_squares, min_rsquares, current_rsquare_per_mod)
    """
    if (epoch == model.epoch + patience):  # giving the optimizer a start
        # Get the number of modalities dynamically
        n_modalities = len(val_data.data)
        
        # Initialize rank history with hierarchical layer structure
        rank_history = {
            'total_rank': [model.get_total_rank() if hasattr(model, 'get_total_rank') else 0],
            'ranks': [', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
            'epoch': [model.epoch],
            'loss': [0.0],  # Will be filled by caller
            'val_loss': [0.0]  # Will be filled by caller
        }
        
        start_reduction = True  # Start rank reduction after early stopping
        
        with torch.no_grad():
            # For NInFEA hierarchical architecture, get encodings
            val_data_list = [val_data.data[i].to(device) for i in range(n_modalities)]
            _, h_encoded = model(val_data_list)  # h_encoded is hierarchical dict
            
        min_rsquares = []
        initial_squares = []
        current_rsquare_per_mod = []
        
        # Direct reconstruction R² approach - this gives per-modality R² values
        direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, val_masks, multi_gpu)
        
        # Ensure we have R² values for all modalities
        if len(direct_r_squared_values) != n_modalities:
            if verbose:
                print(f"Warning: Expected {n_modalities} R² values, got {len(direct_r_squared_values)}")
            # Pad with zeros or truncate as needed
            while len(direct_r_squared_values) < n_modalities:
                direct_r_squared_values.append(0.0)
            direct_r_squared_values = direct_r_squared_values[:n_modalities]
        
        # Initialize per-modality tracking
        for i, r_squared_val in enumerate(direct_r_squared_values):
            initial_squares.append(r_squared_val)
            
            # Calculate threshold based on threshold_type
            if threshold_type == 'relative':
                min_rsquares.append(r_squared_val * r_square_threshold)
            elif threshold_type == 'absolute':
                min_rsquares.append(r_squared_val - r_square_threshold)
            else:
                raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                
            current_rsquare_per_mod.append(r_squared_val)
            rank_history[f'rsquare {i}'] = [r_squared_val]
        
        # Add hierarchical layer information to rank history
        if hasattr(model, 'adaptive_layer_map'):
            layer_names = list(model.adaptive_layer_map.keys())
            rank_history['layer_names'] = layer_names
            rank_history['n_modalities'] = n_modalities
            rank_history['hierarchical_structure'] = True
        else:
            rank_history['hierarchical_structure'] = False
            
        if verbose:
            print(f"NInFEA Hierarchical Architecture:")
            print(f"  - Number of modalities: {n_modalities}")
            print(f"  - Number of adaptive layers: {len(model.adaptive_layers)}")
            if hasattr(model, 'adaptive_layer_map'):
                print(f"  - Layer structure: {list(model.adaptive_layer_map.keys())}")
            print(f"  - Initial R-squared values: {direct_r_squared_values}")
            print(f"  - Setting {threshold_type} thresholds to {min_rsquares}")
        
        print(f"Initial R-squared values for {n_modalities} modalities: {direct_r_squared_values}, setting {threshold_type} thresholds to {min_rsquares}")
        
        return start_reduction, rank_history, initial_squares, min_rsquares, current_rsquare_per_mod, 0
    
    return False, None, None, None, None, 0

def decide_rank_reduction_ninfea(model, val_data, device, multi_gpu, epoch, rank_schedule, 
                               start_reduction, break_counter, reduce_on_best_loss, min_rsquares,
                               r_squares, current_rsquare_per_mod, patience_counter, patience,
                               min_ranks, compressibility_type, reduction_criterion, r_square_threshold,
                               rank_reduction_threshold, sharedwhenall, val_masks=None, compute_jacobian=False, verbose=False):
    """
    Decide which ranks to reduce/increase for NInFEA hierarchical model.
    Adapted from the multimodal version.
    
    Returns:
        tuple: (any_changes_made, updated_rank_history_entry, updated_r_squares, updated_current_rsquare_per_mod)
    """
    if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
        if (reduce_on_best_loss == 'rsquare') & (start_reduction):
            ###
            # get the r_square values per modality - handle arbitrary number of modalities
            ###
            n_modalities = len(val_data.data)
            
            with torch.no_grad():
                val_data_list = [val_data.data[i].to(device) for i in range(n_modalities)]
                _, h_encoded = model(val_data_list)  # NInFEA hierarchical encoding
                
            current_rsquares = []
            modalities_to_reduce = []
            modalities_to_increase = []
            
            # Direct reconstruction R² approach
            direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, val_masks, multi_gpu)
            
            # Ensure consistency in array lengths
            if len(direct_r_squared_values) != n_modalities:
                if verbose:
                    print(f"Warning: Expected {n_modalities} R² values, got {len(direct_r_squared_values)}")
                # Pad with previous values or zeros as needed
                while len(direct_r_squared_values) < n_modalities:
                    direct_r_squared_values.append(0.0)
                direct_r_squared_values = direct_r_squared_values[:n_modalities]
            
            # Ensure current_rsquare_per_mod has correct length
            if len(current_rsquare_per_mod) != n_modalities:
                current_rsquare_per_mod = [0.0] * n_modalities
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                current_rsquares.append(r_squared_val)
                if i < len(current_rsquare_per_mod):
                    current_rsquare_per_mod[i] = r_squared_val
                    
            r_squares.append(current_rsquares)

            # if the current rsquare is larger than the initial one, reset the min_rsquare
            for i in range(n_modalities):
                mod_rsquares = [r[i] for r in r_squares]
                max_rsquare = max(mod_rsquares)
                min_rsquares[i] = max_rsquare - r_square_threshold  # reset to relative threshold of max observed

            ###
            # determine what modalities to reduce or increase
            ###
            bottom_reached = model.bottom_reached if hasattr(model, 'bottom_reached') else False
            
            if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                for i in range(len(current_rsquare_per_mod)):
                    i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                    
                    # Handle different comparison logic for loss vs R²
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        # For loss: lower is better, so we reduce if loss is high (above threshold)
                        if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                            modalities_to_increase.append(i)
                        elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                            modalities_to_reduce.append(i)
                    else:
                        # For R²: higher is better (original logic)
                        if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                            modalities_to_increase.append(i)
                        elif current_rsquare_per_mod[i] > min_rsquares[i]:
                            modalities_to_reduce.append(i)
            elif (len(r_squares) >= 1):
                for i in range(len(current_rsquare_per_mod)):
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        # For loss: reduce if loss is below threshold (good performance)
                        if current_rsquare_per_mod[i] < min_rsquares[i]:
                            modalities_to_reduce.append(i)
                        elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                            modalities_to_increase.append(i)
                    else:
                        # For R²: reduce if R² is above threshold (original logic)
                        if current_rsquare_per_mod[i] > min_rsquares[i]:
                            modalities_to_reduce.append(i)
                        elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                            modalities_to_increase.append(i)

            # if all modalities can be reduced, we set min and max ranks
            if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                for i, cr in enumerate(current_ranks):
                    if cr <= min_ranks[i]:
                        min_ranks[i] = cr
                        # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
            if len(modalities_to_increase) == len(current_rsquare_per_mod):
                # set minima
                print("All modalities need increase, setting minimum ranks to current ranks")
                current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                for i, cr in enumerate(current_ranks):
                    if cr <= min_ranks[i]:
                        min_ranks[i] = cr
                    # if we are increasing all ranks, we can also increase the maximum ranks
                    model.adaptive_layers[i].min_rank = min_ranks[i]
                model.bottom_reached = True
            
            ###
            # set the patience counters and layers to reduce or increase
            # Handle hierarchical NInFEA layer mapping properly
            ###
            layers_to_reduce = []
            layers_to_increase = []
            
            # Get the hierarchical layer structure for proper mapping
            if hasattr(model, 'adaptive_layer_map'):
                layer_names = list(model.adaptive_layer_map.keys())
                n_total_layers = len(model.adaptive_layers)
                
                # Find indices for global_shared, shared subspaces, and modality-specific layers
                global_shared_idx = None
                specific_layer_indices = []
                shared_subspace_indices = {}  # Maps modality combinations to layer indices
                
                for idx, layer_name in enumerate(layer_names):
                    if layer_name == 'global_shared':
                        global_shared_idx = idx
                    elif layer_name.startswith('specific_'):
                        modality_num = int(layer_name.split('_')[1])
                        specific_layer_indices.append((idx, modality_num))
                    elif layer_name.startswith('shared_'):
                        # Parse shared subspace name like 'shared_0_1' or 'shared_0_1_2'
                        modality_combo = tuple(int(x) for x in layer_name.split('_')[1:])
                        shared_subspace_indices[modality_combo] = idx
            else:
                # Fallback to simple mapping if hierarchical structure not available
                global_shared_idx = 0
                specific_layer_indices = [(i + 1, i) for i in range(n_modalities)]
                shared_subspace_indices = {}
            
            # Helper function to find shared subspaces that involve specific modalities
            def get_shared_subspaces_for_modalities(modality_set, shared_subspace_indices):
                """Find all shared subspaces that are subsets of the given modality set"""
                relevant_subspaces = []
                for combo, layer_idx in shared_subspace_indices.items():
                    if set(combo).issubset(set(modality_set)):
                        relevant_subspaces.append((combo, layer_idx))
                return relevant_subspaces
            
            if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                pass
            elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                # Mixed case: reduce specific layers for modalities that can be reduced

                # should I allow the shared to be reduced???
                layers_to_reduce.append(global_shared_idx)
                # Also consider shared subspaces
                
                # Add specific layers for modalities to reduce
                for mod_idx in modalities_to_reduce:
                    for layer_idx, mod_num in specific_layer_indices:
                        if mod_num == mod_idx:
                            layers_to_reduce.append(layer_idx)
                            break
                
                # Add shared subspaces that only involve modalities to be reduced
                shared_to_reduce = get_shared_subspaces_for_modalities(modalities_to_reduce, shared_subspace_indices)
                for combo, layer_idx in shared_to_reduce:
                    # Only reduce shared subspace if ALL its constituent modalities are in modalities_to_reduce
                    if all(mod in modalities_to_reduce for mod in combo):
                        layers_to_reduce.append(layer_idx)
                
                # Add specific layers for modalities to increase
                for mod_idx in modalities_to_increase:
                    for layer_idx, mod_num in specific_layer_indices:
                        if mod_num == mod_idx:
                            layers_to_increase.append(layer_idx)
                            break
                            
                # Add shared subspaces that involve modalities to be increased
                shared_to_increase = get_shared_subspaces_for_modalities(modalities_to_increase, shared_subspace_indices)
                for combo, layer_idx in shared_to_increase:
                    # Increase shared subspace if ANY of its constituent modalities need increase
                    if any(mod in modalities_to_increase for mod in combo):
                        layers_to_increase.append(layer_idx)
            else:
                if len(modalities_to_increase) > 0:
                    if len(modalities_to_increase) == len(current_rsquare_per_mod):
                        # if all modalities are below the threshold, increase ranks of all layers
                        layers_to_increase = list(range(len(model.adaptive_layers)))
                        # start with the shared only
                        #layers_to_increase.append(global_shared_idx)
                    else:
                        # Increase specific layers for underperforming modalities
                        layers_to_increase.append(global_shared_idx)
                        for mod_idx in modalities_to_increase:
                            for layer_idx, mod_num in specific_layer_indices:
                                if mod_num == mod_idx:
                                    layers_to_increase.append(layer_idx)
                                    break
                        
                        # Increase shared subspaces that involve underperforming modalities
                        shared_to_increase = get_shared_subspaces_for_modalities(modalities_to_increase, shared_subspace_indices)
                        for combo, layer_idx in shared_to_increase:
                            # Increase shared subspace if ANY of its constituent modalities need increase
                            if any(mod in modalities_to_increase for mod in combo):
                                layers_to_increase.append(layer_idx)
                
                if len(modalities_to_reduce) > 0:
                    # if all modalities can be reduced, consider reducing shared layers too
                    if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                        reduce_shared = True
                        if reduce_shared and global_shared_idx is not None:
                            layers_to_reduce.append(global_shared_idx)
                        
                        # Reduce all shared subspaces since all modalities can be reduced
                        for combo, layer_idx in shared_subspace_indices.items():
                            layers_to_reduce.append(layer_idx)
                        
                        # Also reduce specific layers
                        for mod_idx in modalities_to_reduce:
                            for layer_idx, mod_num in specific_layer_indices:
                                if mod_num == mod_idx:
                                    layers_to_reduce.append(layer_idx)
                                    break
                    else:
                        # Partial reduction: decide whether to include shared layers
                        if sharedwhenall:
                            reduce_shared = False
                        else:
                            reduce_shared = True
                        
                        if reduce_shared and global_shared_idx is not None:
                            layers_to_reduce.append(global_shared_idx)
                        
                        # Reduce shared subspaces that only involve modalities to be reduced
                        shared_to_reduce = get_shared_subspaces_for_modalities(modalities_to_reduce, shared_subspace_indices)
                        for combo, layer_idx in shared_to_reduce:
                            # Only reduce shared subspace if ALL its constituent modalities are in modalities_to_reduce
                            if all(mod in modalities_to_reduce for mod in combo):
                                layers_to_reduce.append(layer_idx)
                        
                        # Reduce specific layers for modalities that can be reduced
                        for mod_idx in modalities_to_reduce:
                            for layer_idx, mod_num in specific_layer_indices:
                                if mod_num == mod_idx:
                                    layers_to_reduce.append(layer_idx)
                                    break
            
            # Remove duplicates from layer lists
            layers_to_reduce = list(set(layers_to_reduce))
            layers_to_increase = list(set(layers_to_increase))
            
            if verbose:
                print(f"NInFEA Hierarchical Rank Reduction Decision:")
                if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                    print(f"  - {reduction_criterion} values ({n_modalities} modalities): {current_rsquares}")
                else:
                    print(f"  - R-squared values ({n_modalities} modalities): {current_rsquares}")
                print(f"  - Modalities to reduce: {modalities_to_reduce}")
                print(f"  - Modalities to increase: {modalities_to_increase}")
                
                if hasattr(model, 'adaptive_layer_map'):
                    layer_names = list(model.adaptive_layer_map.keys())
                    print(f"  - Available layers: {dict(enumerate(layer_names))}")
                    
                    # Categorize layers being modified
                    global_layers_reduce = []
                    shared_layers_reduce = []
                    specific_layers_reduce = []
                    global_layers_increase = []
                    shared_layers_increase = []
                    specific_layers_increase = []
                    
                    for layer_idx in layers_to_reduce:
                        layer_name = layer_names[layer_idx] if layer_idx < len(layer_names) else f"layer_{layer_idx}"
                        if layer_name == 'global_shared':
                            global_layers_reduce.append(layer_name)
                        elif layer_name.startswith('shared_'):
                            shared_layers_reduce.append(layer_name)
                        elif layer_name.startswith('specific_'):
                            specific_layers_reduce.append(layer_name)
                    
                    for layer_idx in layers_to_increase:
                        layer_name = layer_names[layer_idx] if layer_idx < len(layer_names) else f"layer_{layer_idx}"
                        if layer_name == 'global_shared':
                            global_layers_increase.append(layer_name)
                        elif layer_name.startswith('shared_'):
                            shared_layers_increase.append(layer_name)
                        elif layer_name.startswith('specific_'):
                            specific_layers_increase.append(layer_name)
                    
                    if layers_to_reduce:
                        print(f"  - REDUCING layers:")
                        if global_layers_reduce:
                            print(f"    * Global shared: {global_layers_reduce}")
                        if shared_layers_reduce:
                            print(f"    * Shared subspaces: {shared_layers_reduce}")
                        if specific_layers_reduce:
                            print(f"    * Modality-specific: {specific_layers_reduce}")
                    
                    if layers_to_increase:
                        print(f"  - INCREASING layers:")
                        if global_layers_increase:
                            print(f"    * Global shared: {global_layers_increase}")
                        if shared_layers_increase:
                            print(f"    * Shared subspaces: {shared_layers_increase}")
                        if specific_layers_increase:
                            print(f"    * Modality-specific: {specific_layers_increase}")
                else:
                    print(f"  - Layers to reduce: {layers_to_reduce}")
                    print(f"  - Layers to increase: {layers_to_increase}")
        
        # Apply rank changes
        any_changes_made = False
        if len(layers_to_reduce) > 0:
            # Apply rank reduction
            if multi_gpu:
                changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
            else:
                changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
            if changes_made:
                any_changes_made = True
        
        if len(layers_to_increase) > 0:
            # Apply rank increase
            if multi_gpu:
                changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
            else:
                changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
            if changes_made:
                any_changes_made = True
            break_counter = patience
        
        # Prepare rank history update
        total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        rank_history_entry = {
            'total_rank': total_rank_after,
            'ranks': ', '.join(str(layer.active_dims) for layer in model.adaptive_layers),
            'epoch': epoch,
            'rsquares': current_rsquares
        }
        
        return any_changes_made, rank_history_entry, r_squares, current_rsquare_per_mod, break_counter
    
    return False, None, None, None, break_counter

def train_overcomplete_ae_with_pretrained(train_dataset, val_dataset, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, hidden_dim=512, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, paired=False, conv_depth=3,
                         lr_schedule=None, lr_schedule_step_size=None, lr_schedule_gamma=0.1
                         ):
    """
    Train a NInFEA autoencoder with adaptive rank reduction using pretrained weights
    
    Args:
        train_dataset: Training dataset (can be raw tensors or dataset object with masks)
        val_dataset: Validation dataset (can be raw tensors or dataset object with masks) 
        latent_dim: Latent dimension size
        device: Device to train on
        args: Training arguments object
        pretrained_name: Name of the pretrained model to load
        conv_depth: Convolutional layer depth for NInFEA
    """
    import os
    
    # Get sample to determine input shapes
    #sample_data, sample_masks = train_dataset[0]
    sample_item = train_dataset[0]
    if isinstance(sample_item, tuple):
        sample_data = sample_item[0]
        sample_masks = sample_item[1] if len(sample_item) > 1 else None
    else:
        sample_data = sample_item
        sample_masks = None
    # Handle both 2D and 3D data: 2D -> (seq_length, n_channels), 3D -> (height, width, n_channels)
    input_shapes = []
    for d in sample_data:
        if len(d.shape) == 2:
            # 2D data: (seq_length, n_channels)
            input_shapes.append((d.shape[0], d.shape[1]))
        elif len(d.shape) == 3:
            # 3D data: (height, width, n_channels)
            input_shapes.append((d.shape[0], d.shape[1], d.shape[2]))
        else:
            #raise ValueError(f"Unsupported data shape: {d.shape}. Expected 2D or 3D.")
            input_shapes.append(d.shape)  # Keep original shape if not 2D/3D
    
    # Check for existing pretrained model
    pretrained_model_path = f"./03_results/models/pretrained_models/{pretrained_name}.pt" if pretrained_name else None
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        if isinstance(latent_dim, int):
            latent_dims = [latent_dim] * (len(input_shapes) + 1)
        elif isinstance(latent_dim, list):
            if (len(latent_dim) == 1) & (len(input_shapes) > 1):
                latent_dims = [latent_dim[0]] * (len(input_shapes) + 1)
            else:
                latent_dims = latent_dim
        if train_dataset.n_modalities > 2:
            #model = AdaptiveRankReducedAE_NInFEA(
            model = AdaptiveRankReducedAE_MM(
                input_shapes, latent_dims, depth=ae_depth, hidden_dim=hidden_dim, 
                dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
                min_rank=min_rank, conv_depth=conv_depth
            )
        elif train_dataset.n_modalities == 2:
            model = AdaptiveRankReducedAE_NInFEA_2Mods(
                input_shapes, latent_dims, depth=ae_depth, hidden_dim=hidden_dim, 
                dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
                min_rank=min_rank, conv_depth=conv_depth
            )
        else:
            raise ValueError("NInFEA pretraining requires at least 2 modalities.")    
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        
        # Load loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
        
        # Generate reconstruction plots if they don't exist yet
        plot_path = pretrained_model_path.replace('.pt', '_reconstructions.png')
        if not os.path.exists(plot_path):
            try:
                modality_names = getattr(args, 'modality_names', None)
                plot_ninfea_reconstructions(model, train_dataset, val_dataset, device, plot_path, 
                                          modality_names=modality_names)
            except Exception as e:
                print(f"Warning: Failed to generate reconstruction plot: {e}")
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            
            # Run pretraining with datasets
            model, [train_losses, val_losses], data_indices = pretrain_overcomplete_ae_ninfea(
                train_dataset, val_dataset, latent_dim, device, args, epochs=epochs, early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, hidden_dim=hidden_dim, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank,
                verbose=verbose, conv_depth=conv_depth, paired=True
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            # also save data_indices
            if data_indices is not None:
                data_indices_path = pretrained_model_path.replace('.pt', '_data_indices.pt')
                torch.save(data_indices, data_indices_path)
            print(f"Saved pretrained model to {pretrained_model_path} and loss curves to {loss_curve_path}")
            model.epoch = len(train_losses)
            
            # Generate reconstruction plots after pretraining
            plot_path = pretrained_model_path.replace('.pt', '_reconstructions.png')
            try:
                modality_names = getattr(args, 'modality_names', None)
                plot_ninfea_reconstructions(model, train_dataset, val_dataset, device, plot_path, 
                                          modality_names=modality_names)
            except Exception as e:
                print(f"Warning: Failed to generate reconstruction plot: {e}")
        else:
            raise ValueError("pretrained_name must be provided to save/load pretrained models.")
    
    # Continue with rank reduction training adapted for NInFEA
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # if lr_schedule is provided, use smaller lr
    if lr_schedule is not None:
        lr = 0.1 * lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loaders
    num_workers = getattr(args, 'num_workers', 0)
    
    # We already have train and validation datasets with masks
    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=num_workers
    )
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=num_workers
    )
    n_samples = len(train_dataset) + len(val_dataset)
    n_samples_train = len(train_dataset)
    n_samples_val = len(val_dataset)
    
    # For compatibility, we need to set 'data' for later use in the function
    # Get a sample to understand data structure
    sample_item = train_dataset[0]
    if isinstance(sample_item, tuple):
        sample_data = sample_item[0]
    else:
        sample_data = sample_item
    #sample_data, _ = train_dataset[0]
    # Create a dummy data list for compatibility (this won't be used for actual training)
    data = [torch.zeros(n_samples, *tensor.shape) for tensor in sample_data]
    
    # Create val_data object for rank reduction functions
    # These functions expect val_data.data to contain validation tensors
    class ValDataWrapper:
        def __init__(self, data_tensors):
            self.data = data_tensors
    
    # Extract validation data and masks for R² computation
    def extract_validation_data_and_masks():
        """Extract all validation data and masks for rank reduction functions"""
        val_data_tensors = []
        val_mask_tensors = []
        
        # Collect all validation samples
        for mod_idx in range(len(sample_data)):
            mod_samples = []
            mod_masks = []
            
            for sample_idx in range(int(0.2*len(train_dataset))):
                val_item = train_dataset[sample_idx]
                if isinstance(val_item, tuple):
                    val_sample = val_item[0]
                    val_mask = val_item[1] if len(val_item) > 1 else None
                else:
                    val_sample = val_item
                    val_mask = None
                #val_sample, val_mask = val_dataset[sample_idx]
                mod_samples.append(val_sample[mod_idx])
                if val_mask is not None and mod_idx < len(val_mask):
                    mod_masks.append(val_mask[mod_idx])
                else:
                    # Create a mask of all True if no mask provided
                    mod_masks.append(torch.ones_like(val_sample[mod_idx], dtype=torch.bool))
            
            val_data_tensors.append(torch.stack(mod_samples))
            val_mask_tensors.append(torch.stack(mod_masks))
            
        return val_data_tensors, val_mask_tensors
    
    val_data_tensors, val_mask_tensors = extract_validation_data_and_masks()
    val_data = ValDataWrapper(val_data_tensors)
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * len(data) # per modality
    initial_losses = [None] * len(data) # per modality (for loss-based criteria)
    start_reduction = False
    post_training = False
    current_rsquare_per_mod = [None] * len(data)
    current_loss_per_mod = [None] * len(data)  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    #loss_scales = torch.ones(len(data), device=device)
    if hasattr(args, 'loss_weights'):
        loss_scales = args.loss_weights
    else:
        loss_scales = [1.0] * len(sample_data)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, batch_data in enumerate(data_loader):
            # Handle both masked and unmasked data flexibly
            if isinstance(batch_data, tuple) and len(batch_data) == 2:
                x, mask = batch_data  # Dataset with masks (PaddedMultimodalDataWithMasks, MMSimData)
            else:
                x = batch_data  # Dataset without masks (PairedMultimodalData)
                mask = None
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_encoded = model(x)  # NInFEA returns hierarchical encoded dict
            
            # No orthogonal loss in this version
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality (adapted for NInFEA 3D tensors)
            modality_masks = []
            if mask is not None:
                if isinstance(mask, list):
                    # PaddedMultimodalDataWithMasks returns a list of masks (one per modality)
                    modality_masks = [m.to(device) if m is not None else None for m in mask]
                else:
                    # Legacy: mask is a single concatenated tensor
                    start_idx = 0
                    for i, x_m in enumerate(x):
                        end_idx = start_idx + x_m.shape[1] * x_m.shape[2]  # seq_length * n_channels
                        modality_masks.append(mask[:, start_idx:end_idx].view(mask.shape[0], x_m.shape[1], x_m.shape[2]))
                        start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Compute MSE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase
        with torch.no_grad():
            for val_batch_data in val_data_loader:
                # Handle both masked and unmasked data flexibly
                if isinstance(val_batch_data, tuple) and len(val_batch_data) == 2:
                    x_val, mask = val_batch_data  # Dataset with masks
                else:
                    x_val = val_batch_data  # Dataset without masks
                    mask = None
                
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    if isinstance(mask, list):
                        # PaddedMultimodalDataWithMasks returns a list of masks (one per modality)
                        modality_masks = [m.to(device) if m is not None else None for m in mask]
                    else:
                        # Legacy: mask is a single concatenated tensor
                        start_idx = 0
                        for i, x_m in enumerate(x_val):
                            end_idx = start_idx + x_m.shape[1] * x_m.shape[2]  # seq_length * n_channels
                            modality_masks.append(mask[:, start_idx:end_idx].view(mask.shape[0], x_m.shape[1], x_m.shape[2]))
                            start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Update progress bar
        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else [],
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(data))],
            'patience': patience_counter,
            'break': break_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        #if train_loss < best_loss:
        #    best_loss = train_loss
        #    patience_counter = 0  # Reset patience counter
        #else:
        #    patience_counter += 1
        
        # Check if we should start reduction
        if not start_reduction:
            reduction_result = check_start_reduction_ninfea(
                model, val_data, device, multi_gpu, patience, epoch, 
                r_square_threshold, threshold_type, val_mask_tensors, verbose
            )
            if reduction_result[0]:  # start_reduction is True
                start_reduction, rank_history, initial_squares, min_rsquares, current_rsquare_per_mod, break_counter = reduction_result
                # Update rank history with current losses
                rank_history['loss'][-1] = train_loss
                rank_history['val_loss'][-1] = val_loss
        
        # Apply rank reduction logic
        if start_reduction:
            reduction_result = decide_rank_reduction_ninfea(
                model, val_data, device, multi_gpu, epoch, rank_schedule, 
                start_reduction, break_counter, reduce_on_best_loss, min_rsquares,
                r_squares, current_rsquare_per_mod, patience_counter, patience,
                min_ranks, compressibility_type, reduction_criterion, r_square_threshold,
                rank_reduction_threshold, sharedwhenall, val_mask_tensors, compute_jacobian, verbose
            )
            
            if reduction_result[0]:  # any_changes_made
                any_changes_made, rank_history_entry, r_squares, current_rsquare_per_mod, break_counter = reduction_result
                # Update rank history
                rank_history['total_rank'].append(rank_history_entry['total_rank'])
                rank_history['ranks'].append(rank_history_entry['ranks'])
                rank_history['epoch'].append(rank_history_entry['epoch'])
                rank_history['loss'].append(train_loss)
                rank_history['val_loss'].append(val_loss)
                for i, rsq in enumerate(rank_history_entry['rsquares']):
                    rank_history[f'rsquare {i}'].append(rsq)
                
                if any_changes_made:
                    patience_counter = 0  # Reset patience counter if rank was changed
                else:
                    patience_counter += 1
            elif reduction_result[1] is not None:  # rank_history_entry exists but no changes
                r_squares, current_rsquare_per_mod = reduction_result[2], reduction_result[3]
                patience_counter += 1
            #else:
            #    patience_counter += 1
        
        #break_counter = reduction_result[-1]
        # Handle break counter
        if (break_counter > 0) & (epoch in rank_schedule):
            break_counter -= 1
        
        # Early stopping but conditioned on rank reduction
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Done with rank reduction at epoch {epoch} with best loss {best_loss}")
            post_training = True
            start_reduction = False
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (post_training is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
    
    # Calculate latent representations in batches
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    """
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_i = min(i + batch_size, n_samples)
            batch_data = [d[i:end_i].to(device) for d in data]
            _, h_encoded = model(batch_data)
            
            # Extract hierarchical representations properly using the adaptive_layer_map
            if hasattr(model, 'adaptive_layer_map') and isinstance(h_encoded, dict):
                # Map each layer in adaptive_layers to its corresponding representation in h_encoded
                for layer_name, layer_idx in model.adaptive_layer_map.items():
                    if layer_idx < len(reps):
                        if layer_name == 'global_shared' and 'global_shared' in h_encoded:
                            reps[layer_idx][i:end_i] = h_encoded['global_shared'].cpu()
                        elif layer_name.startswith('shared_') and layer_name in h_encoded:
                            reps[layer_idx][i:end_i] = h_encoded[layer_name].cpu()
                        elif layer_name.startswith('specific_'):
                            # Extract modality index from layer name
                            mod_idx = int(layer_name.split('_')[1])
                            if 'specific' in h_encoded and mod_idx < len(h_encoded['specific']):
                                reps[layer_idx][i:end_i] = h_encoded['specific'][mod_idx].cpu()
                            else:
                                # Fallback: use zeros
                                reps[layer_idx][i:end_i] = torch.zeros(end_i - i, final_ranks[layer_idx])
            else:
                # Fallback for cases without proper hierarchical structure
                for j, layer in enumerate(model.adaptive_layers):
                    if j < len(reps):
                        reps[j][i:end_i] = torch.randn(end_i - i, final_ranks[j])
    """
    reps = []
    # Save the model if model_name provided
    if model_name:
        os.makedirs("./03_results/models/", exist_ok=True)
        torch.save(model.state_dict(), f"./03_results/models/{model_name}.pt")
        if verbose:
            print(f"Model saved to ./03_results/models/{model_name}.pt")
    
    # Initialize rank_history if not created
    if 'rank_history' not in locals():
        # Get model's total rank safely
        if hasattr(model, 'get_total_rank'):
            total_rank = model.get_total_rank()
        elif hasattr(model, 'module') and hasattr(model.module, 'get_total_rank'):
            total_rank = model.module.get_total_rank()
        else:
            total_rank = 0
            
        # Get model's adaptive layers safely
        if hasattr(model, 'adaptive_layers'):
            adaptive_layers = model.adaptive_layers
        elif hasattr(model, 'module') and hasattr(model.module, 'adaptive_layers'):
            adaptive_layers = model.module.adaptive_layers
        else:
            adaptive_layers = []
            
        # Get current epoch safely
        if hasattr(model, 'epoch'):
            current_epoch = model.epoch
        elif hasattr(model, 'module') and hasattr(model.module, 'epoch'):
            current_epoch = model.module.epoch
        else:
            current_epoch = 0
            
        rank_history = {
            'total_rank': [total_rank],
            'ranks': [', '.join(str(layer.active_dims) for layer in adaptive_layers)],
            'epoch': [current_epoch],
            'loss': [train_losses[-1] if train_losses else 0.0],
            'val_loss': [val_losses[-1] if val_losses else 0.0]
        }

    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]
    
    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses]