import torch
import torch.nn as nn
import numpy as np
import tqdm
from torch.utils.data import TensorDataset, DataLoader


class ImageAutoencoder(nn.Module):
    """
    Unimodal CNN-based autoencoder for MNIST images.
    Architecture matches the image encoder/decoder from the multimodal model.
    """
    def __init__(self, latent_dim=100):
        super(ImageAutoencoder, self).__init__()
        
        # CNN Encoder (same as multimodal)
        self.channel_size = 16
        self.n_pixel_after_conv = 4 * 4 * self.channel_size * 4
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, self.channel_size, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.channel_size, self.channel_size*4, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(self.n_pixel_after_conv, latent_dim*4),
            nn.ReLU(),
            nn.Linear(latent_dim*4, latent_dim),
            #nn.Sigmoid()
        )
        
        # CNN Decoder (same as multimodal)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 2*latent_dim),
            nn.ReLU(),
            nn.Linear(2*latent_dim, self.n_pixel_after_conv),
            nn.ReLU(),
            nn.Unflatten(1, (self.channel_size*4, 4, 4)),
            nn.ConvTranspose2d(self.channel_size*4, self.channel_size, kernel_size=5, stride=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(self.channel_size, 1, kernel_size=5, stride=2, output_padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Input: (batch, 784) or (batch, 1, 28, 28)
        if x.dim() == 2:
            x = x.view(-1, 1, 28, 28)
        elif x.dim() == 3:
            x = x.unsqueeze(1)
        
        h = self.encoder(x)
        x_hat = self.decoder(h)
        x_hat = x_hat.view(-1, 784)  # Flatten output
        return x_hat, h


class ResidualBlock(nn.Module):
    """
    Residual block for ResNet-based architecture.
    """
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out


class AudioAutoencoder(nn.Module):
    """
    ResNet-50 based autoencoder for audio spectrograms.
    Uses residual blocks for better feature extraction from spectrograms.
    Supports both averaged (112) and full spectrum (112x112) audio.
    """
    def __init__(self, latent_dim=100, full_spectrum=False):
        super(AudioAutoencoder, self).__init__()
        self.full_spectrum = full_spectrum
        
        if full_spectrum:
            # ResNet-inspired Encoder for full spectrum (112x112)
            # Initial convolution
            self.encoder = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                self._make_layer(64, 64, blocks=2),
                self._make_layer(64, 128, blocks=2, stride=2),
                self._make_layer(128, 256, blocks=2, stride=2),
                self._make_layer(256, 512, blocks=2, stride=2),
                nn.Flatten(),
                nn.Linear(4 * 4 * 512, latent_dim),
                nn.ReLU()
            )

            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, 4 * 4 * 512),
                nn.ReLU(),
                nn.Unflatten(1, (512, 4, 4)),
                nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(32, 1, kernel_size=7, stride=2, padding=3, output_padding=1),
                nn.Sigmoid()
            )
        else:
            # MLP Encoder for averaged audio (112)
            self.encoder = nn.Sequential(
                nn.Linear(112, latent_dim*4),
                nn.ReLU(),
                nn.Linear(latent_dim*4, latent_dim),
                nn.Sigmoid()
            )
            
            # MLP Decoder for averaged audio (112)
            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, 2*latent_dim),
                nn.ReLU(),
                nn.Linear(2*latent_dim, 112)
            )
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        """Create a residual layer with multiple blocks"""
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, downsample))
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def encode(self, x):
        """Encode input to latent space"""
        if self.full_spectrum:
            # Input: (batch, 112, 112) or (batch, 1, 112, 112)
            if x.dim() == 2:
                x = x.view(-1, 1, 112, 112)
            elif x.dim() == 3:
                x = x.unsqueeze(1)
            
            h = self.encoder(x)
        else:
            # Input: (batch, 112)
            if x.dim() > 2:
                x = x.mean(dim=-1)
            h = self.encoder(x)
        
        return h
    
    def decode(self, h):
        """Decode from latent space to reconstruction"""
        if self.full_spectrum:
            x = self.decoder(h)
            
            # Ensure exact output size (in case of rounding issues)
            if x.shape[2] != 112 or x.shape[3] != 112:
                x = nn.functional.interpolate(x, size=(112, 112), mode='bilinear', align_corners=False)
            
            return x
        else:
            return self.decoder(h)
    
    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        
        if self.full_spectrum:
            # Flatten output to match expected shape
            x_hat = x_hat.view(x_hat.size(0), -1)
        
        return x_hat, h


