

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import math
import numpy as np
import tqdm
import gc
from src.visualization.logging import plot_training_state, create_training_movie
from src.models.larrp_multimodal_cnn import AdaptiveRankReducedAE_CNN, MMSimData
import matplotlib.pyplot as plt
import pandas as pd
import os

def compute_direct_r_squared(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate R² for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            # Move to CPU and flatten to (N, D)
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            # Align batch dimension if needed
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Debug: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for R² computation.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Calculate mean of original flattened data
            original_mean = orig_flat.mean(dim=0).cpu()
            original_cpu = orig_flat
            reconstruction_cpu = recon_flat
            
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean for modality {i}. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data for modality {i}. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                #print(f"   Computing R² normally for modality {i}.")
                # check if there are any NaNs or Infs in original or reconstruction
                if torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)) or \
                   torch.any(torch.isnan(reconstruction_cpu)) or torch.any(torch.isinf(reconstruction_cpu)):
                    if verbose:
                        print(f"   Warning: NaN or Inf values found in data for modality {i}. Handling them.")
                    # Handle NaN or Inf values
                    valid_mask = ~torch.isnan(original_cpu) & ~torch.isinf(original_cpu) & \
                                 ~torch.isnan(reconstruction_cpu) & ~torch.isinf(reconstruction_cpu)
                    if valid_mask.sum() == 0:
                        # If no valid values, set R² to 0
                        r_squared = torch.tensor(0.0)
                    else:
                        valid_indices = valid_mask
                        ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                        ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                else:
                    # Normal case - calculate standard R²
                    # print mean original and mean reconstruction
                    #print(f"   Mean original (modality {i}): {original_cpu.mean().item()}, Mean reconstruction: {reconstruction_cpu.mean().item()}")
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)
                    # if there are any very small ss_tot values, print a warning
                    if torch.any(ss_tot < 1e-3):
                        if verbose:
                            print(f"   Warning: Very small ss_tot values found for modality {i}. This may lead to unstable R² values.")
                        valid_mask = ss_tot >= 1e-3
                        if valid_mask.sum() == 0:
                            r_squared = torch.tensor(0.0)
                        else:
                            ssr = ssr[valid_mask]
                            ss_tot = ss_tot[valid_mask]
                #print(f"   SSR sum: {ssr.sum().item()}, SSTot sum: {ss_tot.sum().item()}")
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                # mask out the ones that are negative
                # if there are any negative r_squared values, print a warning
                if torch.any(r_squared < 0):
                    if verbose: # print the fraction of negative values
                        n_negative = (r_squared < 0).sum().item()
                        total = r_squared.numel()
                        fraction_negative = n_negative / total
                        print(f"   Warning: {fraction_negative:.2%} Negative R² values detected for modality {i}. This may indicate poor reconstruction.")
                    r_squared = torch.clamp(r_squared, min=0.0)
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Handle negative R² values (poor fit) - use correlation as fallback
            """
            if r_squared < 0:
                print(f"   Warning: Negative R² value detected for modality {i}: {r_squared}. Using correlation instead.")
                # Use correlation as a fallback
                correlation_matrix = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))
                r_squared = correlation_matrix[0, 1]
                if torch.isnan(r_squared):
                    r_squared = torch.tensor(0.0)
                else:
                    # Square the correlation to get R²-like measure
                    r_squared = r_squared ** 2
            """
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            
            r_squared_values.append(r_squared.item())
    
    return r_squared_values

