import torch
import torch.nn as nn
import numpy as np
import tqdm
import gc
import matplotlib.pyplot as plt
from src.models.larrp_unimodal import AdaptiveRankReducedAE
from src.functions.linear_probing import parallel_linear_regression
import os

def plot_loss_curves(train_losses, val_losses, save_path, title="Loss Curves", pretraining_epochs=None):
    """
    Plot training and validation loss curves
    
    Parameters:
    - train_losses: List of training losses
    - val_losses: List of validation losses
    - save_path: Path to save the plot
    - title: Title for the plot
    - pretraining_epochs: Number of pretraining epochs (for combined plots)
    """
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_losses) + 1)
    
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', alpha=0.7)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', alpha=0.7)
    
    # Add vertical line to separate pretraining and training phases
    if pretraining_epochs is not None and pretraining_epochs > 0:
        plt.axvline(x=pretraining_epochs, color='gray', linestyle='--', alpha=0.7, 
                   label=f'End of Pretraining (epoch {pretraining_epochs})')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Loss curve plot saved to {save_path}")

def train_overcomplete_ae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10, reduce_on_best_loss='rsquare', r_square_threshold=0.9,
                         l1_start_weight=0.0, l1_step_size=1.0, rank_or_sparse='rank',
                         verbose=True
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    """
    # Declare multi_gpu as global so it can be accessed
    #global multi_gpu
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model with adaptive rank reduction
    if rank_or_sparse == 'sparse':
        raise NotImplementedError("Sparse autoencoder is not implemented yet.")
    else:
        model = AdaptiveRankReducedAE(
            data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        ).to(device)
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            if verbose:
                print(f"Failed to use DataParallel: {e}")
                print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = torch.nn.MSELoss()

    # Add linear learning rate scheduler
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=epochs)

    # find out whether the data is too sparse for Rsquare or not
    nonzero_mask = data > 0
    data_sparsity = 1 - nonzero_mask.float().mean()
    if data_sparsity < 0.9:
        if verbose:
            print(f"Sparsity of the data is {data_sparsity:.2f}, using linear regression with R^2.")
        data_sparsity = False
    else:
        #print(f"Sparsity of the data is {sparsity:.2f}, using sparse linear regression.")
        #return sparse_linear_regression(x, y, n_samples, n_samples_train, n_epochs=n_epochs, early_stopping=early_stopping)
        if verbose:
            print(f"Sparsity of the data is {data_sparsity:.2f}, using linear regression with RMSE.")
        data_sparsity = True
    
    # Create data loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = data[n_samples_train:]
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data.shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    initial_square = None
    current_rsquare = None
    start_reduction = False
    bottom_reached = False
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    best_loss = float('inf')
    patience_counter = 0
    if rank_or_sparse == 'rank':
        rank_history = {'rank':[model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                    (model.module.get_total_rank() if multi_gpu else 0)],
                    'epoch':[0],
                    'loss':[float('inf')],
                    'val_loss':[float('inf')],
                    'rsquare':[]}
    else:
        rank_history = {'rank':model.latent_dim, 'epoch':[0], 'loss':[float('inf')], 'rsquare':[]}
        current_rank = model.latent_dim
    
    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0

        for x in data_loader:
            x = x.to(device)

            # Forward pass
            x_hat = model(x)

            # Compute loss
            loss = loss_fn(x_hat, x)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(data_loader)
        train_losses.append(train_loss)

        # Step the learning rate scheduler
        scheduler.step()

        for x_val in val_data_loader:
            x_val = x_val.to(device)
            
            # Forward pass
            with torch.no_grad():
                x_val_hat = model(x_val)
            # Compute validation loss
            val_loss += loss_fn(x_val_hat, x_val).item()

        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter = 0  # Reset patience counter
        else:
            if reduce_on_best_loss in ['true', 'stagnation']:
                patience_counter += 1
        
        # Update progress bar with both loss and current rank information
        if rank_or_sparse == 'rank':
            current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        pbar.set_postfix({
            'loss': round(train_loss, 4),
            'rank': current_rank,
            'best_loss': round(best_loss, 4),
            'current_rsquare': current_rsquare,
        })
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction):
            # Check if we should reduce rank based on loss condition
            should_reduce = True
            should_increase = False
            if reduce_on_best_loss == 'true' and train_loss > best_loss:
                should_reduce = False
                # Don't print a separate message, will show in progress bar
            elif reduce_on_best_loss == 'stagnation' and patience_counter < patience:
                should_reduce = False
                # Don't print a separate message, will show in progress bar
            elif (reduce_on_best_loss == 'rsquare') & (start_reduction):
                # get the r_square value and determine if it is above the threshold
                with torch.no_grad():
                    encoded = model.encode(data[n_samples_train:].to(device))
                current_rsquare = parallel_linear_regression(encoded,
                                                            data[n_samples_train:].to(device),
                                                            n_samples_val,
                                                            int(n_samples_val*0.9),
                                                            device,
                                                            args,
                                                            n_epochs=500,
                                                            early_stopping=50,
                                                            sparse=data_sparsity)
                current_rsquare = current_rsquare.mean().item()
                r_squares.append(current_rsquare)
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)): # either 10 or half of patience, whichever is smaller
                    #if current_rsquare < min_rsquare:
                    #if np.mean(r_squares[:-3]) < min_rsquare: # rolling average for more robustness
                    if all( r < min_rsquare for r in r_squares[-min(10, int(patience/2)):]) and (not bottom_reached): # only allow increasing back once
                        # increase the rank
                        should_reduce = False
                        should_increase = True
                        if verbose:
                            print(f"R-squared {current_rsquare} is below threshold {min_rsquare}, increasing rank")
                        # Reset patience counter
                        #patience_counter = 0
                        bottom_reached = True
                    #elif any(r < min_rsquare for r in r_squares[-3:]):
                    elif current_rsquare < min_rsquare:
                        should_reduce = False
                        patience_counter += 1
                    #else:
                    #    patience_counter = 0
                    #    #print(f"R-squared {current_rsquare.mean().item()} is above threshold {min_rsquare}, reducing rank")
                #elif (len(r_squares) >= 3) and patience_counter >= 3:
                elif (len(r_squares) >= 1) and (patience_counter >= 1):
                    #if all(r > min_rsquare for r in r_squares[-3:]):
                    #    patience_counter = 0
                    #else:
                    #if any(r < min_rsquare for r in r_squares[-3:]):
                    if current_rsquare < min_rsquare:
                        should_reduce = False
                        patience_counter += 1
                else:
                    # Not enough data to determine if we should reduce rank
                    should_reduce = False
                    patience_counter += 1
            
            if should_reduce:
                if rank_or_sparse == 'rank':
                    # Get current total rank
                    total_rank_before = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                    
                    # Apply rank reduction
                    if multi_gpu:
                        changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                    else:
                        changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                        
                    # Get new rank but don't print separate message
                    total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
                else:
                    # For sparse autoencoder, just update L1 weight
                    model.update_l1_weight(l1_step_size)
                    if verbose:
                        print(f"Updated L1 weight to {model.l1_weight}")
                    total_rank_after = model.get_n_active_neurons(model.get_activations(data[:n_samples_train]))
                    current_rank = total_rank_after
                
                # Store current rank in history
                #if changes_made:
                #    rank_history['rank'].append(total_rank_after)
                #    rank_history['epoch'].append(epoch)
                #    if reduce_on_best_loss == 'rsquare':
                #        rank_history['rsquare'].append(current_rsquare)
                #    rank_history['loss'].append(train_loss)
                # in the end, if no changes were made, we need to increase the patience counter
                if changes_made:
                    patience_counter = 0
                else:
                    patience_counter += 1
            if should_increase:
                changes_made = model.increase_rank(reduction_ratio=0.9)
                total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                current_rank = total_rank_after
                #if changes_made:
                #    rank_history['rank'].append(total_rank_after)
                #    rank_history['epoch'].append(epoch)
                #    if reduce_on_best_loss == 'rsquare':
                #        rank_history['rsquare'].append(current_rsquare)
                #    rank_history['loss'].append(train_loss)
                if changes_made:
                    patience_counter = 0
                else:
                    patience_counter += 1
        rank_history['rank'].append(current_rank)
        rank_history['epoch'].append(epoch)
        #if reduce_on_best_loss == 'rsquare':
        rank_history['rsquare'].append(current_rsquare)
        rank_history['loss'].append(train_loss)
        rank_history['val_loss'].append(val_loss)
        
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            #print(f"Early stopping at epoch {epoch}")
            #break
            start_reduction = True  # Start rank reduction after early stopping
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")

            with torch.no_grad():
                encoded = model.encode(data[n_samples_train:].to(device))
            initial_square = parallel_linear_regression(encoded, 
                                                        data[n_samples_train:].to(device), 
                                                        n_samples_val, 
                                                        int(n_samples_val*0.9), 
                                                        device,
                                                        args,
                                                        n_epochs=500, 
                                                        early_stopping=50,
                                                        sparse=data_sparsity)
            if data_sparsity:
                min_rsquare = initial_square.mean().item() * (2 - r_square_threshold)
            else:
                min_rsquare = initial_square.mean().item() * r_square_threshold
            current_rsquare = initial_square.mean().item()
            #rank_history['rsquare'] = [initial_square.mean().item()]
            rank_history['rsquare'].append(current_rsquare)
            if verbose:
                print(f"Initial R-squared value: {initial_square.mean().item()}, setting threshold to {min_rsquare}")
            r_squares.append(current_rsquare)

        # early stopping but conditioned on rank reduction
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss} and rank {current_rank}")
            break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            # No need to convert dtype
            reps_list.append(batch_reps)
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # empty cache
    #del model, optimizer, loss_fn, data_loader
    del data_loader, val_data_loader, optimizer, loss_fn
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    # Linear regression evaluation
    r_squares = parallel_linear_regression(reps, data[:n_samples_train], n_samples_train, int(n_samples_train*0.9), device, args, n_epochs=500, early_stopping=50, verbose=verbose, sparse=data_sparsity)
    
    # remove all nan and inf values
    r_squares = r_squares[torch.isfinite(r_squares)]
    
    # Free memory
    #del reps, reps_list
    del reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    return model, reps, np.mean(train_losses[-5:]), r_squares.mean().item(), rank_history, [train_losses, val_losses]


def compute_direct_r_squared_unimodal(model, data, device, multi_gpu=False, verbose=False, metric='R2'):
    """
    Compute R² based on direct model reconstruction performance for unimodal data
    
    Parameters:
    - model: The trained model
    - data: Input data tensor
    - device: Device to run computation on
    - multi_gpu: Whether model is wrapped with DataParallel
    
    Returns:
    - R² value
    """
    model.eval()
    
    with torch.no_grad():
        # Get model predictions
        data_tensor = data.to(device)
        reconstruction = model(data_tensor)
        
        # Calculate mean of original data
        original_mean = data_tensor.mean(dim=0).cpu()
        original_cpu = data_tensor.cpu()
        reconstruction_cpu = reconstruction.cpu()

        if metric == 'R2':
            # Handle zeros in mean values
            if torch.any(original_mean == 0):
                if verbose:
                    print(f"   Warning: zeros found in original_mean. Removing samples.")
                non_zero_mean = original_mean != 0
                if non_zero_mean.sum() == 0:
                    # If all means are zero, use correlation as fallback
                    r_squared = torch.corrcoef(torch.stack((original_cpu.flatten(), reconstruction_cpu.flatten())))[0, 1]
                    if torch.isnan(r_squared):
                        r_squared = torch.tensor(0.0)
                else:
                    # Calculate R² only for non-zero mean dimensions
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[non_zero_mean]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[non_zero_mean]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            elif torch.any(torch.isnan(original_cpu)) or torch.any(torch.isinf(original_cpu)):
                if verbose:
                    print(f"   Warning: NaN or Inf values found in original data. Handling them.")
                # Handle NaN or Inf values
                valid_mask = ~torch.isnan(original_mean) & ~torch.isinf(original_mean)
                if valid_mask.sum() == 0:
                    # If no valid values, set R² to 0
                    r_squared = torch.tensor(0.0)
                else:
                    valid_indices = valid_mask
                    ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)[valid_indices]
                    ss_tot = ((original_cpu - original_mean)**2).sum(0)[valid_indices]
                    r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                    r_squared = r_squared.mean()  # Average across dimensions
            else:
                # Normal case - calculate standard R²
                ssr = ((original_cpu - reconstruction_cpu)**2).sum(0)
                ss_tot = ((original_cpu - original_mean)**2).sum(0)
                r_squared = 1 - ((ssr + 1e-9) / (ss_tot + 1e-9))
                r_squared = r_squared.mean()  # Average across dimensions
            
            # Ensure r_squared is a scalar tensor
            if not isinstance(r_squared, torch.Tensor):
                r_squared = torch.tensor(r_squared)
            return r_squared.item()

        elif metric == 'MSE':
            mse = ((original_cpu - reconstruction_cpu)**2).mean()
            if not isinstance(mse, torch.Tensor):
                mse = torch.tensor(mse)
            return mse.item()
        elif metric == 'RMSE':
            rmse = torch.sqrt(((original_cpu - reconstruction_cpu)**2).mean())
            if not isinstance(rmse, torch.Tensor):
                rmse = torch.tensor(rmse)
            return rmse.item()
        elif metric == 'ExVarScore':
            ex_var = 1 - (torch.var(reconstruction_cpu - original_cpu) / (torch.var(original_cpu) + 1e-9))
            if not isinstance(ex_var, torch.Tensor):
                ex_var = torch.tensor(ex_var)
            return ex_var.item()
        elif metric == 'McFaddenR2':
            return mcfadden_r_squared(original_cpu, reconstruction_cpu).item()

def mcfadden_r_squared(y_true, y_pred):
    """
    Compute McFadden's R² for regression tasks.
    
    Parameters:
    - y_true: True values (torch tensor)
    - y_pred: Predicted values (torch tensor)
    
    Returns:
    - McFadden's R² value
    """
    # Ensure inputs are 2D tensors
    if y_true.dim() == 1:
        y_true = y_true.unsqueeze(1)
    if y_pred.dim() == 1:
        y_pred = y_pred.unsqueeze(1)
    
    # Calculate log-likelihood of the null model (intercept only)
    mean_y = torch.mean(y_true, dim=0)
    ll_null = -0.5 * torch.sum((y_true - mean_y) ** 2, dim=0)
    
    # Calculate log-likelihood of the fitted model
    ll_model = -0.5 * torch.sum((y_true - y_pred) ** 2, dim=0)
    
    # Compute McFadden's R²
    with torch.no_grad():
        r_squared = 1 - (ll_model / (ll_null + 1e-9))
        r_squared = r_squared.mean()  # Average across dimensions if multi-dimensional
    
    return r_squared

def train_overcomplete_ae2(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                          lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                          initial_rank_ratio=1.0, min_rank=10, 
                          rank_schedule=None, rank_reduction_frequency=10, 
                          rank_reduction_threshold=0.01, warmup_epochs=0,
                          patience=10, r_square_threshold=0.9,
                          threshold_type='relative', rank_or_sparse='rank',
                          verbose=True, compute_jacobian=False, model_name=None, 
                          l2_norm_adaptivelayers=None, sharedwhenall=True,
                          return_detailed_history=False, lr_scheduler=False
                          ):
    """
    Train a unimodal autoencoder with adaptive rank reduction
    Simplified version based on multimodal train_overcomplete_ae with fixed conditions:
    - reduce_on_best_loss='rsquare'
    - compressibility_type='direct'
    - reduction_criterion='r_squared'
    - No orthogonal loss, L1 weights, or loss balancing
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - patience: Patience for rank reduction
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model with adaptive rank reduction
    if rank_or_sparse == 'sparse':
        raise NotImplementedError("Sparse autoencoder training not implemented in this function.")
    else:
        model = AdaptiveRankReducedAE(
            data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        ).to(device)
        #print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer and loss function
    if l2_norm_adaptivelayers is not None:
        # Use AdamW with separate weight decay for adaptive layers
        adaptive_params = []
        for layer in model.adaptive_layers if not multi_gpu else model.module.adaptive_layers:
            adaptive_params.extend(list(layer.parameters()))
        
        # Get all other parameters (excluding adaptive layers)
        all_params = set(model.parameters())
        adaptive_params_set = set(adaptive_params)
        other_params = list(all_params - adaptive_params_set)
        
        # Create parameter groups with different weight decay
        param_groups = [
            {'params': other_params, 'weight_decay': wd},
            {'params': adaptive_params, 'weight_decay': l2_norm_adaptivelayers}
        ]
        optimizer = torch.optim.AdamW(param_groups, lr=lr)
    else:
        # Use standard Adam optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    if lr_scheduler:
        start_lr_factor = 1.0
        end_lr_factor = 0.001
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)

    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = data[n_samples_train:]
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data.shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    
    initial_square = None
    current_rsquare = None
    start_reduction = False
    bottom_reached = False
    min_rsquare = None
    break_counter = 0
    increase_counter = 0
    latest_change_epoch = -1
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    best_loss = float('inf')
    patience_counter = 0
    all_losses = []
    all_rsquares = []
    all_rsquare_epochs = []
    all_ranks = []
    all_lrs = []
    r_squares_since_reduction = []
    
    rank_history = {'total_rank': [model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                    (model.module.get_total_rank() if multi_gpu else 0)],
                    'ranks': [', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                    'epoch': [0],
                    'loss': [float('inf')],
                    'val_loss': [float('inf')],
                    'rsquare': []}
    
    patience_counter = 0
    increase_patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        all_lrs.append(optimizer.param_groups[0]['lr'])

        for x in data_loader:
            x = x.to(device)

            # Forward pass
            x_hat = model(x)

            # Compute loss
            loss = loss_fn(x_hat, x)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(data_loader)
        train_losses.append(train_loss)

        # Step the learning rate scheduler
        if lr_scheduler:
            scheduler.step()

        # Validation phase
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = x_val.to(device)
                x_val_hat = model(x_val)
                val_loss += loss_fn(x_val_hat, x_val).item()

        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        all_losses.append(val_loss)

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            #patience_counter = 0  # Reset patience counter
        #else:
        #    patience_counter += 1
        
        # Update progress bar with both loss and current rank information
        current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        all_ranks.append(current_rank)
        pbar.set_postfix({
            'loss': round(train_loss, 4),
            'rank': current_rank,
            'best_loss': round(best_loss, 4),
            'current_rsquare': round(current_rsquare, 3) if current_rsquare is not None else 'N/A',
            'patience': patience_counter,
        })
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            # Fixed: reduce_on_best_loss='rsquare', compressibility_type='direct', reduction_criterion='r_squared'
            
            # Get the r_square value using direct reconstruction R²
            current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu)
            r_squares.append(current_rsquare)
            all_rsquares.append(current_rsquare)
            all_rsquare_epochs.append(epoch)
            
            should_reduce = False
            should_increase = False
            
            # as long as the r_squares are increasing again, we don't increase the rank
            #if len(r_squares_since_reduction) > 3 and (current_rsquare < np.mean(r_squares_since_reduction[-min(patience, len(r_squares_since_reduction)):])):
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            #elif len(r_squares_since_reduction) > 3 and (current_rsquare > np.mean(r_squares_since_reduction[-min(3, len(r_squares_since_reduction)):])):
            #elif len(r_squares_since_reduction) > 3 and (current_rsquare > r_squares_since_reduction[-1]):
            #    patience_counter = 0
            #if len(r_squares_since_reduction) > 3 and (current_rsquare > r_squares_since_reduction[-1]):
            #    patience_counter = 0
            #elif len(r_squares_since_reduction) > 3:
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            if len(r_squares_since_reduction) >= (patience-1):
                if (current_rsquare < min_rsquare) and (not bottom_reached):
                    should_increase = True
            if len(r_squares) >= 1:
                # For R²: reduce if R² is above threshold (good performance)
                if current_rsquare > min_rsquare:
                    should_reduce = True
                    r_squares_since_reduction = []
                else:
                    r_squares_since_reduction.append(current_rsquare)
                #elif current_rsquare < min_rsquare and not bottom_reached:
                #    should_increase = True
            #if patience_counter >= (patience-1) and (not bottom_reached):
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            
            changes_made = False
            if should_reduce:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank reduced at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
            elif should_increase:
                # Apply rank increase
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank increased at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
                    break_counter = patience  # give model more time to re-learn the added dimensions
                    increase_counter += 1
                    if increase_counter >= 1:
                        bottom_reached = True
                        model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims
                    # set the minimum rank to the previous rank to avoid decreasing again
                    #print(f"Setting minimum rank from {model.adaptive_layers[0].min_rank} to {model.adaptive_layers[0].active_dims} to avoid further decreases.")
                    #model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims
                    #print(f"New minimum rank is {model.adaptive_layers[0].min_rank}")
            
            if changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
                latest_change_epoch = epoch
                # also reset the scheduler
                if lr_scheduler:
                    # Reset optimizer lr to initial value before re-creating scheduler
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)
            else:
                patience_counter += 1

            # Get new rank and store in history
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            rank_history['rsquare'].append(current_rsquare)
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
        else:
            # If we're in a break period (break_counter > 0), count down
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1

        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0  # start with no breaks (only used when increasing layers)
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")

            # Calculate initial R² threshold
            current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu)
            initial_square = current_rsquare
            
            # Calculate threshold based on threshold_type
            if threshold_type == 'relative':
                min_rsquare = initial_square * r_square_threshold
            elif threshold_type == 'absolute':
                min_rsquare = initial_square - r_square_threshold
            else:
                raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
            
            rank_history['rsquare'].append(current_rsquare)
            if verbose:
                print(f"Initial R-squared value: {initial_square:.4f}, setting {threshold_type} threshold to {min_rsquare:.4f}")
            r_squares.append(current_rsquare)

        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if start_reduction and ((epoch - latest_change_epoch) >= early_stopping):
            if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses[(latest_change_epoch+1):])) & (patience_counter >= patience): # so that we can converge even if the loss is higher after pruning
                if verbose:
                    print(f"Early stopping at epoch {epoch} with best loss {best_loss} and rank {rank_history['ranks'][-1]}")
                break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            reps_list.append(batch_reps)
            
            # Free memory
            del batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # Free memory
    del reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    if return_detailed_history:
        return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses], (all_losses, all_rsquares, all_rsquare_epochs, all_ranks, all_lrs)

    return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses]

