import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import math
import numpy as np
import tqdm
import gc
from src.visualization.logging import plot_training_state, create_training_movie
from src.models.larrp_unimodal import AdaptiveRankReducedLinear
import matplotlib.pyplot as plt
import pandas as pd

class AVMNISTDataset(Dataset):
    """Custom Dataset for AVMNIST to handle preprocessing."""
    def __init__(self, data_tuple, full_spectrum=False):
        # data_tuple should be (images, audio, labels)
        self.images = torch.FloatTensor(data_tuple[0])
        self.audio = torch.FloatTensor(data_tuple[1])
        self.data = [self.images, self.audio]
        #self.labels = torch.LongTensor(data_tuple[2])
        self.n_modalities = 2
        self.n_samples = self.images.shape[0]
        self.n_features = [self.images.shape[1], self.audio.shape[1]]
        self.total_features = sum(self.n_features)
        self.mask = None
        self.full_spectrum = full_spectrum

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        # The MVAE architecture expects a 112-dim vector for audio.
        # The raw audio is a 112x112 spectrogram, so we average it.
        if not self.full_spectrum:
            audio_feature = torch.mean(self.audio[idx], dim=0) # (112, 112) -> (112,)
        else:
            audio_feature = self.audio[idx]
        
        # Image is flattened to 784, but CNN needs (1, 28, 28)
        # The model's forward pass will handle the reshape.
        image_feature = self.images[idx]

        return [image_feature, audio_feature]