def train_unimodal_autoencoder(model, data, device, epochs=100, early_stopping=50,
                                lr=0.001, batch_size=128, wd=1e-5, patience=10,
                                verbose=True, num_workers=0, lr_min=1e-6):
    """
    Train a unimodal autoencoder with early stopping.
    
    Parameters:
    - model: The autoencoder model (ImageAutoencoder or AudioAutoencoder)
    - data: Input data tensor (N, features)
    - device: Device to run on
    - epochs: Maximum number of training epochs
    - early_stopping: Early stopping patience
    - lr: Learning rate
    - batch_size: Batch size
    - wd: Weight decay
    - patience: Patience for early stopping
    - verbose: Print training progress
    - num_workers: Number of workers for data loader
    - lr_min: Minimum learning rate for cosine annealing (default: 1e-6)
    
    Returns:
    - model: Trained model
    - train_losses: Training loss history
    - val_losses: Validation loss history
    """
    
    # Split data into train and validation
    n_samples = data.shape[0]
    n_train = int(0.9 * n_samples)
    train_indices = np.arange(n_train)
    val_indices = np.arange(n_train, n_samples)
    
    train_data = data[train_indices]
    val_data = data[val_indices]
    
    # Create data loaders
    train_dataset = TensorDataset(train_data)
    val_dataset = TensorDataset(val_data)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                             pin_memory=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                           pin_memory=True, num_workers=num_workers)
    
    # Setup optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr_min)
    criterion = nn.BCELoss()  # Use Binary Cross Entropy like multimodal training
    
    # Training loop
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    
    if verbose:
        print(f"\nTraining {model.__class__.__name__}")
        print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
        # Get latent dimension
        if hasattr(model, 'encoder_fc'):
            # ResNet-based audio encoder
            latent_dim = model.encoder_fc[-1].out_features
        elif hasattr(model.encoder, '__iter__'):
            # Sequential encoder
            for layer in reversed(list(model.encoder)):
                if hasattr(layer, 'out_features'):
                    latent_dim = layer.out_features
                    break
        else:
            latent_dim = 'unknown'
        print(f"Latent dim: {latent_dim}")
    
    pbar = tqdm.tqdm(range(epochs)) if verbose else range(epochs)
    
    for epoch in pbar:
        # Training
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            x = batch[0].to(device)
            
            optimizer.zero_grad()
            x_hat, _ = model(x)
            
            # Reshape x to match x_hat if needed
            if x_hat.shape != x.shape:
                if model.__class__.__name__ == 'ImageAutoencoder':
                    x = x.view(-1, 784)
                elif model.__class__.__name__ == 'AudioAutoencoder' and model.full_spectrum:
                    x = x.view(-1, 112 * 112)
            
            loss = criterion(x_hat, x)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * x.size(0)
        
        train_loss /= len(train_data)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x = batch[0].to(device)
                x_hat, _ = model(x)
                
                # Reshape x to match x_hat if needed
                if x_hat.shape != x.shape:
                    if model.__class__.__name__ == 'ImageAutoencoder':
                        x = x.view(-1, 784)
                    elif model.__class__.__name__ == 'AudioAutoencoder' and model.full_spectrum:
                        x = x.view(-1, 112 * 112)
                
                loss = criterion(x_hat, x)
                val_loss += loss.item() * x.size(0)
        
        val_loss /= len(val_data)
        val_losses.append(val_loss)
        
        # Step scheduler
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if verbose and isinstance(pbar, tqdm.tqdm):
            pbar.set_description(
                f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.6f} | "
                f"Val Loss: {val_loss:.6f} | Best: {best_val_loss:.6f} | "
                f"Patience: {patience_counter}/{patience} | LR: {current_lr:.2e}"
            )
        
        # Stop if patience exceeded
        if patience_counter >= patience:
            if verbose:
                print(f"\nEarly stopping at epoch {epoch+1}")
            break
    
    if verbose:
        print(f"Training complete. Best val loss: {best_val_loss:.6f}")
    
    return model, train_losses, val_losses