def train_overcomplete_ae2_with_pretrained(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                          lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                          initial_rank_ratio=1.0, min_rank=10, 
                          rank_schedule=None, rank_reduction_frequency=10, 
                          rank_reduction_threshold=0.01, warmup_epochs=0,
                          patience=10, r_square_threshold=0.9,
                          threshold_type='relative', rank_or_sparse='rank',
                          verbose=True, compute_jacobian=False, model_name=None, 
                          l2_norm_adaptivelayers=None, sharedwhenall=True,
                          return_detailed_history=False, lr_scheduler=False,
                          pretrained_name=None, distortion_metric='R2'
                          ):
    """
    Train a unimodal autoencoder with adaptive rank reduction
    Simplified version based on multimodal train_overcomplete_ae with fixed conditions:
    - reduce_on_best_loss='rsquare'
    - compressibility_type='direct'
    - reduction_criterion='r_squared'
    - No orthogonal loss, L1 weights, or loss balancing
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - patience: Patience for rank reduction
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    """
    # check if there is an existing pretrained model for the seed, early stopping, and training hyperparameters (lr, wd, batch size, model architecture)
    pretrained_model_path = f"./03_results/models/pretrained_models/{pretrained_name}.pt" if pretrained_name else None
    if pretrained_model_path and os.path.exists(pretrained_model_path):
        print(f"Found existing pretrained model at {pretrained_model_path}. Loading...")
        model = AdaptiveRankReducedAE(
            data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        )
        model.load_state_dict(torch.load(pretrained_model_path, weights_only=False))
        # make sure that the weights are changed
        model.eval()
        for param in model.parameters():
            param.requires_grad = True
        print(f"Loaded pretrained model from {pretrained_model_path}")
        # also load the loss curves
        loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
        train_val_losses = np.load(loss_curve_path, allow_pickle=True)
        train_losses = train_val_losses[0].tolist()
        val_losses = train_val_losses[1].tolist()
        print(f"Loaded loss curves from {loss_curve_path}")
        # print last losses
        print(f"Last training loss: {train_losses[-1]}, last validation loss: {val_losses[-1]}")
        model.epoch = len(train_losses)
        # Save plot of pretraining loss curves
        plot_path = pretrained_model_path.replace('.pt', '_pretraining_loss_plot.png')
        if not os.path.exists(plot_path):
            plot_loss_curves(train_losses, val_losses, plot_path, title="Pretraining Loss Curves")
    else:
        if pretrained_model_path:
            print("No pretrained model found. Training from scratch.")
            from src.functions.pretrain_mm_sim import pretrain_overcomplete_ae_unimodal
            model, [train_losses, val_losses] = pretrain_overcomplete_ae_unimodal(
                data, n_samples_train, latent_dim, device, args, epochs=epochs, early_stopping=early_stopping,
                lr=lr, batch_size=batch_size, ae_depth=ae_depth, ae_width=ae_width, dropout=dropout, wd=wd,
                initial_rank_ratio=initial_rank_ratio, min_rank=min_rank,
                verbose=verbose
            )
            # Save the pretrained model and loss curves
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
            torch.save(model.state_dict(), pretrained_model_path)
            # Also save loss curves
            loss_curve_path = pretrained_model_path.replace('.pt', '_loss_curve.npy')
            np.save(loss_curve_path, np.array([train_losses, val_losses]))
            
            # Save plot of pretraining loss curves
            plot_path = pretrained_model_path.replace('.pt', '_pretraining_loss_plot.png')
            plot_loss_curves(train_losses, val_losses, plot_path, title="Pretraining Loss Curves")
            
            # also save data_indices
            print(f"Saved pretrained model to {pretrained_model_path} and loss curves to {loss_curve_path}")
            print(f"Saved pretraining loss plot to {plot_path}")
            model.epoch = len(train_losses)
        else:
            raise ValueError("model_name must be provided to save/load pretrained models.")
    train_losses_pretrain = train_losses.copy()
    val_losses_pretrain = val_losses.copy()

    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    model.to(device)
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    if lr_scheduler:
        start_lr_factor = 1.0
        end_lr_factor = 0.001
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)

    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = data[n_samples_train:]
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data.shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    
    initial_square = None
    current_rsquare = None
    start_reduction = False
    bottom_reached = False
    min_rsquare = None
    break_counter = 0
    increase_counter = 0
    latest_change_epoch = -1
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    best_loss = float('inf')
    patience_counter = 0
    all_losses = []
    all_rsquares = []
    all_rsquare_epochs = []
    all_ranks = []
    all_lrs = []
    r_squares_since_reduction = []
    
    rank_history = {'total_rank': [model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                    (model.module.get_total_rank() if multi_gpu else 0)],
                    'ranks': [', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                    'epoch': [0],
                    'loss': [float('inf')],
                    'val_loss': [float('inf')],
                    'rsquare': []}
    
    patience_counter = 0
    increase_patience_counter = 0
    pbar = tqdm.tqdm(range(model.epoch,epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        all_lrs.append(optimizer.param_groups[0]['lr'])

        for x in data_loader:
            x = x.to(device)

            # Forward pass
            x_hat = model(x)

            # Compute loss
            loss = loss_fn(x_hat, x)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(data_loader)
        train_losses.append(train_loss)

        # Step the learning rate scheduler
        if lr_scheduler:
            scheduler.step()

        # Validation phase
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = x_val.to(device)
                x_val_hat = model(x_val)
                val_loss += loss_fn(x_val_hat, x_val).item()

        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        all_losses.append(val_loss)

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            #patience_counter = 0  # Reset patience counter
        #else:
        #    patience_counter += 1
        
        # Update progress bar with both loss and current rank information
        current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        all_ranks.append(current_rank)
        pbar.set_postfix({
            'loss': round(train_loss, 4),
            'rank': current_rank,
            'best_loss': round(best_loss, 4),
            'current_r2': round(current_rsquare, 3) if current_rsquare is not None else 'N/A',
            'patience': patience_counter,
            'break': break_counter
        })
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            # Fixed: reduce_on_best_loss='rsquare', compressibility_type='direct', reduction_criterion='r_squared'
            
            # Get the r_square value using direct reconstruction R²
            #current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu)
            current_rsquare = compute_direct_r_squared_unimodal(model, train_data[:int(0.1*len(train_data))], device, multi_gpu, metric=distortion_metric)
            # update max rsquare if necessary
            if threshold_type == 'absolute':
                if current_rsquare > initial_square:
                    initial_square = current_rsquare
                    min_rsquare = initial_square - r_square_threshold
                    if verbose:
                        print(f"New initial R-squared value: {initial_square:.4f}, updating absolute threshold to {min_rsquare:.4f}")
            r_squares.append(current_rsquare)
            all_rsquares.append(current_rsquare)
            all_rsquare_epochs.append(epoch)
            
            should_reduce = False
            should_increase = False
            
            # as long as the r_squares are increasing again, we don't increase the rank
            #if len(r_squares_since_reduction) > 3 and (current_rsquare < np.mean(r_squares_since_reduction[-min(patience, len(r_squares_since_reduction)):])):
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            #elif len(r_squares_since_reduction) > 3 and (current_rsquare > np.mean(r_squares_since_reduction[-min(3, len(r_squares_since_reduction)):])):
            #elif len(r_squares_since_reduction) > 3 and (current_rsquare > r_squares_since_reduction[-1]):
            #    patience_counter = 0
            #if len(r_squares_since_reduction) > 3 and (current_rsquare > r_squares_since_reduction[-1]):
            #    patience_counter = 0
            #elif len(r_squares_since_reduction) > 3:
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            if len(r_squares_since_reduction) >= (patience-1):
                if (current_rsquare < min_rsquare) and (not bottom_reached):
                    should_increase = True
            if len(r_squares) >= 1:
                # For R²: reduce if R² is above threshold (good performance)
                if current_rsquare > min_rsquare:
                    should_reduce = True
                    model.adaptive_layers[0].max_rank = model.adaptive_layers[0].active_dims
                    r_squares_since_reduction = []
                else:
                    r_squares_since_reduction.append(current_rsquare)
                #elif current_rsquare < min_rsquare and not bottom_reached:
                #    should_increase = True
            #if patience_counter >= (patience-1) and (not bottom_reached):
            #    if current_rsquare < min_rsquare:
            #        should_increase = True
            
            changes_made = False
            if should_reduce:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank reduced at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
            elif should_increase:
                # Apply rank increase
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank increased at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
                    break_counter = patience  # give model more time to re-learn the added dimensions
                    increase_counter += 1
                    if increase_counter >= 3:
                        bottom_reached = True
                        model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims
                    # set the minimum rank to the previous rank to avoid decreasing again
                    #print(f"Setting minimum rank from {model.adaptive_layers[0].min_rank} to {model.adaptive_layers[0].active_dims} to avoid further decreases.")
                    #model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims
                    #print(f"New minimum rank is {model.adaptive_layers[0].min_rank}")
            
            if changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
                latest_change_epoch = epoch
                # also reset the scheduler
                if lr_scheduler:
                    # Reset optimizer lr to initial value before re-creating scheduler
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)
            else:
                patience_counter += 1

            # Get new rank and store in history
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            rank_history['rsquare'].append(current_rsquare)
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
        else:
            # If we're in a break period (break_counter > 0), count down
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1

        if (start_reduction is False) and (epoch == model.epoch + early_stopping):
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0  # start with no breaks (only used when increasing layers)
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")

            # Calculate initial R² threshold
            #current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu)
            current_rsquare = compute_direct_r_squared_unimodal(model, train_data[:int(0.1*len(train_data))], device, multi_gpu, metric=distortion_metric)
            initial_square = current_rsquare
            
            # Calculate threshold based on threshold_type
            if threshold_type == 'relative':
                min_rsquare = initial_square * r_square_threshold
            elif threshold_type == 'absolute':
                min_rsquare = initial_square - r_square_threshold
            else:
                raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
            
            rank_history['rsquare'].append(current_rsquare)
            if verbose:
                print(f"Initial R-squared value: {initial_square:.4f}, setting {threshold_type} threshold to {min_rsquare:.4f}")
            r_squares.append(current_rsquare)

        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if start_reduction and ((epoch - latest_change_epoch) >= early_stopping):
            epochs_to_check = epoch - latest_change_epoch - 1 # making sure we don't have the previous epoch with lowest overall loss
            if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses[-epochs_to_check:])) & (patience_counter >= patience): # so that we can converge even if the loss is higher after pruning
                if verbose:
                    print(f"Early stopping at epoch {epoch} with best loss {best_loss} and rank {rank_history['ranks'][-1]}")
                break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            reps_list.append(batch_reps)
            
            # Free memory
            del batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # Free memory
    del reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    # Save combined loss curves plot (pretraining + training)
    if model_name:
        combined_train_losses = train_losses_pretrain + train_losses
        combined_val_losses = val_losses_pretrain + val_losses
        pretraining_epochs = len(train_losses_pretrain)
        
        combined_plot_path = f"03_results/models/{model_name}_combined_loss_plot.png"
        plot_loss_curves(combined_train_losses, combined_val_losses, combined_plot_path, 
                        title="Combined Loss Curves (Pretraining + Training)", 
                        pretraining_epochs=pretraining_epochs)
        print(f"Saved combined loss plot to {combined_plot_path}")

    if return_detailed_history:
        return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses], (all_losses, all_rsquares, all_rsquare_epochs, all_ranks, all_lrs)

    return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses]