class AdaptiveRankReducedAE_AvMnist(nn.Module):
    """
    An autoencoder class adapted for multimodal data like AV-MNIST.
    It uses specialized encoders/decoders for each modality (CNN for images, MLP for audio)
    and learns shared and modality-specific subspaces.
    """
    def __init__(self, input_dims, latent_dims, depth=2, width=0.5, dropout=0.0,
                 initial_rank_ratio=1.0, min_rank=1, full_spectrum=False):
        """
        Args:
            input_dims (tuple): Dimensions of the input for each modality (e.g., (784, 112)).
            latent_dims (tuple): Dimensions for [mod1_specific, mod2_specific, shared] subspaces.
            full_spectrum (bool): Whether to use full spectrum (112x112) audio data instead of averaged (112).
        """
        super(AdaptiveRankReducedAE_AvMnist, self).__init__()
        self.full_spectrum = full_spectrum
        
        # --- 1. Define Specialized Encoders for Each Modality ---
        self.encoders = nn.ModuleList()
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.adaptive_layers = nn.ModuleList()  # Track adaptive rank layers for rank reduction

        # Image Encoder (CNN based on MultiBench MVAE example)
        img_latent_dim = latent_dims[0]
        self.channel_size = 16
        self.n_pixel_after_conv = 4 * 4 * self.channel_size * 4
        self.img_encoder = nn.Sequential(
            nn.Conv2d(1, self.channel_size, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.channel_size, self.channel_size*4, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(self.n_pixel_after_conv, img_latent_dim*4), # 512 is the flattened size after convs
            nn.ReLU(),
            nn.Linear(img_latent_dim*4, img_latent_dim),
            nn.Sigmoid()
        )
        self.encoders[0].append(self.img_encoder)

        # Audio Encoder - different architecture for full spectrum vs averaged audio
        audio_latent_dim = latent_dims[1]
        if full_spectrum:
            # For full spectrum (112x112), use CNN similar to image but with one extra conv layer
            self.audio_channel_size = 16
            # Calculate dimensions more carefully:
            # 112x112 -> 108x108 (conv kernel=5) -> 54x54 (maxpool) -> 50x50 (conv kernel=5) -> 25x25 (maxpool) -> 21x21 (conv kernel=5) -> 10x10 (maxpool)
            self.audio_n_pixel_after_conv = 10 * 10 * (self.audio_channel_size * 4)
            print(f"Audio flattened size after convs: {self.audio_n_pixel_after_conv}")
            self.audio_encoder = nn.Sequential(
                nn.Conv2d(1, self.audio_channel_size, kernel_size=5, padding=0),  # 112x112 -> 108x108
                nn.ReLU(),
                nn.MaxPool2d(2, 2),  # 108x108 -> 54x54
                nn.Conv2d(self.audio_channel_size, self.audio_channel_size*2, kernel_size=5, padding=0),  # 54x54 -> 50x50
                nn.ReLU(),
                nn.MaxPool2d(2, 2),  # 50x50 -> 25x25
                nn.Conv2d(self.audio_channel_size*2, self.audio_channel_size*4, kernel_size=5, padding=0),  # 25x25 -> 21x21
                nn.ReLU(),
                nn.MaxPool2d(2, 2),  # 21x21 -> 10x10 (actually 10x10 since 21//2 = 10)
                nn.Flatten(),
                nn.Linear(self.audio_n_pixel_after_conv, audio_latent_dim*8),
                nn.ReLU(),
                nn.Linear(audio_latent_dim*8, audio_latent_dim),
                nn.Sigmoid()
            )
        else:
            # For averaged audio (112), use MLP as before
            self.audio_encoder = nn.Sequential(
                nn.Linear(input_dims[1], audio_latent_dim*4),
                nn.ReLU(),
                nn.Linear(audio_latent_dim*4, audio_latent_dim),
                nn.Sigmoid()
            )
        self.encoders[1].append(self.audio_encoder)
        
        # --- 2. Define Layers for Shared and Specific Subspaces ---
        # This part remains from your original design, learning the subspaces
        # from the outputs of the specialized encoders.
        self.adaptive_layers = nn.ModuleList()
        shared_latent_dim = latent_dims[2]
        # now for the integral part where we learn the separate spaces
        shared_layer = AdaptiveRankReducedLinear(
            sum(latent_dims[:len(input_dims)]), latent_dims[-1], # last latent dim is for the shared space
            #latent_dims[0], latent_dims[-1], # try max pooling
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        for i in range(len(input_dims)):
            specific_layer = AdaptiveRankReducedLinear(
                #sum(latent_dims[:len(input_dims)]), latent_dims[i],
                latent_dims[i], latent_dims[i],
                #latent_dims[0], latent_dims[i], # try max pooling
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
        # Initialize log variance parameters for loss balancing
        self.modality_weights = nn.Parameter(torch.ones(len(input_dims)), requires_grad=True)

        # --- 3. Define Specialized Decoders for Each Modality ---
        # Image Decoder (Transposed CNN)
        # It takes the shared + image-specific latent vectors
        self.img_decoder = nn.Sequential(
            nn.Linear(shared_latent_dim + img_latent_dim, 2*(shared_latent_dim + img_latent_dim)),
            nn.ReLU(),
            nn.Linear(2*(shared_latent_dim + img_latent_dim), self.n_pixel_after_conv),
            nn.ReLU(),
            nn.Unflatten(1, (self.channel_size*4, 4, 4)),
            nn.ConvTranspose2d(self.channel_size*4, self.channel_size, kernel_size=5, stride=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(self.channel_size, 1, kernel_size=5, stride=2, output_padding=1),
            nn.Sigmoid() # To ensure output is between 0 and 1
        )
        self.decoders[0].append(self.img_decoder)

        # Audio Decoder - different architecture for full spectrum vs averaged audio
        if full_spectrum:
            # For full spectrum (112x112), use transposed CNN similar to image but with one extra layer
            # We need to carefully calculate dimensions to get exactly 112x112 output
            self.audio_decoder = nn.Sequential(
                nn.Linear(shared_latent_dim + audio_latent_dim, 4*(shared_latent_dim + audio_latent_dim)),
                nn.ReLU(),
                nn.Linear(4*(shared_latent_dim + audio_latent_dim), self.audio_n_pixel_after_conv),
                nn.ReLU(),
                nn.Unflatten(1, (self.audio_channel_size*4, 10, 10)),
                # 10x10 -> 22x22 (10*2 + 5 - 1 - 2*0 = 22)
                nn.ConvTranspose2d(self.audio_channel_size*4, self.audio_channel_size*2, kernel_size=5, stride=2, padding=1, output_padding=1),
                nn.ReLU(),
                # 22x22 -> 46x46 (22*2 + 5 - 1 - 2*1 = 46) 
                nn.ConvTranspose2d(self.audio_channel_size*2, self.audio_channel_size, kernel_size=5, stride=2, padding=1, output_padding=1),
                nn.ReLU(),
                # 46x46 -> 96x96, then pad to 112x112
                nn.ConvTranspose2d(self.audio_channel_size, 1, kernel_size=7, stride=2, padding=1, output_padding=1),
                nn.Upsample(size=(112, 112), mode='bilinear', align_corners=False),  # Ensure exact 112x112
                nn.Sigmoid()  # To ensure output is between 0 and 1
            )
        else:
            # For averaged audio (112), use MLP as before
            self.audio_decoder = nn.Sequential(
                nn.Linear(shared_latent_dim + audio_latent_dim, 2*(shared_latent_dim + audio_latent_dim)),
                nn.ReLU(),
                nn.Linear(2*(shared_latent_dim + audio_latent_dim), input_dims[1])
            )
        self.decoders[1].append(self.audio_decoder)
    
    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h

    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            if m == 0:
                # Image: accept (batch, 784) or (batch, 1, 28, 28)
                if x_m.dim() == 2:
                    x_m = x_m.view(-1, 1, 28, 28)
                elif x_m.dim() == 3:
                    x_m = x_m.unsqueeze(1)
            elif m == 1:
                # Audio handling: support averaged (112) or full_spectrum (112x112)
                if self.full_spectrum:
                    # Accept flattened (batch, 12544), (batch, 112, 112) or (batch,1,112,112)
                    if x_m.dim() == 2:
                        # flattened
                        x_m = x_m.view(-1, 1, 112, 112)
                    elif x_m.dim() == 3:
                        # (batch, 112, 112) -> add channel dim
                        x_m = x_m.unsqueeze(1)
                    elif x_m.dim() == 4:
                        # already (batch,1,112,112)
                        pass
                    else:
                        # fallback: try to reshape conservatively
                        x_m = x_m.contiguous().view(-1, 1, 112, 112)
                else:
                    # Averaged audio: expect (batch, 112) or (batch,112,?) -> average last dim
                    if x_m.dim() > 2:
                        # average last dimension to get (batch, 112)
                        x_m = x_m.mean(dim=-1)
            for layer in self.encoders[m]:
                #print(next(layer.parameters()).device, x_m.device)
                x_m = layer(x_m)
            h_concat.append(x_m)
        h = torch.cat(h_concat, dim=1)

        if not compute_jacobian:
            h_shared = self.adaptive_layers[0](h)
            specific_outputs = []
            for i, layer in enumerate(self.adaptive_layers[1:]):
                #specific_output = layer(h)
                specific_output = layer(h_concat[i])
                specific_outputs.append(specific_output)
            return (h_shared, specific_outputs)
        else:
            h_split = [self.adaptive_layers[0](h)]
            weights = [self.adaptive_layers[0].get_weights()]
            for i, layer in enumerate(self.adaptive_layers[1:]):
                #specific_output = layer(h)
                specific_output = layer(h_concat[i])
                h_split.append(specific_output)
                weights.append(layer.get_weights())
            
            contractive_losses = []
            for i, (activation, weight) in enumerate(zip(h_split, weights)):
                # Approximate Jacobian as outer product of activations and weights
                batch_size = activation.shape[0]
                derivative = (activation > 0).float() # for ReLU only, for sigmoid would be h * (1-h)
                # This is the key vectorized calculation for the penalty of one layer
                # It calculates sum_j ( (h'(a)_j)^2 * sum_i (W_ji)^2 ) efficiently
                # sum of squared weights per output unit
                w_squared = torch.sum(weight**2, dim=1)
                d_squared = derivative**2
                # contractive loss per layer: sum_j ( (h'(a)_j)^2 * sum_i (W_ji)^2 ) averaged over batch
                contractive_loss_layer = torch.sum(d_squared * w_squared.unsqueeze(0))
                contractive_losses.append((contractive_loss_layer / batch_size).detach().cpu().item())  # Average over batch
            return h_split, contractive_losses
    
    def decode(self, h):
        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            #print(f"Decoding modality {m} with shape {h_concat.shape}")
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
            if m == 0:
                # Image: flatten to 784 dimensions
                h_concat = h_concat.view(-1, 784)
            elif m == 1 and self.full_spectrum:
                # Audio in full spectrum mode: keep as (batch, 1, 112, 112)
                # Ensure decoder output shape is (batch, 1, 112, 112)
                if h_concat.dim() == 2:
                    # some decoder sequences may still give flattened output; reshape conservatively
                    try:
                        h_concat = h_concat.view(-1, 1, 112, 112)
                    except Exception:
                        # fallback: attempt to reshape using inferred batch size
                        batch = h_concat.shape[0]
                        h_concat = h_concat.reshape(batch, 1, 112, 112)
                else:
                    # already shaped (batch, C, H, W)
                    h_concat = h_concat
            x_hat.append(h_concat)
        return x_hat
    
    def encode_modalities(self, x):
        #print(x[0].shape, x[1].shape)
        #print(torch.mean(x[1], dim=-1).shape)
        #h_shared, h_specific = self.encode([x[0], torch.mean(x[1], dim=-1)])
        h_shared, h_specific = self.encode(x)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False

        for i, layer in enumerate(self.adaptive_layers):
            # if layer_ids is specified, only reduce rank for those layers
            if i not in layer_ids:
                continue
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)

            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())

            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            #new_rank = max(target_rank, ratio_rank)
            new_rank = target_rank
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank, dim=dim, which_dims=None)
                changes_made = True
                
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers by specified increment"""
        changes_made = False
        
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            #print(f"Increasing rank for layer {i}")
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)

def compute_direct_r_squared(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute R² based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of R² values for each modality
    """
    model.eval()
    r_squared_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate R² for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            # Move to CPU and flatten to (N, D)
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            # Align batch dimension if needed
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Debug: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for R² computation.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Calculate mean of original flattened data
            original_mean = orig_flat.mean(dim=0).cpu()
            original_cpu = orig_flat
            reconstruction_cpu = recon_flat
            
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean for modality {i}. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data for modality {i}. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                #print(f"   Computing R² normally for modality {i}.")
                # check if there are any NaNs or Infs in original or reconstruction
                if torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)) or \
                   torch.any(torch.isnan(reconstruction_cpu)) or torch.any(torch.isinf(reconstruction_cpu)):
                    if verbose:
                        print(f"   Warning: NaN or Inf values found in data for modality {i}. Handling them.")
                    # Handle NaN or Inf values
                    valid_mask = ~torch.isnan(original_cpu) & ~torch.isinf(original_cpu) & \
                                 ~torch.isnan(reconstruction_cpu) & ~torch.isinf(reconstruction_cpu)
                    if valid_mask.sum() == 0:
                        # If no valid values, set R² to 0
                        r_squared = torch.tensor(0.0)
                    else:
                        valid_indices = valid_mask
                        ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                        ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                else:
                    # Normal case - calculate standard R²
                    # print mean original and mean reconstruction
                    #print(f"   Mean original (modality {i}): {original_cpu.mean().item()}, Mean reconstruction: {reconstruction_cpu.mean().item()}")
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)
                    # if there are any very small ss_tot values, print a warning
                    if torch.any(ss_tot < 1e-3):
                        if verbose:
                            print(f"   Warning: Very small ss_tot values found for modality {i}. This may lead to unstable R² values.")
                        valid_mask = ss_tot >= 1e-3
                        if valid_mask.sum() == 0:
                            r_squared = torch.tensor(0.0)
                        else:
                            ssr = ssr[valid_mask]
                            ss_tot = ss_tot[valid_mask]
                #print(f"   SSR sum: {ssr.sum().item()}, SSTot sum: {ss_tot.sum().item()}")
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                # mask out the ones that are negative
                # if there are any negative r_squared values, print a warning
                if torch.any(r_squared < 0):
                    if verbose: # print the fraction of negative values
                        n_negative = (r_squared < 0).sum().item()
                        total = r_squared.numel()
                        fraction_negative = n_negative / total
                        print(f"   Warning: {fraction_negative:.2%} Negative R² values detected for modality {i}. This may indicate poor reconstruction.")
                    r_squared = torch.clamp(r_squared, min=0.0)
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Handle negative R² values (poor fit) - use correlation as fallback
            """
            if r_squared < 0:
                print(f"   Warning: Negative R² value detected for modality {i}: {r_squared}. Using correlation instead.")
                # Use correlation as a fallback
                correlation_matrix = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))
                r_squared = correlation_matrix[0, 1]
                if torch.isnan(r_squared):
                    r_squared = torch.tensor(0.0)
                else:
                    # Square the correlation to get R²-like measure
                    r_squared = r_squared ** 2
            """
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            
            r_squared_values.append(r_squared.item())
    
    return r_squared_values