def train_multimodal_ae_with_pretrained(data, n_samples_train, latent_dim, device, args, 
                         pretrained_encoders=None, pretrained_decoders=None,
                         epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         threshold_type='relative', compressibility_type='linear', reduction_criterion='r_squared',
                         include_l1=False, l1_weight=0.0, include_ortholoss=False,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True, compute_jacobian=False, model_name=None, pretrained_name=None,
                         recon_loss_balancing=False, ortho_loss_balancing=False,
                         ortho_loss_start_weight=0.0, ortho_loss_end_weight=1.0, ortho_loss_anneal_epochs=None, ortho_loss_warmup=None,
                         l2_norm_adaptivelayers=None, sharedwhenall=True, paired=False, lr_schedule=None,
                         decision_metric='R2', full_spectrum=False, freeze_models=False):
    """
    Train a multimodal autoencoder with pretrained encoders and adaptive rank reduction.
    This is adapted from train_overcomplete_ae_with_pretrained but uses the 
    AdaptiveRankReducedAE_AvMnist_Pretrained model class.
    """
    import torch.nn.functional as F
    from src.functions.train_avmnist import (AVMNISTDataset, compute_direct_r_squared, 
                                              compute_direct_explained_variance,
                                              plot_image_reconstruction, plot_audio_scatter,
                                              plot_modal_image_reconstruction)
    
    # Determine whether multi-GPU mode is requested
    multi_gpu = getattr(args, 'multi_gpu', False)
    
    # Create model with pretrained components
    input_dims = (784, 112*112 if full_spectrum else 112)
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1)  # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    
    model = AdaptiveRankReducedAE_AvMnist_Pretrained(
        input_dims, latent_dims, 
        pretrained_encoders=pretrained_encoders,
        pretrained_decoders=pretrained_decoders,
        depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank, full_spectrum=full_spectrum
    ).to(device)
    
    #print(model)
    print(f"Model is on device: {next(model.parameters()).device}")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params} parameters")
    
    # Handle multi-GPU setup
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Setup learning rate scheduler
    scheduler = None
    if lr_schedule == 'linear':
        try:
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
        except Exception:
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
    elif lr_schedule == 'step':
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
    
    # Create data loaders
    if paired:
        data_indices_path = f"./03_results/models/pretrained_models/{pretrained_name}_data_indices.pt"
        data_indices = torch.load(data_indices_path)
        train_indices = data_indices[:n_samples_train]
        val_indices = data_indices[n_samples_train:]
    else:
        print("Using non-paired data splitting")
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    
    train_data = [d[train_indices] for d in data]
    train_data = AVMNISTDataset(train_data, full_spectrum=full_spectrum)
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    
    val_data = [data[i][val_indices] for i in range(len(data))]
    val_data = AVMNISTDataset(val_data, full_spectrum=full_spectrum)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule
    if rank_schedule is None:
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    
    initial_squares = [None] * len(data)
    initial_losses = [None] * len(data)
    start_reduction = False
    current_rsquare_per_mod = [None] * len(data)
    current_loss_per_mod = [None] * len(data)
    bottom_reached = False
    space_sims = None
    break_counter = 0
    
    # Training state
    train_losses = []
    val_losses = []
    r_squares = []
    min_ranks = [layer.active_dims for layer in model.adaptive_layers]
    best_loss = float('inf')
    
    loss_scales = torch.ones(len(data), device=device)
    initial_losses = torch.zeros(len(data), device=device)
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9
    
    patience_counter = 0
    initial_training_complete = False  # Track if initial early stopping phase is complete
    best_val_loss = float('inf')  # Track best validation loss for initial early stopping
    early_stopping_counter = 0  # Counter for initial early stopping
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    mask = None
    
    # Freeze pretrained encoders/decoders for initial warmup if they exist
    pretrained_params_frozen = False
    if pretrained_encoders is not None:
        if verbose:
            print(f"\nFreezing pretrained encoders and decoders for first {patience} epochs to train adaptive layers...")
        
        # Freeze encoders
        for encoder_module_list in model.encoders:
            for encoder in encoder_module_list:
                for param in encoder.parameters():
                    param.requires_grad = False
        
        # Freeze decoders (but create new ones, so this might not apply)
        # Only freeze if we're reusing pretrained decoders
        if pretrained_decoders is not None:
            for decoder_module_list in model.decoders:
                for decoder in decoder_module_list:
                    for param in decoder.parameters():
                        param.requires_grad = False
        
        pretrained_params_frozen = True
        
        # Recreate optimizer to exclude frozen parameters
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wd)
        
        # Recreate scheduler if it exists
        if lr_schedule == 'linear':
            try:
                scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
            except Exception:
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
        elif lr_schedule == 'step':
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
        
        if verbose:
            trainable_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
            frozen_params_count = sum(p.numel() for p in model.parameters() if not p.requires_grad)
            print(f"Trainable parameters: {trainable_params_count:,}")
            print(f"Frozen parameters: {frozen_params_count:,}")
    
    for epoch in pbar:
        # Unfreeze pretrained parameters after initial warmup period
        if (not freeze_models) and (pretrained_params_frozen and epoch >= patience):
            if verbose:
                print(f"\nUnfreezing pretrained encoders and decoders at epoch {epoch}...")
            
            # Unfreeze encoders
            for encoder_module_list in model.encoders:
                for encoder in encoder_module_list:
                    for param in encoder.parameters():
                        param.requires_grad = True
            
            # Unfreeze decoders if they were frozen
            if pretrained_decoders is not None:
                for decoder_module_list in model.decoders:
                    for decoder in decoder_module_list:
                        for param in decoder.parameters():
                            param.requires_grad = True
            
            pretrained_params_frozen = False
            
            # Recreate optimizer to include all parameters
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            
            # Recreate scheduler if it exists
            if lr_schedule == 'linear':
                try:
                    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
                except Exception:
                    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.0, 1.0 - (epoch + 1) / float(max(1, epochs))))
            elif lr_schedule == 'step':
                scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=0.001, total_iters=2000)
            
            if verbose:
                trainable_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
                print(f"All parameters now trainable: {trainable_params_count:,}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, x in enumerate(data_loader):
            last_batch_data = [x_m.clone() for x_m in x]
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()
            
            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                # Ensure target and prediction have matching shapes
                if i == 1 and full_spectrum:
                    if x_m.dim() == 3:
                        x_m_reshaped = x_m.unsqueeze(1)
                    else:
                        x_m_reshaped = x_m
                    if x_hat_m.dim() == 2:
                        x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                else:
                    x_m_reshaped = x_m
                
                # Compute BCE loss for this modality
                if modality_masks[i] is not None:
                    m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                else:
                    m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                
                # Check for NaN
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        if start_reduction and include_ortholoss:
            total_ortho_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)
                
                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if i == 1 and full_spectrum:
                        if x_m.dim() == 3:
                            x_m_reshaped = x_m.unsqueeze(1)
                        else:
                            x_m_reshaped = x_m
                        if x_hat_m.dim() == 2:
                            x_hat_m = x_hat_m.view(x_hat_m.shape[0], 1, 112, 112)
                    else:
                        x_m_reshaped = x_m
                    
                    if modality_masks[i] is not None:
                        m_loss = F.binary_cross_entropy(x_hat_m[modality_masks[i]], x_m_reshaped[modality_masks[i]], reduction='mean')
                    else:
                        m_loss = F.binary_cross_entropy(x_hat_m, x_m_reshaped, reduction='mean')
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
        
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        
        # Step the scheduler
        if scheduler is not None:
            try:
                scheduler.step()
            except Exception:
                pass
        
        # Track initial early stopping (before rank reduction starts)
        if not initial_training_complete:
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1
            
            # Check if initial early stopping is reached
            if early_stopping_counter >= early_stopping:
                initial_training_complete = True
                if verbose:
                    print(f"\nInitial training complete at epoch {epoch}. Val loss stopped improving.")
                    print(f"Best val loss: {best_val_loss:.6f}")
                    print("Starting rank reduction phase...")
        
        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') 
                    else (model.module.adaptive_layers if multi_gpu else []),
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) if current_rsquare_per_mod[i] is not None else 'N/A' for i in range(len(data))],
            'patience': patience_counter,
            'early_stop': early_stopping_counter if not initial_training_complete else 'complete',
        }
        if recon_loss_balancing and all(ema is not None for ema in modality_loss_emas):
            min_ema = min(ema for ema in modality_loss_emas if ema > 0)
            balance_scales = [round(min_ema / ema, 3) if ema > 0 else 1.0 for ema in modality_loss_emas]
            log_dict.update({'balance_scales': balance_scales})
        pbar.set_postfix(log_dict)
        
        # Update best loss for rank reduction phase
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1
        
        # Only start rank reduction after initial training is complete
        if (start_reduction is False) and initial_training_complete and (epoch >= model.epoch + patience):
            rank_history = {
                'total_rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks':[', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch':[model.epoch],
                'loss':[train_losses[-1]],
                'val_loss':[val_losses[-1]]
            }
            
            start_reduction = True
            break_counter = 0
            
            min_rsquares = []
            mask = val_data.mask
            modality_masks_latent = []
            modality_masks_data = []
            modality_masks_space = []
            
            # Build validation subset
            n_sub = int(0.1 * train_data.data[0].shape[0])
            audio_subset = train_data.data[1][:n_sub]
            
            try:
                model_full_spectrum = getattr(model, 'full_spectrum', False)
            except Exception:
                model_full_spectrum = full_spectrum
            
            if model_full_spectrum:
                try:
                    orig_audio_all = data[1]
                    orig_audio_subset = orig_audio_all[train_indices][:n_sub]
                except Exception:
                    orig_audio_subset = None
                
                if orig_audio_subset is not None and orig_audio_subset.dim() > 2:
                    audio_subset = orig_audio_subset.view(orig_audio_subset.shape[0], -1)
                else:
                    if audio_subset.dim() > 2:
                        audio_subset = audio_subset.view(audio_subset.shape[0], -1)
                    else:
                        if verbose:
                            print("   Warning: model expects full-spectrum audio but only averaged audio available.")
            else:
                if not full_spectrum and audio_subset.dim() > 2:
                    audio_subset = torch.mean(audio_subset, dim=1)
            
            val_data_list = [train_data.data[0][:n_sub].to(device), audio_subset.to(device)]
            
            if verbose:
                print(f"   Debug: val_data_list[0] shape (images) = {val_data_list[0].shape}")
                print(f"   Debug: val_data_list[1] shape (audio)  = {val_data_list[1].shape}")
            
            if decision_metric == 'ExVarScore':
                direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu, verbose=verbose)
            else:
                direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu, verbose=verbose)
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
                
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
            
            if verbose:
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(len(current_rsquare_per_mod))]}, setting {threshold_type} thresholds to {min_rsquares}")
        
        # Apply rank reduction at scheduled epochs
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            if (reduce_on_best_loss == 'rsquare') & (start_reduction):
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                mask = val_data.mask
                modality_masks_data = []
                modality_masks_latent = []
                modality_masks_space = []
                
                if not full_spectrum:
                    val_data_list = [train_data.data[0][:int(0.1 * train_data.data[0].shape[0])].to(device), torch.mean(train_data.data[1], dim=1)[:int(0.1 * train_data.data[0].shape[0])].to(device)]
                else:
                    audio_subset = train_data.data[1][:int(0.1 * train_data.data[0].shape[0])]
                    val_data_list = [train_data.data[0][:int(0.1 * train_data.data[0].shape[0])].to(device), audio_subset.to(device)]
                
                if decision_metric == 'ExVarScore':
                    direct_r_squared_values = compute_direct_explained_variance(model, val_data_list, device, multi_gpu)
                else:
                    direct_r_squared_values = compute_direct_r_squared(model, val_data_list, device, multi_gpu)
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                
                r_squares.append(current_rsquares)
                max_rquares = [max(r_squares, key=lambda x: x[i])[i] for i in range(len(current_rsquare_per_mod))] if len(r_squares) > 0 else initial_squares
                if threshold_type == 'relative':
                    min_rsquares = [r * r_square_threshold for r in max_rquares]
                elif threshold_type == 'absolute':
                    min_rsquares = [r - r_square_threshold for r in max_rquares]
                
                # Determine what modalities to reduce or increase
                #if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= (patience-1):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                        else:
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            #elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                            #    modalities_to_increase.append(i)
                        else:
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            #elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                            #    modalities_to_increase.append(i)
                
                # Set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        if cr <= min_ranks[i]:
                            min_ranks[i] = cr
                        model.adaptive_layers[i].min_rank = min_ranks[i]
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                
                # Set the layers to reduce or increase
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    #layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    layers_to_increase = [i + 1 for i in modalities_to_increase]
                    #model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
            
            # Apply rank changes
            any_changes_made = False
            if len(layers_to_reduce) > 0:
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                if changes_made:
                    any_changes_made = True
            
            if len(layers_to_increase) > 0:
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                if changes_made:
                    any_changes_made = True
                    break_counter = patience
            
            if any_changes_made:
                patience_counter = 0
            else:
                patience_counter += 1
            
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
            
            # Store current rank in history
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            for i in range(len(current_rsquare_per_mod)):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquares[i])
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
        else:
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1
        
        # Get normalized weights for display
        if multi_gpu:
            weights = model.module.modality_weights
        else:
            weights = model.modality_weights
        
        pos_weights = F.softplus(weights)
        norm_weights = (pos_weights / (pos_weights.sum() + 1e-8)).detach().cpu().numpy().round(3)
        
        # Early stopping
        if (epoch > early_stopping) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and ranks {rank_history['ranks'][-1]}")
            break
    
    # Calculate latent representations in batches
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            if not full_spectrum:
                x_batch[1] = torch.mean(x_batch[1], dim=1)
            
            if multi_gpu:
                batch_reps = model.module.encode(x_batch)
            else:
                batch_reps = model.encode(x_batch)
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
            
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Save the model
    if model_name:
        model_path = f"./03_results/models/{model_name}.pt"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        if multi_gpu:
            torch.save(model.module.state_dict(), model_path)
        else:
            torch.save(model.state_dict(), model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None]
    
    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses]