def train_overcomplete_ae3(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                          lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                          initial_rank_ratio=1.0, min_rank=10, 
                          rank_schedule=None, rank_reduction_frequency=10, 
                          rank_reduction_threshold=0.01, warmup_epochs=0,
                          patience=10, r_square_threshold=0.9,
                          threshold_type='relative', rank_or_sparse='rank',
                          verbose=True, compute_jacobian=False, model_name=None, 
                          l2_norm_adaptivelayers=None, sharedwhenall=True,
                          return_detailed_history=False, lr_scheduler=False,
                          decision_metric='R2', higher_is_better=True,
                          ):
    """
    Train a unimodal autoencoder with adaptive rank reduction
    Simplified version based on multimodal train_overcomplete_ae with fixed conditions:
    - reduce_on_best_loss='rsquare'
    - compressibility_type='direct'
    - reduction_criterion='r_squared'
    - No orthogonal loss, L1 weights, or loss balancing
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - patience: Patience for rank reduction
    - r_square_threshold: R² threshold for rank reduction decisions
    - threshold_type: 'relative' (multiply by initial R²) or 'absolute' (use threshold directly)
    """
    # Declare multi_gpu as global so it can be accessed
    
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model with adaptive rank reduction
    if rank_or_sparse == 'sparse':
        raise NotImplementedError("Sparse autoencoder training not implemented in this function.")
    else:
        model = AdaptiveRankReducedAE(
            data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
            dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
            min_rank=min_rank
        ).to(device)
        #print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer and loss function
    if l2_norm_adaptivelayers is not None:
        # Use AdamW with separate weight decay for adaptive layers
        adaptive_params = []
        for layer in model.adaptive_layers if not multi_gpu else model.module.adaptive_layers:
            adaptive_params.extend(list(layer.parameters()))
        
        # Get all other parameters (excluding adaptive layers)
        all_params = set(model.parameters())
        adaptive_params_set = set(adaptive_params)
        other_params = list(all_params - adaptive_params_set)
        
        # Create parameter groups with different weight decay
        param_groups = [
            {'params': other_params, 'weight_decay': wd},
            {'params': adaptive_params, 'weight_decay': l2_norm_adaptivelayers}
        ]
        optimizer = torch.optim.AdamW(param_groups, lr=lr)
    else:
        # Use standard Adam optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    if lr_scheduler:
        start_lr_factor = 1.0
        end_lr_factor = 0.001
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)

    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = data[n_samples_train:]
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data.shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    
    initial_square = None
    current_rsquare = None
    start_reduction = False
    bottom_reached = False
    min_rsquare = None
    break_counter = 0
    increase_counter = 0
    latest_change_epoch = -1
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    best_loss = float('inf')
    patience_counter = 0
    all_losses = []
    all_rsquares = []
    all_rsquare_epochs = []
    all_ranks = []
    all_lrs = []
    r_squares_since_reduction = []
    
    rank_history = {'total_rank': [model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                    (model.module.get_total_rank() if multi_gpu else 0)],
                    'ranks': [', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                    'epoch': [0],
                    'loss': [float('inf')],
                    'val_loss': [float('inf')],
                    'rsquare': []}
    
    patience_counter = 0
    increase_patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        all_lrs.append(optimizer.param_groups[0]['lr'])

        for x in data_loader:
            x = x.to(device)

            # Forward pass
            x_hat = model(x)

            # Compute loss
            loss = loss_fn(x_hat, x)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(data_loader)
        train_losses.append(train_loss)

        # Step the learning rate scheduler
        if lr_scheduler:
            scheduler.step()

        # Validation phase
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = x_val.to(device)
                x_val_hat = model(x_val)
                val_loss += loss_fn(x_val_hat, x_val).item()

        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        all_losses.append(val_loss)

        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            #patience_counter = 0  # Reset patience counter
        #else:
        #    patience_counter += 1
        
        # Update progress bar with both loss and current rank information
        current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        all_ranks.append(current_rank)
        pbar.set_postfix({
            'loss': round(train_loss, 4),
            'rank': current_rank,
            'best_loss': round(best_loss, 4),
            'current_rsquare': round(current_rsquare, 3) if current_rsquare is not None else 'N/A',
            'patience': patience_counter,
        })
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if (epoch in rank_schedule) & (start_reduction) & (break_counter == 0):
            # Fixed: reduce_on_best_loss='rsquare', compressibility_type='direct', reduction_criterion='r_squared'
            
            # Get the r_square value using direct reconstruction R²
            current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu, metric=decision_metric)
            r_squares.append(current_rsquare)
            all_rsquares.append(current_rsquare)
            all_rsquare_epochs.append(epoch)
            
            should_reduce = False
            should_increase = False
            
            if len(r_squares_since_reduction) >= (patience-1):
                if threshold_type == 'absolute':
                    if (current_rsquare < min_rsquare) and (not bottom_reached):
                        should_increase = True
                else:
                    if current_rsquare > min_rsquare and (not bottom_reached):
                        should_increase = True
            if len(r_squares) >= 1:
                # For R²: reduce if R² is above threshold (good performance)
                if threshold_type == 'absolute':
                    if current_rsquare > min_rsquare:
                        should_reduce = True
                        r_squares_since_reduction = []
                    else:
                        r_squares_since_reduction.append(current_rsquare)
                else:
                    if current_rsquare < min_rsquare:
                        should_reduce = True
                        model.adaptive_layers[0].max_rank = model.adaptive_layers[0].active_dims
                        r_squares_since_reduction = []
                    else:
                        r_squares_since_reduction.append(current_rsquare)
            
            changes_made = False
            if should_reduce:
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank reduced at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
            elif should_increase:
                # Apply rank increase
                if multi_gpu:
                    changes_made = model.module.increase_rank(increase_ratio=1.1)
                else:
                    changes_made = model.increase_rank(increase_ratio=1.1)
                    
                if changes_made:
                    if verbose:
                        print(f"Rank increased at epoch {epoch}. New rank: {model.module.get_total_rank() if multi_gpu else model.get_total_rank()}")
                    break_counter = patience  # give model more time to re-learn the added dimensions
                    increase_counter += 1
                    if increase_counter >= 1:
                        bottom_reached = True
                        model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims
            
            if changes_made:
                patience_counter = 0  # Reset patience counter if rank was changed
                latest_change_epoch = epoch
                # also reset the scheduler
                if lr_scheduler:
                    # Reset optimizer lr to initial value before re-creating scheduler
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=start_lr_factor, end_factor=end_lr_factor, total_iters=1000)
            else:
                patience_counter += 1

            # Get new rank and store in history
            total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            rank_history['rsquare'].append(current_rsquare)
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
        else:
            # If we're in a break period (break_counter > 0), count down
            if (epoch in rank_schedule) & (start_reduction) & (break_counter > 0):
                break_counter -= 1

        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            start_reduction = True  # Start rank reduction after early stopping
            break_counter = 0  # start with no breaks (only used when increasing layers)
            if verbose:
                print(f"Patience exceeded at epoch {epoch}, starting rank reduction")

            # Calculate initial R² threshold
            current_rsquare = compute_direct_r_squared_unimodal(model, val_data, device, multi_gpu, metric=decision_metric)
            initial_square = current_rsquare
            
            # Calculate threshold based on threshold_type
            if threshold_type == 'relative':
                min_rsquare = initial_square * r_square_threshold
            elif threshold_type == 'absolute':
                min_rsquare = initial_square - r_square_threshold
            else:
                raise ValueError(f"threshold_type must be 'relative' or 'absolute', got {threshold_type}")
            
            rank_history['rsquare'].append(current_rsquare)
            if verbose:
                print(f"Initial R-squared value: {initial_square:.4f}, setting {threshold_type} threshold to {min_rsquare:.4f}")
            r_squares.append(current_rsquare)

        # early stopping but conditioned on rank reduction
        #if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is True) & (patience_counter >= patience):
        if start_reduction and ((epoch - latest_change_epoch) >= early_stopping):
            if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses[(latest_change_epoch+1):])) & (patience_counter >= patience): # so that we can converge even if the loss is higher after pruning
                if verbose:
                    print(f"Early stopping at epoch {epoch} with best loss {best_loss} and rank {rank_history['ranks'][-1]}")
                break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            reps_list.append(batch_reps)
            
            # Free memory
            del batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # Free memory
    del reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    if return_detailed_history:
        return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses], (all_losses, all_rsquares, all_rsquare_epochs, all_ranks, all_lrs)

    return model, reps, np.mean(train_losses[-5:]), current_rsquare if current_rsquare is not None else 0.0, rank_history, [train_losses, val_losses]