def compute_direct_explained_variance(model, data, device, multi_gpu=False, verbose=False):
    """
    Compute explained variance score based on direct model reconstruction performance
    
    Parameters:
    - model: The trained model
    - data: Input data list [modality1, modality2, ...]
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - List of explained variance values for each modality
    """
    model.eval()
    explained_variance_values = []
    
    with torch.no_grad():
        # Get model predictions
        data_tensors = [d.to(device) for d in data]
        reconstructions, _ = model(data_tensors)
        
        # Calculate explained variance for each modality
        for i, (original, reconstruction) in enumerate(zip(data_tensors, reconstructions)):
            original_cpu = original.cpu()
            reconstruction_cpu = reconstruction.cpu()

            # Flatten originals and reconstructions to (N, D)
            try:
                orig_flat = original_cpu.view(original_cpu.shape[0], -1)
            except Exception:
                orig_flat = original_cpu.reshape(original_cpu.size(0), -1)
            try:
                recon_flat = reconstruction_cpu.view(reconstruction_cpu.shape[0], -1)
            except Exception:
                recon_flat = reconstruction_cpu.reshape(reconstruction_cpu.size(0), -1)

            if verbose:
                print(f"   Debug: modality {i} original shape {original_cpu.shape} -> flat {orig_flat.shape}; recon shape {reconstruction_cpu.shape} -> flat {recon_flat.shape}")

            # Align batch dimension
            if orig_flat.shape[0] != recon_flat.shape[0]:
                n_min = min(orig_flat.shape[0], recon_flat.shape[0])
                orig_flat = orig_flat[:n_min]
                recon_flat = recon_flat[:n_min]

            # Align feature dimension by truncation if necessary
            if orig_flat.shape[1] != recon_flat.shape[1]:
                min_feat = min(orig_flat.shape[1], recon_flat.shape[1])
                if verbose:
                    print(f"   Warning: modality {i} feature size mismatch (orig={orig_flat.shape[1]}, recon={recon_flat.shape[1]}). Truncating to {min_feat} features for explained variance.")
                orig_flat = orig_flat[:, :min_feat]
                recon_flat = recon_flat[:, :min_feat]

            # Handle NaN or Inf values
            if torch.any(torch.isnan(orig_flat)) or torch.any(torch.isinf(orig_flat)) or \
               torch.any(torch.isnan(recon_flat)) or torch.any(torch.isinf(recon_flat)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in flattened data for modality {i}. Handling them.")
                valid_mask = ~torch.isnan(orig_flat) & ~torch.isinf(orig_flat) & \
                             ~torch.isnan(recon_flat) & ~torch.isinf(recon_flat)
                if valid_mask.sum() == 0:
                    explained_variance = torch.tensor(0.0)
                else:
                    original_valid = orig_flat[valid_mask]
                    reconstruction_valid = recon_flat[valid_mask]
                    explained_variance = 1 - (torch.var(reconstruction_valid - original_valid) / (torch.var(original_valid) + 1e-9))
            else:
                # Normal case - calculate standard explained variance on flattened data
                explained_variance = 1 - (torch.var(recon_flat - orig_flat) / (torch.var(orig_flat) + 1e-9))

            # Handle negative explained variance values (poor fit)
            if isinstance(explained_variance, torch.Tensor):
                if explained_variance < 0:
                    if verbose:
                        print(f"   Warning: Negative explained variance value detected for modality {i}: {explained_variance}. Clamping to 0.")
                    explained_variance = torch.clamp(explained_variance, min=0.0)
            else:
                if explained_variance < 0:
                    explained_variance = 0.0

            # Ensure explained_variance is a scalar tensor
            if not isinstance(explained_variance, torch.Tensor):
                explained_variance = torch.tensor(explained_variance)

            explained_variance_values.append(explained_variance.item())
    
    return explained_variance_values

import os

def pretrain_overcomplete_ae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True, recon_loss_balancing=False, paired=False, lr_schedule=None,
                         full_spectrum=False):
    """
    Train an autoencoder for mm_sim pretraining (no rank reduction).
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank
    - patience: Early stopping patience
    - verbose: Print progress
    - recon_loss_balancing: Adaptive loss balancing across modalities
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    mask = None
    
    # Create model with adaptive rank reduction
    # Only rank-based AE supported
    input_dims = (784, 112*112 if full_spectrum else 112)  # 112x112=12544 for full spectrum, 112 for averaged
    #input_dims = [d.shape[1] for d in data]
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    model = AdaptiveRankReducedAE_AvMnist(
        input_dims, latent_dims, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank, full_spectrum=full_spectrum
    ).to(device)
    print(model)
    # print the device the model is on
    print(f"Model is on device: {next(model.parameters()).device}")
    # print the number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params} parameters")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler if requested
    scheduler = None
    if lr_schedule == 'linear':
        try:
            # Use LinearLR when available (PyTorch >= 1.11)
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=1000)
        except Exception:
            # Fallback to LambdaLR for older PyTorch versions
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
    elif lr_schedule == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 1000], gamma=0.1)

    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    if paired:
        data_indices = torch.randperm(data[0].shape[0])
        train_indices = data_indices[:n_samples_train]
        val_indices = data_indices[n_samples_train:]
    else:
        data_indices = None
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = AVMNISTDataset(train_data, full_spectrum=full_spectrum)
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = AVMNISTDataset(val_data, full_spectrum=full_spectrum)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    #loss_scales[1] = 0.1
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, x in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Ensure target and prediction have matching shapes
                if i == 1 and full_spectrum:
                    # For audio spectrograms, ensure channel dim: (batch, 112, 112) -> (batch, 1, 112, 112)
                    if x_m.dim() == 3:
                        x_m_reshaped = x_m.unsqueeze(1)
                    else:
                        x_m_reshaped = x_m
                    # Ensure prediction also has channel dim
                    if x_hat_m.dim() == 2:
                        # Some decoders may still output flattened vectors; reshape to image
                        x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                else:
                    x_m_reshaped = x_m

                # Compute BCE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                else:
                    m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        # Ortho loss is not used in pretraining
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    # Ensure target and prediction have matching shapes
                    if i == 1 and full_spectrum:
                        if x_m.dim() == 3:
                            x_m_reshaped = x_m.unsqueeze(1)
                        else:
                            x_m_reshaped = x_m
                        if x_hat_m.dim() == 2:
                            x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                    else:
                        x_m_reshaped = x_m
                    
                    if modality_masks[i] is not None:
                        m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                    else:
                        m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            try:
                scheduler.step()
            except Exception:
                # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
                pass

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1

        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
        
        # Update progress bar
        # include current learning rate in progress bar
        if scheduler is not None:
            current_lr = scheduler.get_last_lr()[0]
            pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.2e}")
        else:
            try:
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.2e}")
            except Exception:
                pbar.set_description(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
    
    return model, [train_losses, val_losses], data_indices

def train_overcomplete_ae_with_pretrained(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, paired=False, lr_schedule=None,
                         decision_metric='R2', full_spectrum=False
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    - compressibility_type: 'linear' (linear probing R²) or 'direct' (reconstruction R²)
    - reduction_criterion: 'r_squared' (use R²), 'train_loss' (use training loss), 'val_loss' (use validation loss)
                          Only relevant when compressibility_type='direct'. Default: 'r_squared'
    - recon_loss_balancing: Whether to apply adaptive loss balancing across modalities (default: False)
    """

    # determine whether multi-GPU mode is requested (make available early)
    multi_gpu = getattr(args, 'multi_gpu', False)

    # check if there is an existing pretrained model for the seed, early stopping, and training hyperparameters (lr, wd, batch size, model architecture)
    pretrained_model_path = f"./03_results/models/pretrained_models/{pretrained_name}.pt" if pretrained_name else None
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        #input_dims = [d.shape[1] for d in data]
        input_dims = (784, 112*112 if full_spectrum else 112)
        if isinstance(latent_dim, int):
            latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
        elif isinstance(latent_dim, list):
            if (len(latent_dim) == 1) & (len(input_dims) > 1):
                latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
            else:
                latent_dims = latent_dim
        model = AdaptiveRankReducedAE_AvMnist(
            input_dims, latent_dims, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank, full_spectrum=full_spectrum
        )
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
        # make sure that the weights are changed
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        # also load the loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        # print last losses
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
        # If reconstruction plots do not exist, create and save example reconstructions
        try:
            image_plot_path = pretrained_model_path.replace('.pt', '_image_recon.png')
            audio_plot_path = pretrained_model_path.replace('.pt', '_audio_recon.png')
            if (not os.path.exists(image_plot_path)) or (not os.path.exists(audio_plot_path)):
                # ensure model parameters are on the same device as inputs
                try:
                    model = model.to(device)
                except Exception:
                    pass
                model.eval()
                with torch.no_grad():
                    # take a small sample from the provided data to generate example reconstructions
                    n_plot = min(8, getattr(data[0], '__len__', lambda: data[0].shape[0])() if hasattr(data[0], '__len__') else data[0].shape[0])
                    rng = np.random.default_rng(seed=42)
                    sample_idx = rng.choice(data[0].shape[0], size=n_plot, replace=False)
                    # prepare images
                    imgs = data[0][sample_idx]
                    imgs_t = torch.FloatTensor(imgs).to(device)
                    # prepare audio: if audio is 3D (e.g. spectrogram), average last dim like dataset does (unless full_spectrum)
                    aud_arr = np.array(data[1][sample_idx])
                    if aud_arr.ndim > 2 and not full_spectrum:
                        aud_arr = aud_arr.mean(axis=1)
                    elif aud_arr.ndim > 2 and full_spectrum:
                        # Flatten the spectrogram for full spectrum mode
                        aud_arr = aud_arr.reshape(aud_arr.shape[0], -1)
                    aud_t = torch.FloatTensor(aud_arr).to(device)

                    # Use module if DataParallel wrapping already present
                    if multi_gpu and hasattr(model, 'module'):
                        reconstructions, _ = model.module([imgs_t, aud_t])
                    else:
                        reconstructions, _ = model([imgs_t, aud_t])

                    # Save plots
                    plot_image_reconstruction(imgs_t.cpu().numpy(), reconstructions[0].cpu().numpy(), n=n_plot, out_path=image_plot_path)
                    # For audio: if full_spectrum, plot as images; otherwise use PCA scatter
                    if full_spectrum:
                        # reconstructions[1] may be flattened vectors or image tensors
                        recon_audio_np = reconstructions[1].cpu().numpy()
                        # If recon is flattened, try to reshape to (N, 1, 112, 112) or (N, 112, 112)
                        if recon_audio_np.ndim == 2 and recon_audio_np.shape[1] == 112*112:
                            recon_audio_img = recon_audio_np.reshape(-1, 112*112)
                        else:
                            recon_audio_img = recon_audio_np
                        # original audio images available in aud_t before flattening; convert to numpy
                        orig_audio_img = aud_t.cpu().numpy()
                        plot_modal_image_reconstruction(orig_audio_img, recon_audio_img, image_shape=(112,112), n=n_plot, out_path=audio_plot_path)
                    else:
                        plot_audio_scatter(aud_t.cpu().numpy(), reconstructions[1].cpu().numpy(), out_path=audio_plot_path)
                print(f"Saved reconstruction plots to {image_plot_path} and {audio_plot_path}")
        except Exception as e:
            print(f"Warning: Could not save reconstruction plots: {e}")
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            model, [train_losses, val_losses], data_indices = pretrain_overcomplete_ae(
                data, n_samples_train, latent_dim, device, args, epochs=int(epochs/2), early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, ae_width=ae_width, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank, lr_schedule=lr_schedule,
                verbose=verbose, full_spectrum=full_spectrum
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            # Also save a PNG of the pretraining loss curves (train & val)
            try:
                loss_png_path = pretrained_model_path.replace('.pt', '_pretrain_loss_curve.png')
                plt.figure(figsize=(6, 4))
                plt.plot(np.arange(1, len(train_losses) + 1), train_losses, label='train')
                plt.plot(np.arange(1, len(val_losses) + 1), val_losses, label='val')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Pretraining Loss Curves')
                plt.legend()
                plt.tight_layout()
                plt.savefig(loss_png_path, dpi=150)
                plt.close()
            except Exception as e:
                print(f"Warning: Could not save pretraining loss PNG: {e}")
            # also save data_indices
            if data_indices is not None:
                data_indices_path = pretrained_model_path.replace('.pt', '_data_indices.pt')
                torch.save(data_indices, data_indices_path)
            # Save example plots for reconstruction
            try:
                saved = False
                with torch.no_grad():
                    # Determine a small batch size for plotting
                    n_plot = min(8, data[0].shape[0] if hasattr(data[0], 'shape') else (len(data[0]) if hasattr(data[0], '__len__') else 8))
                    rng = np.random.default_rng(seed=42)
                    sample_idx = rng.choice(data[0].shape[0], size=n_plot, replace=False)

                    # Prepare image tensor
                    imgs = data[0][sample_idx]
                    if not torch.is_tensor(imgs):
                        imgs_t = torch.FloatTensor(np.array(imgs)).to(device)
                    else:
                        imgs_t = imgs.to(device)

                    # Prepare audio tensor (handle spectrograms by averaging last dim unless full_spectrum)
                    aud = data[1][sample_idx]
                    if isinstance(aud, np.ndarray):
                        aud_arr = aud
                        if aud_arr.ndim > 2 and not full_spectrum:
                            aud_arr = aud_arr.mean(axis=1)
                        elif aud_arr.ndim > 2 and full_spectrum:
                            aud_arr = aud_arr.reshape(aud_arr.shape[0], -1)
                        aud_t = torch.FloatTensor(aud_arr).to(device)
                    elif torch.is_tensor(aud):
                        if aud.dim() > 2 and not full_spectrum:
                            aud_t = aud.mean(dim=1).to(device)
                        elif aud.dim() > 2 and full_spectrum:
                            aud_t = aud.view(aud.shape[0], -1).to(device)
                        else:
                            aud_t = aud.to(device)
                    else:
                        aud_t = torch.FloatTensor(np.array(aud)).to(device)

                    last_batch_data = [imgs_t, aud_t]

                    # Run through model (handle DataParallel)
                    if multi_gpu and hasattr(model, 'module'):
                        model.module.eval()
                        reconstructions, _ = model.module(last_batch_data)
                    else:
                        model.eval()
                        reconstructions, _ = model(last_batch_data)

                    # Plot image reconstructions
                    plot_image_reconstruction(
                        last_batch_data[0].cpu().numpy(), 
                        reconstructions[0].cpu().numpy(),
                        n=n_plot,
                        out_path=pretrained_model_path.replace('.pt', '_image_recon.png')
                    )
                    # Plot audio: image reconstruction if full_spectrum, else PCA scatter
                    if full_spectrum:
                        plot_modal_image_reconstruction(
                            last_batch_data[1].cpu().numpy(),
                            reconstructions[1].cpu().numpy(),
                            image_shape=(112,112), n=n_plot,
                            out_path=pretrained_model_path.replace('.pt', '_audio_recon.png')
                        )
                    else:
                        plot_audio_scatter(
                            last_batch_data[1].cpu().numpy(),
                            reconstructions[1].cpu().numpy(),
                            out_path=pretrained_model_path.replace('.pt', '_audio_recon.png')
                        )
                    saved = True

                if saved:
                    print(f"Saved reconstruction plots to {pretrained_model_path.replace('.pt', '_image_recon.png')} and {pretrained_model_path.replace('.pt', '_audio_recon.png')}")
            except Exception as e:
                print(f"Warning: Could not save reconstruction plots: {e}")
            
            model.epoch = len(train_losses)
        else:
            raise ValueError("model_name must be provided to save/load pretrained models.")
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Setup learning rate scheduler if requested
    scheduler = None
    if lr_schedule == 'linear':
        try:
            # Use LinearLR when available (PyTorch >= 1.11)
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
        except Exception:
            # Fallback to LambdaLR for older PyTorch versions
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
    elif lr_schedule == 'step':
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)

    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    if paired:
        # load the saved data_indices if it exists
        data_indices_path = f"./03_results/models/pretrained_models/{pretrained_name}_data_indices.pt"
        data_indices = torch.load(data_indices_path)
        #data_indices = torch.randperm(data[0].shape[0])
        train_indices = data_indices[:n_samples_train]
        val_indices = data_indices[n_samples_train:]
    else:
        print("Using non-paired data splitting")
        # use the first n_samples_train samples for training and the rest for validation
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = AVMNISTDataset(train_data, full_spectrum=full_spectrum)
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = AVMNISTDataset(val_data, full_spectrum=full_spectrum)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_squares = [None] * len(data) # per modality
    initial_losses = [None] * len(data) # per modality (for loss-based criteria)
    start_reduction = False
    current_rsquare_per_mod = [None] * len(data)
    current_loss_per_mod = [None] * len(data)  # for loss-based criteria
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    #loss_scales[1] = 0.1
    initial_losses = torch.zeros(len(data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    mask = None
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, x in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Ensure target and prediction have matching shapes
                if i == 1 and full_spectrum:
                    # For audio in full spectrum mode, ensure channel dim on target
                    if x_m.dim() == 3:
                        x_m_reshaped = x_m.unsqueeze(1)
                    else:
                        x_m_reshaped = x_m
                    # Ensure prediction also has channel dim
                    if x_hat_m.dim() == 2:
                        x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                else:
                    x_m_reshaped = x_m
                
                # Compute BCE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                else:
                    m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()

            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    # Ensure target and prediction have matching shapes
                    if i == 1 and full_spectrum:
                        if x_m.dim() == 3:
                            x_m_reshaped = x_m.unsqueeze(1)
                        else:
                            x_m_reshaped = x_m
                        if x_hat_m.dim() == 2:
                            x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                    else:
                        x_m_reshaped = x_m
                    
                    if modality_masks[i] is not None:
                        m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                    else:
                        m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Step the scheduler once per epoch (if configured)
        if scheduler is not None:
            try:
                scheduler.step()
            except Exception:
                # scheduler.step() may expect different inputs for some schedulers; ignore failures to remain conservative
                pass

        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(data))],
            'patience': patience_counter,
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            # Show the balance scales for reconstruction losses
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1

        if (start_reduction is False) and (epoch == model.epoch + patience): # giving the optimizer a start
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            #break
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0 # start with no breaks (only used when increasing layers)

            #with torch.no_grad():
            #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_modality = model.encode_modalities(val_data_tensors)
            #    #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
            #    encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
            #    encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
            min_rsquares = []
            mask = val_data.mask
            modality_masks_latent = []
            modality_masks_data = []
            modality_masks_space = []
            #if mask is not None:
            #    start_idx = 0
            #    for j, x_m in enumerate(x_val):
            #        end_idx = start_idx + x_m.shape[1]
            #        temp_mask = mask[:, start_idx]
            #        # expand it to match the encoded shape
            #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
            #        modality_masks_latent.append(temp_mask)
            #        modality_masks_data.append(mask[:, start_idx:end_idx])
            #        modality_masks_space.append(mask[:, start_idx])
            #        start_idx = end_idx
            
            # Direct reconstruction R² approach (original behavior)
            #val_data_list = [val_data.data[i] for i in range(len(data))]
            #val_data_list = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
            # Build a small validation subset (10% of training samples)
            n_sub = int(0.1 * train_data.data[0].shape[0])
            audio_subset = train_data.data[1][:n_sub]

            # If the model expects full spectrum audio, prefer the original input data's full-spectrum
            # representation (if available in `data`). Otherwise, convert/flatten as best-effort.
            try:
                model_full_spectrum = getattr(model, 'full_spectrum', False)
            except Exception:
                model_full_spectrum = full_spectrum

            if model_full_spectrum:
                # Prefer original full-spectrum audio from the `data` argument (before dataset wrapping)
                try:
                    orig_audio_all = data[1]
                    orig_audio_subset = orig_audio_all[train_indices][:n_sub]
                except Exception:
                    orig_audio_subset = None

                if orig_audio_subset is not None and orig_audio_subset.dim() > 2:
                    # Flatten to (N, H*W)
                    audio_subset = orig_audio_subset.view(orig_audio_subset.shape[0], -1)
                else:
                    # Fall back: if current audio_subset has spatial dims, flatten; otherwise warn
                    if audio_subset.dim() > 2:
                        audio_subset = audio_subset.view(audio_subset.shape[0], -1)
                    else:
                        if verbose:
                            print("   Warning: model expects full-spectrum audio but only averaged audio available for the subset.")
                        # Leave as-is (will likely mismatch) and let the metric function handle truncation
            else:
                # Model expects averaged audio (1D). If we have 2D spectrograms, average along freq axis
                if not full_spectrum and audio_subset.dim() > 2:
                    audio_subset = torch.mean(audio_subset, dim=1)

            # Move tensors to device and build val list
            val_data_list = [train_data.data[0][:n_sub].to(device), audio_subset.to(device)]

            if verbose:
                print(f"   Debug: val_data_list[0] shape (images) = {val_data_list[0].shape}")
                print(f"   Debug: val_data_list[1] shape (audio)  = {val_data_list[1].shape}")

            # Also query model recon shapes once for debugging (no side effects)
            if verbose:
                try:
                    with torch.no_grad():
                        recon_debug, _ = model([val_data_list[0], val_data_list[1]])
                        recon_shapes = [r.shape for r in recon_debug]
                        print(f"   Debug: model reconstructions shapes = {recon_shapes}")
                except Exception as e:
                    print(f"   Debug: could not run model for recon shape check: {e}")
            if decision_metric == 'ExVarScore':
                direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu, verbose=verbose)
            else:  # Default to R2
                direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu, verbose=verbose)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                # Calculate threshold based on threshold_type
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                    
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
                
            if verbose:
                #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(current_rsquare_per_mod))]}, setting {threshold_type} thresholds to {min_rsquares}")
            #print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(encoded_per_modality))]}, setting {threshold_type} thresholds to {min_rsquares}")
            
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                ###
                # get the r_square values per modality
                ###
                #with torch.no_grad():
                #    val_data_tensors = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
                #    #encoded_per_modality = model.encode_modalities([d[n_samples_train:].to(device) for d in data])
                #    #encoded_per_modality = model.encode_modalities([val_data.data[i].to(device) for i in range(len(data))])
                #    encoded_per_modality = model.encode_modalities(val_data_tensors)
                #    if not compute_jacobian:
                #        #encoded_per_space_shared, encoded_per_space_specific = model.encode([val_data.data[i].to(device) for i in range(len(data))])
                #        encoded_per_space_shared, encoded_per_space_specific = model.encode(val_data_tensors)
                #        encoded_per_space = [encoded_per_space_shared] + list(encoded_per_space_specific)
                #    else:
                #        encoded_per_space, contractive_losses = model.encode([val_data.data[i].to(device) for i in range(len(data))], compute_jacobian=compute_jacobian)
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                mask = val_data.mask
                modality_masks_data = []
                modality_masks_latent = []
                modality_masks_space = []
                #if mask is not None:
                #    start_idx = 0
                #    for j, x_m in enumerate(x_val):
                #        end_idx = start_idx + x_m.shape[1]
                #        #modality_masks.append(mask[:, start_idx:end_idx])
                #        # expand it to match the encoded shape
                #        temp_mask = mask[:, start_idx]
                #        temp_mask = temp_mask.unsqueeze(1).expand(-1, encoded_per_modality[j].shape[1])
                #        modality_masks_latent.append(temp_mask)
                #        modality_masks_data.append(mask[:, start_idx:end_idx])
                #        modality_masks_space.append(mask[:, start_idx])
                #        start_idx = end_idx
                
                # Direct reconstruction R² approach (original behavior)
                #val_data_list = [val_data.data[i] for i in range(len(data))]
                #val_data_list = [val_data.data[0].to(device), torch.mean(val_data.data[1], dim=1).to(device)]
                if not full_spectrum:
                    val_data_list = [train_data.data[0][:int(0.1 * train_data.data[0].shape[0])].to(device), torch.mean(train_data.data[1], dim=1)[:int(0.1 * train_data.data[0].shape[0])].to(device)]
                else:
                    # if full_spectrum, we need to flatten the audio data
                    audio_subset = train_data.data[1][:int(0.1 * train_data.data[0].shape[0])]
                    val_data_list = [train_data.data[0][:int(0.1 * train_data.data[0].shape[0])].to(device), audio_subset.to(device)]
                if decision_metric == 'ExVarScore':
                    direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu)
                else:  # Default to R2
                    direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                        
                r_squares.append(current_rsquares)
                max_rquares = [max(r_squares, key=lambda x: x[i])[i] for i in range(len(current_rsquare_per_mod))] if len(r_squares) > 0 else initial_squares
                if threshold_type == 'relative':
                    min_rsquares = [r * r_square_threshold for r in max_rquares]
                elif threshold_type == 'absolute':
                    min_rsquares = [r - r_square_threshold for r in max_rquares]


                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                            # if we are still above the thresholds, the ranks maximum should not be larger than the sum of current ranks
                        #model.adaptive_layers[i].max_rank = min(sum(current_ranks), model.adaptive_layers[i].max_rank)
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                    #bottom_reached = True
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    # set the min for the modality to be increased to current rank
                    model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #if sharedwhenall:
                            #    reduce_shared = False
                            #else:
                            #    reduce_shared = True
                            #if reduce_shared:
                            #    layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            #else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                if compute_jacobian:
                    valid_contractive_losses = [contractive_losses[i] for i in layers_to_reduce]
                    max_contractive_loss = max(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    #max_contractive_loss = min(valid_contractive_losses) if len(valid_contractive_losses) > 0 else None
                    if max_contractive_loss is not None:
                        layers_to_reduce = [i for i in layers_to_reduce if contractive_losses[i] == max_contractive_loss]
                    else:
                        layers_to_reduce = []
                    if verbose:
                        print(f"Contractive losses: {contractive_losses}, reducing rank for layers {layers_to_reduce} with max loss {max_contractive_loss}")
                
            
            #if should_reduce:
            any_changes_made = False
            #changes_made = False
            if len(layers_to_reduce) > 0:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            if len(layers_to_increase) > 0:
                # Apply rank increase
                #print(f"Increasing rank for layer {layers_to_increase}")
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience # give model more time to re-learn the added dimensions
            #else:
            #    changes_made = False
            
            if any_changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
            else:
                patience_counter += 1

            # Get new rank but don't print separate message
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            #for i in range(len(encoded_per_modality)):
            for i in range(len(current_rsquare_per_mod)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)

            # also get mutual information between all spaces
            #valid_spaces = []
            #for encoded in encoded_per_space:
            #    valid_spaces_temp = []
            #    for i in range(len(encoded_per_modality)):
            #        #normalized_encoded = (encoded - encoded.min() + 1e-9) / (encoded.max() - encoded.min() + 1e-9)
            #        if mask is not None:
            #            temp_mask = modality_masks_space[i]
            #            #valid_spaces_temp.append(normalized_encoded[mask])
            #            valid_spaces_temp.append(encoded[temp_mask])
            #        else:
            #            #valid_spaces_temp.append(normalized_encoded)
            #            valid_spaces_temp.append(encoded)
            #    # stack the valid spaces
            #    valid_spaces_temp = torch.vstack(valid_spaces_temp)
            #    valid_spaces.append(valid_spaces_temp)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # Get normalized weights for display
        if multi_gpu:
            weights = model.module.modality_weights
        else:
            weights = model.modality_weights
            
        pos_weights = F.softplus(weights)
        norm_weights = (pos_weights / (pos_weights.sum() + 1e-8)).detach().cpu().numpy().round(3)

        
        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if (epoch > early_stopping) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
    
    # Calculate latent representations in batches
    #'''
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            if not full_spectrum:
                x_batch[1] = torch.mean(x_batch[1], dim=1)
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)#.cpu()
            else:
                batch_reps = model.encode(x_batch)#.cpu()
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            # No need to convert dtype
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]

    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses]