class AdaptiveRankReducedAE_AvMnist_Pretrained(nn.Module):
    """
    Multimodal autoencoder with pretrained encoders/decoders for each modality.
    Uses specialized encoders/decoders and learns shared and modality-specific subspaces.
    """
    def __init__(self, input_dims, latent_dims, pretrained_encoders=None, pretrained_decoders=None,
                 depth=2, width=0.5, dropout=0.0, initial_rank_ratio=1.0, min_rank=1, full_spectrum=False):
        """
        Args:
            input_dims (tuple): Dimensions of the input for each modality (e.g., (784, 112)).
            latent_dims (tuple): Dimensions for [mod1_specific, mod2_specific, shared] subspaces.
            pretrained_encoders (list): List of pretrained encoder modules (optional).
            pretrained_decoders (list): List of pretrained decoder modules (optional).
            full_spectrum (bool): Whether to use full spectrum (112x112) audio data.
        """
        super(AdaptiveRankReducedAE_AvMnist_Pretrained, self).__init__()
        self.full_spectrum = full_spectrum
        self.epoch = 0  # Track current epoch for checkpointing
        
        # --- 1. Use Pretrained Encoders/Decoders or Create New Ones ---
        self.encoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        self.decoders = nn.ModuleList([nn.ModuleList() for _ in range(len(input_dims))])
        
        if pretrained_encoders is not None and pretrained_decoders is not None:
            # Use pretrained encoders and decoders
            for i in range(len(input_dims)):
                self.encoders[i].append(pretrained_encoders[i])
                self.decoders[i].append(pretrained_decoders[i])
        else:
            # Create encoders/decoders from scratch (same as original)
            img_latent_dim = latent_dims[0]
            self.channel_size = 16
            self.n_pixel_after_conv = 4 * 4 * self.channel_size * 4
            
            self.img_encoder = nn.Sequential(
                nn.Conv2d(1, self.channel_size, kernel_size=5),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Conv2d(self.channel_size, self.channel_size*4, kernel_size=5),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Flatten(),
                nn.Linear(self.n_pixel_after_conv, img_latent_dim*4),
                nn.ReLU(),
                nn.Linear(img_latent_dim*4, img_latent_dim),
                nn.Sigmoid()
            )
            self.encoders[0].append(self.img_encoder)
            
            audio_latent_dim = latent_dims[1]
            if full_spectrum:
                self.audio_channel_size = 16
                self.audio_n_pixel_after_conv = 10 * 10 * (self.audio_channel_size * 4)
                
                self.audio_encoder = nn.Sequential(
                    nn.Conv2d(1, self.audio_channel_size, kernel_size=5, padding=0),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(self.audio_channel_size, self.audio_channel_size*2, kernel_size=5, padding=0),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(self.audio_channel_size*2, self.audio_channel_size*4, kernel_size=5, padding=0),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Flatten(),
                    nn.Linear(self.audio_n_pixel_after_conv, audio_latent_dim*8),
                    nn.ReLU(),
                    nn.Linear(audio_latent_dim*8, audio_latent_dim),
                    nn.Sigmoid()
                )
            else:
                self.audio_encoder = nn.Sequential(
                    nn.Linear(input_dims[1], audio_latent_dim*4),
                    nn.ReLU(),
                    nn.Linear(audio_latent_dim*4, audio_latent_dim),
                    nn.Sigmoid()
                )
            self.encoders[1].append(self.audio_encoder)
            
            # Decoders
            self.img_decoder = nn.Sequential(
                nn.Linear(latent_dims[2] + img_latent_dim, 2*(latent_dims[2] + img_latent_dim)),
                nn.ReLU(),
                nn.Linear(2*(latent_dims[2] + img_latent_dim), self.n_pixel_after_conv),
                nn.ReLU(),
                nn.Unflatten(1, (self.channel_size*4, 4, 4)),
                nn.ConvTranspose2d(self.channel_size*4, self.channel_size, kernel_size=5, stride=2, output_padding=1),
                nn.ReLU(),
                nn.ConvTranspose2d(self.channel_size, 1, kernel_size=5, stride=2, output_padding=1),
                nn.Sigmoid()
            )
            self.decoders[0].append(self.img_decoder)
            
            if full_spectrum:
                self.audio_decoder = nn.Sequential(
                    nn.Linear(latent_dims[2] + audio_latent_dim, 4*(latent_dims[2] + audio_latent_dim)),
                    nn.ReLU(),
                    nn.Linear(4*(latent_dims[2] + audio_latent_dim), self.audio_n_pixel_after_conv),
                    nn.ReLU(),
                    nn.Unflatten(1, (self.audio_channel_size*4, 10, 10)),
                    nn.ConvTranspose2d(self.audio_channel_size*4, self.audio_channel_size*2, kernel_size=5, stride=2, padding=1, output_padding=1),
                    nn.ReLU(),
                    nn.ConvTranspose2d(self.audio_channel_size*2, self.audio_channel_size, kernel_size=5, stride=2, padding=1, output_padding=1),
                    nn.ReLU(),
                    nn.ConvTranspose2d(self.audio_channel_size, 1, kernel_size=7, stride=2, padding=1, output_padding=1),
                    nn.Upsample(size=(112, 112), mode='bilinear', align_corners=False),
                    nn.Sigmoid()
                )
            else:
                self.audio_decoder = nn.Sequential(
                    nn.Linear(latent_dims[2] + audio_latent_dim, 2*(latent_dims[2] + audio_latent_dim)),
                    nn.ReLU(),
                    nn.Linear(2*(latent_dims[2] + audio_latent_dim), input_dims[1])
                )
            self.decoders[1].append(self.audio_decoder)
        
        # --- 2. Define Layers for Shared and Specific Subspaces ---
        # Import AdaptiveRankReducedLinear here
        from src.models.larrp_unimodal import AdaptiveRankReducedLinear
        
        self.adaptive_layers = nn.ModuleList()
        shared_latent_dim = latent_dims[2]
        
        # Shared layer
        shared_layer = AdaptiveRankReducedLinear(
            sum(latent_dims[:len(input_dims)]), latent_dims[-1],
            initial_rank_ratio=initial_rank_ratio,
            min_rank=min_rank
        )
        self.adaptive_layers.append(shared_layer)
        
        # Specific layers
        for i in range(len(input_dims)):
            specific_layer = AdaptiveRankReducedLinear(
                latent_dims[i], latent_dims[i],
                initial_rank_ratio=initial_rank_ratio,
                min_rank=min_rank
            )
            self.adaptive_layers.append(specific_layer)
        
        self.map_back_layers = nn.ModuleList()
        for i in range(len(input_dims)):
            map_back_layer = nn.Linear(latent_dims[-1] + latent_dims[i], latent_dims[i])
            self.map_back_layers.append(map_back_layer)
        
        # Initialize modality weights for loss balancing
        self.modality_weights = nn.Parameter(torch.ones(len(input_dims)), requires_grad=True)
    
    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h

    def encode(self, x, compute_jacobian=False):
        h_concat = []
        for m, x_m in enumerate(x):
            if m == 0:
                # Image: accept (batch, 784) or (batch, 1, 28, 28)
                if x_m.dim() == 2:
                    x_m = x_m.view(-1, 1, 28, 28)
                elif x_m.dim() == 3:
                    x_m = x_m.unsqueeze(1)
            elif m == 1:
                # Audio handling
                if self.full_spectrum:
                    if x_m.dim() == 2:
                        x_m = x_m.view(-1, 1, 112, 112)
                    elif x_m.dim() == 3:
                        x_m = x_m.unsqueeze(1)
                else:
                    if x_m.dim() > 2:
                        x_m = x_m.mean(dim=-1)
            
            for layer in self.encoders[m]:
                x_m = layer(x_m)
            h_concat.append(x_m)
        
        h = torch.cat(h_concat, dim=1)

        if not compute_jacobian:
            h_shared = self.adaptive_layers[0](h)
            specific_outputs = []
            for i in range(len(h_concat)):
                h_specific = self.adaptive_layers[i + 1](h_concat[i])
                specific_outputs.append(h_specific)
            return h_shared, specific_outputs
        else:
            # Jacobian computation logic (if needed)
            return h, None
    
    def decode(self, h):
        h_shared, h_specific = h
        x_hat = []
        for m, h_m in enumerate(h_specific):
            h_concat = torch.cat([h_shared, h_m], dim=1)
            h_concat = self.map_back_layers[m](h_concat)
            #print(f"After map_back layer for modality {m}, shape: {h_concat.shape}")
            for layer in self.decoders[m]:
                h_concat = layer(h_concat)
                #print(layer, f" output shape for modality {m}: {h_concat.shape}")
            if m == 0:
                # Image: flatten to 784 dimensions
                h_concat = h_concat.view(-1, 784)
            elif m == 1 and self.full_spectrum:
                # Audio in full spectrum mode
                if h_concat.shape[2] != 112 or h_concat.shape[3] != 112:
                    h_concat = nn.functional.interpolate(h_concat, size=(112, 112), mode='bilinear', align_corners=False)
                #if h_concat.dim() == 2:
                #    try:
                #        h_concat = h_concat.view(-1, 1, 112, 112)
                #    except Exception:
                #        batch = h_concat.shape[0]
                #        h_concat = h_concat.reshape(batch, 1, 112, 112)
            x_hat.append(h_concat)
        return x_hat
    
    def encode_modalities(self, x):
        h_shared, h_specific = self.encode(x)
        h_combined = []
        for i, h_m in enumerate(h_specific):
            h_combined.append(torch.cat([h_shared, h_m], dim=1))
        return h_combined
    
    def reduce_rank(self, reduction_ratio=0.8, threshold=0.01, layer_ids=[], dim=0):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            S = layer.get_rank_reduction_info()
            if len(S) <= layer.min_rank:
                continue
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item(),
                             int(layer.active_dims * reduction_ratio))
            current_rank = layer.active_dims
            if target_rank < current_rank:
                layer.reduce_rank(target_rank, dim=dim, which_dims=None)
                changes_made = True
        return changes_made

    def increase_rank(self, increment=None, increase_ratio=1.1, layer_ids=[], dim=0):
        """Increase rank of all adaptive layers"""
        changes_made = False
        for i, layer in enumerate(self.adaptive_layers):
            if i not in layer_ids:
                continue
            if layer.increase_rank(increment=increment, increase_ratio=increase_ratio, dim=dim, mode='multimodal'):
                changes_made = True
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)