def compute_direct_explained_variance(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute explained variance score based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of explained variance values for each modality
    """
    model.eval()
    explained_variance_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate explained variance for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            # Flatten originals and reconstructions to (N, D)
            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            if verbose:
                print(f"   Debug: modality {i} original shape {original_cpu.shape} -> flat {orig_flat.shape}; recon shape {reconstruction_cpu.shape} -> flat {recon_flat.shape}")

            # Align batch dimension
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Warning: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for explained variance.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Handle NaN or Inf values
            if torch.any(torch.isnan(orig_flat)) or torch.any(torch.isinf(orig_flat)) or \
               torch.any(torch.isnan(recon_flat)) or torch.any(torch.isinf(recon_flat)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in flattened data for modality {i}. Handling them.")
                valid_mask = ~torch.isnan(orig_flat) & ~torch.isinf(orig_flat) & \
                             ~torch.isnan(recon_flat) & ~torch.isinf(recon_flat)
                if valid_mask.sum() == 0:
                    explained_variance = torch.tensor(0.0)
                else:
                    original_valid = orig_flat[valid_mask]
                    reconstruction_valid = recon_flat[valid_mask]
                    explained_variance = 1 - (torch.var(reconstruction_valid - original_valid) / (torch.var(original_valid) + 1e-9))
            else:
                # Normal case - calculate standard explained variance on flattened data
                explained_variance = 1 - (torch.var(recon_flat - orig_flat) / (torch.var(orig_flat) + 1e-9))

            # Handle negative explained variance values (poor fit)
            if isinstance(explained_variance, torch.Tensor):
                if explained_variance < 0:
                    if verbose:
                        print(f"   Warning: Negative explained variance value detected for modality {i}: {explained_variance}. Clamping to 0.")
                    explained_variance = torch.clamp(explained_variance, min=0.0)
            else:
                if explained_variance < 0:
                    explained_variance = 0.0

            # Ensure explained_variance is a scalar tensor
            if not isinstance(explained_variance, torch.Tensor):
                explained_variance = torch.tensor(explained_variance)

            explained_variance_values.append(explained_variance.item())
    
    return explained_variance_values

import os

def pretrain_overcomplete_ae(train_data, val_data, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True, recon_loss_balancing=False, lr_schedule=None,
                         input_shapes=None):
    """
    Train an autoencoder for mm_sim pretraining (no rank reduction).
    
    Parameters:
    - train_data: Training data list of tensors (one per modality)
    - val_data: Validation data list of tensors (one per modality)
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank
    - patience: Early stopping patience
    - verbose: Print progress
    - recon_loss_balancing: Adaptive loss balancing across modalities
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create CNN model with adaptive rank reduction
    if input_shapes is None:
        raise ValueError("input_shapes must be provided for CNN model")
    
    input_dims = [shape[0] * shape[1] * shape[2] for shape in input_shapes]  # For compatibility
    
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    
    model = AdaptiveRankReducedAE_CNN(
        input_dims, latent_dims, input_shapes, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank
    ).to(device)
    #print(model)
    # print the device the model is on
    print(f"Model is on device: {next(model.parameters()).device}")
    # print the number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params} parameters")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler if requested
    scheduler = None
    if lr_schedule == 'linear':
        try:
            # Use LinearLR when available (PyTorch >= 1.11)
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=1000)
        except Exception:
            # Fallback to LambdaLR for older PyTorch versions
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
    elif lr_schedule == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 1000], gamma=0.1)
    elif lr_schedule == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Create data loaders from pre-split data
    train_dataset = MMSimData(train_data)
    val_dataset = MMSimData(val_data)
    
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(train_data), device=device)
    #loss_scales[1] = 0.1
    loss_history = {f'mod_{i}_loss': [] for i in range(len(train_data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(train_data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(train_data)
        
        for batch_idx, x in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # For CNN, use MSE loss directly on 4D tensors
                m_loss = F.mse_loss(x_hat_m, x_m, reduction='mean')
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss

            # Backward pass and optimize (normal training)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        # Ortho loss is not used in pretraining
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    m_loss = F.mse_loss(x_hat_m, x_m, reduction='mean')
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            #try:
            scheduler.step()
            #except Exception:
            #    # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
            #    pass

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1

        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
        
        # Update progress bar
        # include current learning rate in progress bar
        if scheduler is not None:
            current_lr = scheduler.get_last_lr()[0]
            pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.4e}")
        else:
            try:
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.2e}")
            except Exception:
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
    
    return model, [train_losses, val_losses]

def train_overcomplete_ae_with_pretrained(train_data, val_data, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, lr_schedule='constant',
                         decision_metric='R2', input_shapes=None, end_lr=1e-5
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - train_data: Training data list of tensors (one per modality)
    - val_data: Validation data list of tensors (one per modality)
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """

    # determine whether multi-GPU mode is requested (make available early)
    multi_gpu = getattr(args, 'multi_gpu', False)

    # check if there is an existing pretrained model for the seed, early stopping, and training hyperparameters (lr, wd, batch size, model architecture)
    pretrained_model_path = f"{project_config.RESULTS_DIR}/so2sat/pretrained_{pretrained_name}.pt" if pretrained_name else None
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        if input_shapes is None:
            raise ValueError("input_shapes must be provided for CNN model")
        
        input_dims = [shape[0] * shape[1] * shape[2] for shape in input_shapes]
        
        if isinstance(latent_dim, int):
            latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
        elif isinstance(latent_dim, list):
            if (len(latent_dim) == 1) & (len(input_dims) > 1):
                latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
            else:
                latent_dims = latent_dim
        
        model = AdaptiveRankReducedAE_CNN(
            input_dims, latent_dims, input_shapes, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        )
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
        # make sure that the weights are changed
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        # also load the loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        # print last losses
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
        # If reconstruction plots do not exist, create and save example reconstructions
        try:
            mod0_plot_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_mod0_recon.png')
            mod1_plot_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_mod1_recon.png')
            if (not os.path.exists(mod0_plot_path)) or (not os.path.exists(mod1_plot_path)):
                # ensure model parameters are on the same device as inputs
                try:
                    model = model.to(device)
                except Exception:
                    pass
                model.eval()
                with torch.no_grad():
                    # take a small sample from the provided data to generate example reconstructions
                    n_plot = min(8, train_data[0].shape[0])
                    rng = np.random.default_rng(seed=42)
                    sample_idx = rng.choice(train_data[0].shape[0], size=n_plot, replace=False)
                    
                    # Get sample data for both modalities
                    mod0_sample = train_data[0][sample_idx].to(device)
                    mod1_sample = train_data[1][sample_idx].to(device)

                    # Use module if DataParallel wrapping already present
                    if multi_gpu and hasattr(model, 'module'):
                        reconstructions, _ = model.module([mod0_sample, mod1_sample])
                    else:
                        reconstructions, _ = model([mod0_sample, mod1_sample])

                    # Save plots (basic visualization for CNN outputs)
                    # This would need proper plotting functions for image data
                    print(f"Reconstruction shapes: mod0={reconstructions[0].shape}, mod1={reconstructions[1].shape}")
                print(f"Note: Reconstruction plot generation skipped for CNN model")
        except Exception as e:
            print(f"Warning: Could not save reconstruction plots: {e}")
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            model, [train_losses, val_losses] = pretrain_overcomplete_ae(
                train_data, val_data, latent_dim, device, args, epochs=int(epochs/2), early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, ae_width=ae_width, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank, lr_schedule=lr_schedule,
                verbose=verbose, input_shapes=input_shapes
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            # Also save a PNG of the pretraining loss curves (train & val)
            try:
                loss_png_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_pretrain_loss_curve.png')
                os.makedirs(os.path.dirname(loss_png_path), exist_ok=True)
                plt.figure(figsize=(6, 4))
                plt.plot(np.arange(1, len(train_losses) + 1), train_losses, label='train')
                plt.plot(np.arange(1, len(val_losses) + 1), val_losses, label='val')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Pretraining Loss Curves')
                plt.legend()
                plt.tight_layout()
                plt.savefig(loss_png_path, dpi=150)
                plt.close()
            except Exception as e:
                print(f"Warning: Could not save pretraining loss PNG: {e}")
            # Save example plots for reconstruction
            try:
                saved = False
                with torch.no_grad():
                    # Determine a small batch size for plotting
                    n_plot = min(8, train_data[0].shape[0])
                    rng = np.random.default_rng(seed=42)
                    sample_idx = rng.choice(train_data[0].shape[0], size=n_plot, replace=False)

                    # Prepare modality tensors
                    mod0 = train_data[0][sample_idx].to(device)
                    mod1 = train_data[1][sample_idx].to(device)
                    
                    last_batch_data = [mod0, mod1]

                    # Run through model (handle DataParallel)
                    if multi_gpu and hasattr(model, 'module'):
                        model.module.eval()
                        reconstructions, _ = model.module(last_batch_data)
                    else:
                        model.eval()
                        reconstructions, _ = model(last_batch_data)

                    # Basic info logging (plotting would need proper visualization for CNN outputs)
                    print(f"Reconstruction shapes: mod0={reconstructions[0].shape}, mod1={reconstructions[1].shape}")
                    saved = True

                if saved:
                    print(f"Note: Detailed reconstruction plots skipped for CNN model")
            except Exception as e:
                print(f"Warning: Could not save reconstruction plots: {e}")
            
            model.epoch = len(train_losses)
            
            # Save reconstruction plots for both modalities
            try:
                # Need to create datasets from the train_data and val_data that are in scope
                temp_train_dataset = MMSimData(train_data)
                temp_val_dataset = MMSimData(val_data)
                
                model.eval()
                with torch.no_grad():
                    # Get 8 train and 8 test samples
                    n_plot = 8
                    
                    # Train samples
                    train_idx = torch.randperm(len(temp_train_dataset))[:n_plot]
                    train_samples = [temp_train_dataset.data[i][train_idx].to(device) for i in range(len(train_data))]
                    
                    # Test samples
                    test_idx = torch.randperm(len(temp_val_dataset))[:n_plot]
                    test_samples = [temp_val_dataset.data[i][test_idx].to(device) for i in range(len(train_data))]
                    
                    # Get reconstructions
                    if multi_gpu and hasattr(model, 'module'):
                        train_recon, _ = model.module(train_samples)
                        test_recon, _ = model.module(test_samples)
                    else:
                        train_recon, _ = model(train_samples)
                        test_recon, _ = model(test_samples)
                    
                    # Plot modality 0 (radar - channel 4 only)
                    mod0_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_mod0_recon.png')
                    os.makedirs(os.path.dirname(mod0_path), exist_ok=True)
                    fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                    for i in range(n_plot):
                        # Train original (channel 4)
                        axes[0, i].imshow(train_samples[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_title('Train', fontsize=10, loc='left')
                        # Train reconstruction (channel 4)
                        axes[1, i].imshow(train_recon[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                        # Test original (channel 4)
                        axes[2, i].imshow(test_samples[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[2, i].axis('off')
                        if i == 0:
                            axes[2, i].set_title('Test', fontsize=10, loc='left')
                        # Test reconstruction (channel 4)
                        axes[3, i].imshow(test_recon[0][i, 4].cpu().numpy(), cmap='gray')
                        axes[3, i].axis('off')
                        if i == 0:
                            axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                    plt.tight_layout()
                    plt.savefig(mod0_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved radar reconstruction plot to {mod0_path}")
                    
                    # Plot modality 1 (optical - RGB)
                    mod1_path = pretrained_model_path.replace(project_config.RESULTS_DIR, './03_results/train_plots/').replace('.pt', '_mod1_recon.png')
                    os.makedirs(os.path.dirname(mod1_path), exist_ok=True)
                    fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                    for i in range(n_plot):
                        # Train original (RGB)
                        img_train = train_samples[1][i].permute(1, 2, 0).cpu().numpy()
                        # Normalize to [0, 1] for display
                        img_train = (img_train - img_train.min()) / (img_train.max() - img_train.min() + 1e-8)
                        axes[0, i].imshow(img_train)
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_title('Train', fontsize=10, loc='left')
                        # Train reconstruction (RGB)
                        img_train_recon = train_recon[1][i].permute(1, 2, 0).cpu().numpy()
                        img_train_recon = (img_train_recon - img_train_recon.min()) / (img_train_recon.max() - img_train_recon.min() + 1e-8)
                        axes[1, i].imshow(img_train_recon)
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                        # Test original (RGB)
                        img_test = test_samples[1][i].permute(1, 2, 0).cpu().numpy()
                        img_test = (img_test - img_test.min()) / (img_test.max() - img_test.min() + 1e-8)
                        axes[2, i].imshow(img_test)
                        axes[2, i].axis('off')
                        if i == 0:
                            axes[2, i].set_title('Test', fontsize=10, loc='left')
                        # Test reconstruction (RGB)
                        img_test_recon = test_recon[1][i].permute(1, 2, 0).cpu().numpy()
                        img_test_recon = (img_test_recon - img_test_recon.min()) / (img_test_recon.max() - img_test_recon.min() + 1e-8)
                        axes[3, i].imshow(img_test_recon)
                        axes[3, i].axis('off')
                        if i == 0:
                            axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                    plt.tight_layout()
                    plt.savefig(mod1_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved optical reconstruction plot to {mod1_path}")
            except Exception as e:
                print(f"Warning: Could not save reconstruction plots: {e}")
        else:
            raise ValueError("model_name must be provided to save/load pretrained models.")
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler if requested
    scheduler = None
    if lr_schedule == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs, eta_min=end_lr
        )
    elif lr_schedule == 'linear':
        try:
            # Use LinearLR when available (PyTorch >= 1.11)
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=end_lr/lr, total_iters=epochs)
        except Exception:
            # Fallback to LambdaLR for older PyTorch versions
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(end_lr/lr, 1.0 - (epoch + 1) / float(max(1, epochs))))
    elif lr_schedule == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs*0.5), int(epochs*0.75)], gamma=0.1)

    # Create data loaders from pre-split data
    train_dataset = MMSimData(train_data)
    val_dataset = MMSimData(val_data)
    
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * len(train_data) # per modality
    initial_losses = [None] * len(train_data) # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * len(train_data)
    current_loss_per_mod = [None] * len(train_data)  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(train_data), device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(len(train_data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(train_data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(train_data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(train_data)
        
        for batch_idx, x in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # For CNN, use MSE loss directly on 4D tensors
                m_loss = F.mse_loss(x_hat_m, x_m, reduction='mean')
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss

            # Backward pass and optimize (normal training)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    m_loss = F.mse_loss(x_hat_m, x_m, reduction='mean')
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            #try:
            scheduler.step()
            #except Exception:
            #    # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
            #    pass

        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(train_data))],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1

        if (start_reduction is False) and (epoch == model.epoch + patience): # giving the optimizer a start
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)

            #with torch.no_grad():
            #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_modality = model.encode_modalities(val_data_tensors)
            #    #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
            #    encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            #if mask is not None:
            #    start_idx = 0
            #    for j, x_m in enumerate(x_val):
            #        end_idx = start_idx + x_m.shape[1]
            #        temp_mask = mask[:, start_idx]
            #        # expand it to match the encoded shape
            #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
            #        modality_masks_latent.append(temp_mask)
            #        modality_masks_data.append(mask[:, start_idx:end_idx])
            #        modality_masks_space.append(mask[:, start_idx])
            #        start_idx = end_idx
            
            # Direct reconstruction R² approach (original behavior)
            #val_data_list = [val_data.data[i] for i in range(len(data))]
            #val_data_list = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            # Build a small validation subset (10% of training samples)
            n_sub = int(0.1 * train_data[0].shape[0])
            
            # Both modalities are already in proper tensor format (4D for images)
            val_data_list = [train_data[0][:n_sub].to(device), train_data[1][:n_sub].to(device)]

            if verbose:
                print(f"   Debug: val_data_list[0] shape (images) = {val_data_list[0].shape}")
                print(f"   Debug: val_data_list[1] shape (audio)  = {val_data_list[1].shape}")

            # Also query model recon shapes once for debugging (no side effects)
            if verbose:
                try:
                    with torch.no_grad():
                        recon_debug, _ = model([val_data_list[0], val_data_list[1]])
                        recon_shapes = [r.shape for r in recon_debug]
                        print(f"   Debug: model reconstructions shapes = {recon_shapes}")
                except Exception as e:
                    print(f"   Debug: could not run model for recon shape check: {e}")
            if decision_metric == 'ExVarScore':
                direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu, verbose=verbose)
            else:  # Default to R2
                direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu, verbose=verbose)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                # Calculate threshold based on threshold_type
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                    
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
            max_rsquares = initial_squares.copy()
                
            if verbose:
                #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(current_rsquare_per_mod))]}, setting {threshold_type} thresholds to {min_rsquares}")
            #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                #with torch.no_grad():
                #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
                #    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                #    encoded_per_modality = model.encode_modalities(val_data_tensors)
                #    if not compute_jacobian:
                #        #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                #        encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
                #        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                #    else:
                #        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                #if mask is not None:
                #    start_idx = 0
                #    for j, x_m in enumerate(x_val):
                #        end_idx = start_idx + x_m.shape[1]
                #        #modality_masks.append(mask[:, start_idx:end_idx])
                #        # expand it to match the encoded shape
                #        temp_mask = mask[:, start_idx]
                #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                #        modality_masks_latent.append(temp_mask)
                #        modality_masks_data.append(mask[:, start_idx:end_idx])
                #        modality_masks_space.append(mask[:, start_idx])
                #        start_idx = end_idx
                
                # Direct reconstruction R² approach
                n_sub = int(0.1 * train_data[0].shape[0])
                val_data_list = [train_data[0][:n_sub].to(device), train_data[1][:n_sub].to(device)]
                if decision_metric == 'ExVarScore':
                    direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu)
                else:  # Default to R2
                    direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                        
                r_squares.append(current_rsquares)
                
                # Update min_rsquares based on maximum observed (allows threshold to follow best performance)
                #max_rquares = [max(r_squares, key=lambda x: x[i])[i] for i in range(len(current_rsquare_per_mod))] if len(r_squares) > 0 else initial_squares
                update_max = False
                for i, r in enumerate(current_rsquares):
                    if r > max_rsquares[i]:
                        max_rsquares[i] = r
                        update_max = True
                if update_max:
                    if threshold_type == 'relative':
                        min_rsquares = [r * r_square_threshold for r in max_rsquares]
                    elif threshold_type == 'absolute':
                        # For R²: subtract threshold; for loss: add threshold
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            min_rsquares = [r + r_square_threshold for r in max_rsquares]
                        else:
                            min_rsquares = [r - r_square_threshold for r in max_rsquares]
                    if verbose:
                        print(f"Updated threshold {min_rsquares} based on new max {max_rsquares}")

                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                        #model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                    #bottom_reached = True
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    # set the min for the modality to be increased to current rank
                    model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #if sharedwhenall:
                            #    reduce_shared = False
                            #else:
                            #    reduce_shared = True
                            #if reduce_shared:
                            #    layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            #else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            #for i in range(len(encoded_per_modality)):
            for i in range(len(current_rsquare_per_mod)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
            
            # Save checkpoint after rank change to preserve progress
            if any_changes_made and model_name:
                checkpoint_path = f"{project_config.RESULTS_DIR}/so2sat/{model_name}_checkpoint_intermediate.pt"
                os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
                if multi_gpu:
                    torch.save(model.module.state_dict(), checkpoint_path)
                else:
                    torch.save(model.state_dict(), checkpoint_path)
                
                # Save training history (losses and rank_history)
                history_path = checkpoint_path.replace('.pt', '_history.npz')
                np.savez(history_path,
                         train_losses=np.array(train_losses),
                         val_losses=np.array(val_losses),
                         rank_history=rank_history,
                         loss_history=loss_history)
                
                if verbose:
                    print(f"   Saved checkpoint to {checkpoint_path} and history to {history_path}")

            # also get mutual information between all spaces
            #valid_spaces = []
            #for encoded in encoded_per_space:
            #    valid_spaces_temp = []
            #    for i in range(len(encoded_per_modality)):
            #        #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
            #        if mask is not None:
            #            temp_mask = modality_masks_space[i]
            #            #valid_spaces_temp.append(normalized_encoded[mask])
            #            valid_spaces_temp.append(encoded[temp_mask])
            #        else:
            #            #valid_spaces_temp.append(normalized_encoded)
            #            valid_spaces_temp.append(encoded)
            #    # stack the valid spaces
            #    valid_spaces_temp = torch.vstack(valid_spaces_temp)
            #    valid_spaces.append(valid_spaces_temp)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # Get normalized weights for display
        if multi_gpu:
            weights = model.module.modality_weights
        else:
            weights = model.modality_weights
            
        pos_weights = F.softplus(weights)
        norm_weights = (pos_weights / (pos_weights.sum() + 1e-8)).detach().cpu().numpy().round(3)

        
        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if (epoch > early_stopping) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
    
    # Calculate latent representations in batches (only for training data)
    #'''
    n_samples = train_data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [train_data[j][i:end_idx].to(device) for j in range(len(train_data))]
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)#.cpu()
            else:
                batch_reps = model.encode(x_batch)#.cpu()
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            # No need to convert dtype
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
        
        # Save final reconstruction plots for both modalities
        try:
            model.eval()
            with torch.no_grad():
                # Get 8 train and 8 test samples
                n_plot = 8
                
                # Train samples
                train_idx = torch.randperm(len(train_dataset))[:n_plot]
                train_samples = [train_dataset.data[i][train_idx].to(device) for i in range(len(train_data))]
                
                # Test samples
                test_idx = torch.randperm(len(val_dataset))[:n_plot]
                test_samples = [val_dataset.data[i][test_idx].to(device) for i in range(len(train_data))]
                
                # Get reconstructions
                if multi_gpu and hasattr(model, 'module'):
                    train_recon, _ = model.module(train_samples)
                    test_recon, _ = model.module(test_samples)
                else:
                    train_recon, _ = model(train_samples)
                    test_recon, _ = model(test_samples)
                
                # Plot modality 0 (radar - channel 4 only)
                mod0_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_mod0_recon_final.png')
                os.makedirs(os.path.dirname(mod0_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original (channel 4)
                    axes[0, i].imshow(train_samples[0][i, 4].cpu().numpy(), cmap='gray')
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_title('Train', fontsize=10, loc='left')
                    # Train reconstruction (channel 4)
                    axes[1, i].imshow(train_recon[0][i, 4].cpu().numpy(), cmap='gray')
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                    # Test original (channel 4)
                    axes[2, i].imshow(test_samples[0][i, 4].cpu().numpy(), cmap='gray')
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_title('Test', fontsize=10, loc='left')
                    # Test reconstruction (channel 4)
                    axes[3, i].imshow(test_recon[0][i, 4].cpu().numpy(), cmap='gray')
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                plt.tight_layout()
                plt.savefig(mod0_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved final radar reconstruction plot to {mod0_path}")
                
                # Plot modality 1 (optical - RGB)
                mod1_path = model_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', '_mod1_recon_final.png')
                os.makedirs(os.path.dirname(mod1_path), exist_ok=True)
                fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                for i in range(n_plot):
                    # Train original (RGB)
                    img_train = train_samples[1][i].permute(1, 2, 0).cpu().numpy()
                    # Normalize to [0, 1] for display
                    img_train = (img_train - img_train.min()) / (img_train.max() - img_train.min() + 1e-8)
                    axes[0, i].imshow(img_train)
                    axes[0, i].axis('off')
                    if i == 0:
                        axes[0, i].set_title('Train', fontsize=10, loc='left')
                    # Train reconstruction (RGB)
                    img_train_recon = train_recon[1][i].permute(1, 2, 0).cpu().numpy()
                    img_train_recon = (img_train_recon - img_train_recon.min()) / (img_train_recon.max() - img_train_recon.min() + 1e-8)
                    axes[1, i].imshow(img_train_recon)
                    axes[1, i].axis('off')
                    if i == 0:
                        axes[1, i].set_title('Train Recon', fontsize=10, loc='left')
                    # Test original (RGB)
                    img_test = test_samples[1][i].permute(1, 2, 0).cpu().numpy()
                    img_test = (img_test - img_test.min()) / (img_test.max() - img_test.min() + 1e-8)
                    axes[2, i].imshow(img_test)
                    axes[2, i].axis('off')
                    if i == 0:
                        axes[2, i].set_title('Test', fontsize=10, loc='left')
                    # Test reconstruction (RGB)
                    img_test_recon = test_recon[1][i].permute(1, 2, 0).cpu().numpy()
                    img_test_recon = (img_test_recon - img_test_recon.min()) / (img_test_recon.max() - img_test_recon.min() + 1e-8)
                    axes[3, i].imshow(img_test_recon)
                    axes[3, i].axis('off')
                    if i == 0:
                        axes[3, i].set_title('Test Recon', fontsize=10, loc='left')
                plt.tight_layout()
                plt.savefig(mod1_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"Saved final optical reconstruction plot to {mod1_path}")
        except Exception as e:
            print(f"Warning: Could not save final reconstruction plots: {e}")
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]

    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses]

def plot_image_reconstruction(original_imgs, recon_imgs, n=8, out_path=None):
    """Plot original and reconstructed images side-by-side.
    original_imgs and recon_imgs are numpy arrays of shape (N, 784) with values in [0,1].
    """
    orig = original_imgs[:n]
    recon = recon_imgs[:n]
    fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
    for i in range(n):
        axes[0, i].imshow(orig[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_modal_image_reconstruction(original_imgs, recon_imgs, image_shape=(28,28), n=8, out_path=None):
    """Plot original and reconstructed modality images for arbitrary image shapes.
    original_imgs and recon_imgs can be:
      - numpy arrays shaped (N, H*W)
      - numpy arrays shaped (N, H, W)
      - numpy arrays shaped (N, 1, H, W)
    image_shape: tuple (H, W)
    """
    H, W = image_shape
    # Normalize and reshape originals
    orig = original_imgs[:n]
    recon = recon_imgs[:n]

    def ensure_hw(arr):
        if arr.ndim == 4 and arr.shape[1] == 1:
            return arr[:, 0]
        if arr.ndim == 3:
            return arr
        if arr.ndim == 2:
            # assume flattened
            return arr.reshape(-1, H, W)
        raise ValueError('Unsupported array shape for image reconstruction plotting: ' + str(arr.shape))

    orig_hw = ensure_hw(orig)
    recon_hw = ensure_hw(recon)

    fig, axes = plt.subplots(2, n, figsize=(n * 2, 3))
    for i in range(n):
        axes[0, i].imshow(orig_hw[i], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon_hw[i], cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_audio_scatter(original_audio, recon_audio, out_path=None):
    """Create a scatter plot of original vs reconstructed audio features.
    original_audio and recon_audio are numpy arrays shape (N, D).
    """
    # Reduce to first two dims using PCA if dimensionality > 2
    from sklearn.decomposition import PCA
    if original_audio.shape[1] > 2:
        pca = PCA(n_components=2)
        orig_2 = pca.fit_transform(original_audio)
        recon_2 = pca.transform(recon_audio)
    else:
        orig_2 = original_audio[:, :2]
        recon_2 = recon_audio[:, :2]

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(orig_2[:, 0], orig_2[:, 1], s=4, alpha=0.6, label='original')
    ax.scatter(recon_2[:, 0], recon_2[:, 1], s=4, alpha=0.6, label='recon')
    ax.legend()
    ax.set_title('Audio original vs recon (PCA 2D)')
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)