def plot_image_reconstruction(original_imgs, recon_imgs, n=8, out_path=None):
    """Plot original and reconstructed images side-by-side.
    original_imgs and recon_imgs are numpy arrays of shape (N, 784) with values in [0,1].
    """
    orig = original_imgs[:n]
    recon = recon_imgs[:n]
    fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
    for i in range(n):
        axes[0, i].imshow(orig[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_modal_image_reconstruction(original_imgs, recon_imgs, image_shape=(28,28), n=8, out_path=None):
    """Plot original and reconstructed modality images for arbitrary image shapes.
    original_imgs and recon_imgs can be:
      - numpy arrays shaped (N, H*W)
      - numpy arrays shaped (N, H, W)
      - numpy arrays shaped (N, 1, H, W)
    image_shape: tuple (H, W)
    """
    H, W = image_shape
    # Normalize and reshape originals
    orig = original_imgs[:n]
    recon = recon_imgs[:n]

    def ensure_hw(arr):
        if arr.ndim == 4 and arr.shape[1] == 1:
            return arr[:, 0]
        if arr.ndim == 3:
            return arr
        if arr.ndim == 2:
            # assume flattened
            return arr.reshape(-1, H, W)
        raise ValueError('Unsupported array shape for image reconstruction plotting: ' + str(arr.shape))

    orig_hw = ensure_hw(orig)
    recon_hw = ensure_hw(recon)

    fig, axes = plt.subplots(2, n, figsize=(n * 2, 3))
    for i in range(n):
        axes[0, i].imshow(orig_hw[i], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon_hw[i], cmap='gray')
        axes[1, i].axis('off')
    plt.tight_layout()
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_audio_scatter(original_audio, recon_audio, out_path=None):
    """Create a scatter plot of original vs reconstructed audio features.
    original_audio and recon_audio are numpy arrays shape (N, D).
    """
    # Reduce to first two dims using PCA if dimensionality > 2
    from sklearn.decomposition import PCA
    if original_audio.shape[1] > 2:
        pca = PCA(n_components=2)
        orig_2 = pca.fit_transform(original_audio)
        recon_2 = pca.transform(recon_audio)
    else:
        orig_2 = original_audio[:, :2]
        recon_2 = recon_audio[:, :2]

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(orig_2[:, 0], orig_2[:, 1], s=4, alpha=0.6, label='original')
    ax.scatter(recon_2[:, 0], recon_2[:, 1], s=4, alpha=0.6, label='recon')
    ax.legend()
    ax.set_title('Audio original vs recon (PCA 2D)')
    if out_path:
        fig.savefig(out_path, dpi=150)
    plt.close(fig)


def posttrain_overcomplete_ae(model_path, data, n_samples_train, device, args, epochs=1000, early_stopping=50,
                            lr=1e-4, batch_size=128, wd=1e-5, patience=10, verbose=True, 
                            recon_loss_balancing=False, paired=False, lr_schedule=None, model_name=None,trained_ranks=None, full_spectrum=False):
    """
    Load a saved model and continue training without rank reduction (post-training).
    
    Parameters:
    - model_path: Path to the saved model state dict
    - data: Input data tensor list
    - n_samples_train: Number of samples to use for training
    - device: Device to run training on
    - args: Arguments object with training parameters
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate (typically lower than initial training)
    - batch_size: Batch size for training
    - wd: Weight decay
    - patience: Early stopping patience
    - verbose: Print progress
    - recon_loss_balancing: Adaptive loss balancing across modalities
    - paired: Whether data is paired
    - lr_schedule: Learning rate schedule ('linear' or None)
    - model_name: Name for saving the post-trained model
    """
    import os
    import torch.nn.functional as F
    import tqdm
    
    # Determine multi-GPU setup
    multi_gpu = getattr(args, 'multi_gpu', False)
    
    # Create model architecture (must exactly match original training parameters)
    input_dims = (784, 112*112 if full_spectrum else 112)  # 112x112=12544 for full spectrum, 112 for averaged
    
    # Use the EXACT same parameters as in 054_avmnist.py
    latent_dim = 100  # This was hardcoded in the original training
    
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1)  # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    
    # Use exact parameters from original training (054_avmnist.py)
    model = AdaptiveRankReducedAE_AvMnist(
        input_dims, latent_dims, 
        depth=2,        # ae_depth was 2
        width=0.5,      # ae_width was 0.5  
        dropout=0.0,    # dropout was 0.0
        initial_rank_ratio=1.0, 
        min_rank=1,
        full_spectrum=full_spectrum
    ).to(device)
    
    # Load the saved model state
    saved_state = torch.load(model_path, map_location=device)
    model.load_state_dict(saved_state)
    model.epoch = 0  # Reset epoch counter for post-training
    
    # Extract the exact ranks from the loaded model and fix them
    if trained_ranks is None:
        trained_ranks = [layer.active_dims for layer in model.adaptive_layers]
    
    # Manually set the ranks to exactly match the trained model (freeze architecture)
    for i, layer in enumerate(model.adaptive_layers):
        layer.active_dims = trained_ranks[i]
        # Ensure the weight matrix matches the active dimensions
        #if hasattr(layer, 'weight') and layer.weight.size(0) != trained_ranks[i]:
        #    # This shouldn't happen if loading correctly, but safety check
        #    print(f"Warning: Layer {i} weight size mismatch. Expected {trained_ranks[i]}, got {layer.weight.size(0)}")

    if verbose:
        print(f"Loaded model from {model_path}")
        print(f"Model is on device: {next(model.parameters()).device}")
        print(f"Fixed model ranks to: {trained_ranks}")
        print(f"Total rank: {sum(trained_ranks)}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
        
        if num_gpus > 1:
            try:
                model = torch.nn.DataParallel(model)
                if verbose:
                    print(f"Using {num_gpus} GPUs for post-training")
            except Exception as e:
                print(f"Failed to use DataParallel: {e}")
                multi_gpu = False
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Setup learning rate scheduler if requested
    scheduler = None
    if lr_schedule == 'linear':
        from torch.optim.lr_scheduler import LinearLR
        scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00001, total_iters=epochs)
    elif lr_schedule == 'exponential':
        from torch.optim.lr_scheduler import ExponentialLR
        scheduler = ExponentialLR(optimizer, gamma=0.99)  # Decay
    elif lr_schedule == 'cyclic':
        from torch.optim.lr_scheduler import CyclicLR
        scheduler = CyclicLR(optimizer, base_lr=lr/1000, max_lr=lr, step_size_up=100, mode='triangular2')
    
    # Create data loaders
    n_samples = data[0].shape[0]
    indices = torch.randperm(n_samples)
    
    if paired:
        train_indices = indices[:n_samples_train]
        val_indices = indices[n_samples_train:]
    else:
        train_indices = indices[:n_samples_train]
        val_indices = indices[n_samples_train:]
    
    train_data = [d[train_indices] for d in data]
    train_data = AVMNISTDataset(train_data, full_spectrum=full_spectrum)
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    
    val_data = [data[i][val_indices] for i in range(len(data))]
    val_data = AVMNISTDataset(val_data, full_spectrum=full_spectrum)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    
    # Training loop
    train_losses = []
    val_losses = []
    best_loss = float('inf')
    patience_counter = 0
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize plotting variables
    last_batch_data = None
    last_batch_labels = None
    plot_save_dir = f"./03_results/plots/temp_latent_plots/{model_name}_posttrain" if model_name else "./03_results/plots/temp_latent_plots/posttrain"
    os.makedirs(plot_save_dir, exist_ok=True)
    
    if verbose:
        print(f"Starting post-training for up to {epochs} epochs with early stopping after {early_stopping} epochs")
        print(f"Plots will be saved to: {plot_save_dir}")
    
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        model.train()
        epoch_loss = 0.0
        modality_losses = torch.zeros(len(data), device=device)
        
        for batch_idx, x_batch in enumerate(data_loader):
            x_batch = [x.to(device) for x in x_batch]
            
            # Store last batch for plotting
            if batch_idx == 0:  # Store first batch of each epoch
                last_batch_data = [x_m.clone() for x_m in x_batch]
                # For AVMNIST, we don't have explicit labels in the batch, set to None
                last_batch_labels = None
            
            optimizer.zero_grad()
            
            # Forward pass
            if multi_gpu:
                x_recon = model.module.forward(x_batch)
            else:
                x_recon = model.forward(x_batch)
            
            # Compute reconstruction losses for each modality
            losses = []
            for i in range(len(data)):
                if i == 0:  # Image modality - reshape for CNN
                    x_orig = x_batch[i].view(-1, 1, 28, 28)
                    x_rec = x_recon[0][i].view(-1, 1, 28, 28)
                else:  # Audio modality
                    x_orig = x_batch[i]
                    x_rec = x_recon[0][i]
                
                loss = F.binary_cross_entropy(x_rec, x_orig, reduction='mean')
                losses.append(loss)
                modality_losses[i] += loss.item()
            
            # Apply loss balancing if requested
            if recon_loss_balancing:
                # Simple dynamic loss balancing
                total_loss = sum(loss * loss_scales[i] for i, loss in enumerate(losses))
            else:
                total_loss = sum(losses)
            
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
        
        # Update learning rate scheduler
        if scheduler is not None:
            scheduler.step()
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x.to(device) for x in x_val]
                #x_val[1] = torch.mean(x_val[1], dim=1)
                
                if multi_gpu:
                    x_val_recon = model.module.forward(x_val)
                else:
                    x_val_recon = model.forward(x_val)
                
                val_losses_batch = []
                for i in range(len(data)):
                    if i == 0:
                        x_orig = x_val[i].view(-1, 1, 28, 28)
                        x_rec = x_val_recon[0][i].view(-1, 1, 28, 28)
                    else:
                        x_orig = x_val[i]
                        x_rec = x_val_recon[0][i]
                    
                    loss = F.binary_cross_entropy(x_rec, x_orig, reduction='mean')
                    val_losses_batch.append(loss)
                
                val_loss += sum(val_losses_batch).item()
        
        train_losses.append(epoch_loss / len(data_loader))
        val_losses.append(val_loss / len(val_data_loader))
        
        # Update progress bar
        pbar.set_postfix({
            'train_loss': f'{train_losses[-1]:.4f}',
            'val_loss': f'{val_losses[-1]:.4f}',
            'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
        })
        
        # Plot training state every 10 epochs
        if epoch % 10 == 0 and last_batch_data is not None:
            try:
                from src.visualization.logging import plot_training_state
                plot_training_state(model, last_batch_data, last_batch_labels, epoch, 
                                  multi_gpu, plot_save_dir, device, verbose=False)
            except Exception as e:
                if verbose:
                    print(f"Warning: Could not create plot at epoch {epoch}: {e}")
        
        # Early stopping
        if val_losses[-1] < best_loss:
            best_loss = val_losses[-1]
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch > early_stopping and patience_counter >= patience:
            if verbose:
                print(f"\nEarly stopping at epoch {epoch} with best validation loss {best_loss:.4f}")
            break
    
    # Calculate latent representations in batches (same as original training)
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            
            # Only average if audio has more than 2 dimensions (to avoid double averaging)
            if len(x_batch[1].shape) > 2:
                x_batch[1] = torch.mean(x_batch[1], dim=1)  # Average audio features
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)
            else:
                batch_reps = model.encode(x_batch)
            
            # Store batch representations
            for j in range(len(reps)):
                if j == 0:
                    temp_rep = batch_reps[0]
                else:
                    temp_rep = batch_reps[1][j-1]
                relevant_dims = temp_rep[:, :final_ranks[j]]
                reps[j][i:end_idx] = relevant_dims

    # Save the post-trained model
    if model_name:
        model_path_out = f"./03_results/models/{model_name}_posttrained.pth"
        os.makedirs(os.path.dirname(model_path_out), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path_out)
        else:
            torch.save(model.state_dict(), model_path_out)
        if verbose:
            print(f"Saved post-trained model to {model_path_out}")
        
        # Save post-training loss curves as CSV
        loss_curves_path = f"./03_results/models/{model_name}_posttrained_loss_curves.csv"
        loss_df = pd.DataFrame({
            'train_loss': train_losses,
            'val_loss': val_losses
        })
        loss_df.to_csv(loss_curves_path, index=False)
        if verbose:
            print(f"Saved post-training loss curves to {loss_curves_path}")
        
        # Save the representations as numpy arrays (same format as original training)
        for i, rep in enumerate(reps):
            rep_save_path = f"./03_results/models/{model_name}_posttrained_rep{i}.npy"
            os.makedirs(os.path.dirname(rep_save_path), exist_ok=True)
            np.save(rep_save_path, rep.cpu().numpy())
            if verbose:
                print(f"Saved post-trained representation {i} to {rep_save_path}")

    # Create a reconstruction plot (original vs reconstruction) for a small set of samples
    try:
        recon_plot_dir = f"./03_results/plots"
        os.makedirs(recon_plot_dir, exist_ok=True)
        n_show = min(10, n_samples)
        # Prepare inputs
        with torch.no_grad():
            rng = np.random.default_rng(seed=42)
            sample_idx = rng.choice(data[0].shape[0], size=n_show, replace=False)
            imgs = data[0][sample_idx].to(device)
            audio = data[1][sample_idx].to(device)
            # Average audio if necessary (match training preprocessing)
            if len(audio.shape) > 2:
                audio = torch.mean(audio, dim=1)
            x_input = [imgs, audio]
            if multi_gpu:
                out = model.module.forward(x_input)
            else:
                out = model.forward(x_input)
            recon_list = out[0]
            recon_imgs = recon_list[0]

        # Convert to numpy images
        try:
            orig_np = imgs.cpu().numpy()
            if orig_np.ndim == 2 and orig_np.shape[1] == 784:
                orig_np = orig_np.reshape(-1, 28, 28)
            if orig_np.ndim == 4 and orig_np.shape[1] == 1:
                orig_np = orig_np[:, 0]
        except Exception:
            orig_np = imgs.cpu().numpy()

        try:
            recon_np = recon_imgs.cpu().numpy()
            if recon_np.ndim == 2 and recon_np.shape[1] == 784:
                recon_np = recon_np.reshape(-1, 28, 28)
            if recon_np.ndim == 4 and recon_np.shape[1] == 1:
                recon_np = recon_np[:, 0]
        except Exception:
            recon_np = recon_imgs.cpu().numpy()

        cols = n_show
        fig, axes = plt.subplots(2, cols, figsize=(cols * 1.6, 3.2))
        axes = np.array(axes).reshape(2, cols)
        for i in range(n_show):
            axes[0, i].imshow(orig_np[i], cmap='gray')
            axes[0, i].axis('off')
            axes[1, i].imshow(recon_np[i], cmap='gray')
            axes[1, i].axis('off')
        plt.tight_layout()
        recon_plot_path = os.path.join(recon_plot_dir, f"{model_name}_posttrain_reconstruction.png")
        fig.savefig(recon_plot_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        if verbose:
            print(f"Saved post-training reconstruction plot to: {recon_plot_path}")
    except Exception as e:
        if verbose:
            print(f"Warning: could not create post-training reconstruction plot: {e}")
    
    # Final plotting and movie creation
    if last_batch_data is not None:
        try:
            from src.visualization.logging import plot_training_state, create_training_movie
            
            # Final plot
            plot_training_state(model, last_batch_data, last_batch_labels, epochs-1, 
                              multi_gpu, plot_save_dir, device, verbose=verbose)
            
            # Create training movie
            create_training_movie(plot_save_dir)
            if verbose:
                print(f"Training movie created from plots in {plot_save_dir}")
        except Exception as e:
            if verbose:
                print(f"Warning: Could not create final plot or movie: {e}")
    
    return model, train_losses, val_losses