import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
import tqdm
import random
import math
import os

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--data', type=str, default='mm_sim', help='data name. options are mm_sim, bonemarrow, brain')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--stage', type=str, default='noisy', help='stage of the data in the data generation process. only valid for mm_sim. options are noisy, raw, processed.')
parser.add_argument('--n_batches', type=int, default=3, help='number of batches (10k samples each)')
parser.add_argument('--single_batch', type=bool, default=False, help='if True, only one batch (noise in data) is used for the computation. useful for debugging')
parser.add_argument('--norm', type=bool, default=False, help='if True, the data is normalized before computing the metrics')
parser.add_argument('--steps', type=int, default=10, help='number of steps for the latent dimension search')
parser.add_argument('--latent_min', type=int, default=10, help='minimum latent dimension')
parser.add_argument('--latent_max', type=int, default=1000, help='maximum latent dimension')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--gpu', type=int, default=0, help='gpu id to use')
parser.add_argument('--multi_gpu', action='store_true', help='Use multiple GPUs if available')
parser.add_argument('--gpu_ids', type=str, default='', help='Comma-separated list of GPU IDs to use (e.g., "0,1,2"). If empty, all available GPUs will be used.')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs for training')
parser.add_argument('--ae_depth', type=int, default=2, help='depth of the autoencoder')
parser.add_argument('--ae_width', type=float, default=0.5, help='width of the autoencoder')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate of the autoencoder')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight decay of the autoencoder')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate of the autoencoder')
parser.add_argument('--batch_size', type=int, default=512, help='batch size for training')
#parser.add_argument('--floor', type=bool, default=False, help='if True, the r_square is floored to 1 decimal')
parser.add_argument('--stop_at', type=int, default=10, help='stop when the difference between min and max dims is less than this value')
parser.add_argument('--total_decrease_threshold', type=float, default=0.8, help='threshold for total decrease in loss to consider a latent dimension valid')
parser.add_argument('--step_threshold', type=float, default=0.1, help='threshold for step increase in loss to consider a latent dimension valid')
args = parser.parse_args()

###
# load data
###

if args.data == 'mm_sim':
    data_dir = './01_data/mm_sim/'
    data = []
    for i in range(args.n_batches):
        if args.modality == 'rna':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'atac':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'protein':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
        elif args.modality == 'rna-atac':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'rna-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'atac-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'all':
            temp_data = []
            if args.stage != 'noisy':
                raise ValueError("stage not supported for all modality.")
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0], dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1], dim=1, keepdim=True)
                temp_data[2] = temp_data[2] / torch.norm(temp_data[2], dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        else:
            raise ValueError("data modality not supported")
    data = torch.cat(data, dim=0)
    if args.single_batch:
        # load the batch info
        metadata = pd.concat([pd.read_csv(data_dir + f"causal_variables_batch_{i}.csv") for i in range(args.n_batches)])
        if args.modality == 'rna':
            indices = np.where(metadata['mrna_batch_effect'] == 0.0)[0]
        elif args.modality == 'protein':
            indices = np.where(metadata['prot_batch_effect'] == 0.0)[0]
        else:
            indices = np.arange(data.shape[0])
        print(f"using {len(indices)} samples from batch 0")
        data = data[indices, :]
elif args.data == 'bonemarrow':
    data_dir = '../../data/singlecell/'
    import anndata as ad
    data_file = ad.read_h5ad(data_dir + "human_bonemarrow.h5ad")
    modality_switch = np.where(data_file.var['modality'] == 'ATAC')[0][0]
    if args.single_batch:
        # take a the donor-site combination that gives the most samples
        data_file = data_file[(data_file.obs['covariate_Site'] == 'site4') & (data_file.obs['DonorID'] == 19593)]
    if args.modality == 'rna':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
    elif args.modality == 'atac':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
    elif args.modality == 'rna-atac':
        if args.norm:
            data_temp_a = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
            # normalize
            data_temp_a = data_temp_a / torch.norm(data_temp_a, dim=1, keepdim=True)
            data_temp_b = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
            # normalize
            data_temp_b = data_temp_b / torch.norm(data_temp_b, dim=1, keepdim=True)
            data = torch.cat([data_temp_a, data_temp_b], dim=1)
        else:
            data = torch.tensor(np.asarray(data_file.layers['counts'].todense()))
    else:
        raise ValueError("data modality not supported")
    if not args.single_batch:
        data = data[:args.n_batches*10000, :]
    data = data.cpu()
data = data.float()

n_samples = data.shape[0]
max_components = min(min(data.shape[0], data.shape[1]), 10000)
print(f"loaded data '{args.data}' with shape {data.shape}")

############################
# functions
############################

import torch.nn as nn
import torch
import tqdm
import gc

# Configure device(s) for training
multi_gpu = args.multi_gpu
if args.multi_gpu and torch.cuda.is_available() and torch.cuda.device_count() > 1:
    if args.gpu_ids:
        # Use specific GPUs
        gpu_ids = [int(id) for id in args.gpu_ids.split(',')]
        
        # IMPORTANT: Don't set CUDA_VISIBLE_DEVICES here since that reindexes devices
        # and could cause conflicts with DataParallel which expects device 0
        
        # Just use the first GPU in the list as the primary device
        primary_gpu = gpu_ids[0]
        device = torch.device(f'cuda:{primary_gpu}')
        print(f"Using multiple GPUs: {gpu_ids} with primary GPU {primary_gpu}")
        
        # Since DataParallel requires the primary device to be cuda:0,
        # and your colleague is blocking cuda:0, we'll disable multi_gpu if needed
        if 0 not in gpu_ids:
            print("Warning: DataParallel requires cuda:0 to be available. Since it's not in your GPU list, falling back to single GPU mode.")
            multi_gpu = False
    else:
        raise ValueError("When using multi-GPU, please specify --gpu_ids with comma-separated GPU IDs (e.g., '0,1,2').")
else:
    # Single GPU or CPU
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using single device: {device}")
    multi_gpu = False

# set the random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

# I will try with running an overcomplete autoencoder with varying latent dimensions
class OvercompleteAE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, depth=2, width=0.5, dropout=0.0):
        super(OvercompleteAE, self).__init__()

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        hidden_dim = int(width * input_dim)
        ff_input_dim = input_dim
        self.convolution = False
        print(f"Creating OvercompleteAE with\n   input_dim={input_dim}, latent_dim={latent_dim}, depth={depth}, width={width}, dropout={dropout}")
        print(f"   hidden_dim: {hidden_dim}, ff_input_dim: {ff_input_dim}")

        # if the input dim is too large (> 100000), use a convolutional block to reduce the input dimension in a parameter-efficient way
        if input_dim > 100000:
            print(f"Input dimension {input_dim} is too large, using convolutional block to reduce it.")
            padding = 0
            kernel_size = 3
            stride = 2
            # Use a 1D convolutional layer to reduce the input dimension
            self.encoder.append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.encoder.append(torch.nn.Flatten())
            reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
            print(f"Reduced input dimension from {input_dim} to {reduced_dim} using convolutional block.")
            hidden_dim = int(width * reduced_dim)
            ff_input_dim = reduced_dim
            self.convolution = True
        for i in range(depth):
            if i == (depth - 1):
                self.encoder.append(torch.nn.Linear(hidden_dim, latent_dim))
                self.decoder.append(torch.nn.Linear(hidden_dim, ff_input_dim))
            else:
                if i == 0:
                    self.encoder.append(torch.nn.Linear(ff_input_dim, hidden_dim))
                    self.decoder.append(torch.nn.Linear(latent_dim, hidden_dim))
                else:
                    self.encoder.append(torch.nn.Linear(hidden_dim, hidden_dim))
                    self.decoder.append(torch.nn.Linear(hidden_dim, hidden_dim))
                self.encoder.append(torch.nn.ReLU())
                self.decoder.append(torch.nn.ReLU())
                # Add dropout after ReLU activations if dropout rate > 0
                if dropout > 0.0:
                    self.encoder.append(torch.nn.Dropout(dropout))
                    self.decoder.append(torch.nn.Dropout(dropout))
        if input_dim > 100000:
            # Add a final convolutional layer to upsample back to the original input dimension
            self.decoder.append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.decoder.append(torch.nn.Flatten())
    
    def encode(self, x):
        if self.convolution:
            x = x.view(x.shape[0], 1, -1)
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x):
        for layer in self.decoder:
            if self.convolution and isinstance(layer, nn.ConvTranspose1d):
                # If the last layer is a ConvTranspose1d, reshape x to match expected input shape
                x = x.view(x.shape[0], 1, -1)
            x = layer(x)
        return x
    
    def forward(self, x):
        # encode
        x = self.encode(x)
        # decode
        x = self.decode(x)
        return x

def parallel_linear_regression(x, y, n_samples, n_samples_train, n_epochs=500, early_stopping=50):
    # Declare multi_gpu as global so it can be accessed
    global multi_gpu
    
    import tqdm
    
    y_mean = y[n_samples_train:n_samples].mean(dim=0)

    # Adjust batch size if using multiple GPUs
    batch_size = 128
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
        batch_size = batch_size * num_gpus
        print(f"Using batch size {batch_size} for linear regression with {num_gpus} GPUs")

    # loaders
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[:n_samples_train], y[:n_samples_train]), batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[n_samples_train:n_samples], y[n_samples_train:n_samples]), batch_size=batch_size, shuffle=False)

    # set up a linear layer to use for parallel regression - explicitly use float32
    linear = nn.Linear(x.shape[1], y.shape[1]).to(device).float()
    
    # Enable multi-GPU for linear model if available
    if multi_gpu:
        try:
            # Make sure model is on the right device before wrapping
            linear = linear.to(device)
            for param in linear.parameters():
                if param.device != device:
                    param.data = param.data.to(device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            if args.gpu_ids:
                # DataParallel uses indices starting from 0 after CUDA_VISIBLE_DEVICES is set
                num_gpus = len(args.gpu_ids.split(','))
                linear = nn.DataParallel(linear, device_ids=list(range(num_gpus)))
            else:
                linear = nn.DataParallel(linear)
        except Exception as e:
            print(f"Failed to use DataParallel for linear model: {e}")
            print(f"Falling back to single GPU")
            linear = linear.to(device)
    
    optimizer = torch.optim.Adam(linear.parameters(), lr=0.0001, weight_decay=0)
    loss_fn = nn.MSELoss()

    # train the linear layer
    val_losses = []
    pbar = tqdm.tqdm(range(n_epochs))
    for epoch in pbar:
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()  # Force float32
            optimizer.zero_grad()
            y_pred = linear(x_batch)
            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            optimizer.step()
        val_loss = 0
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device).float(), y_val.to(device).float()  # Force float32
            with torch.no_grad():
                y_pred = linear(x_val)
                val_loss += loss_fn(y_pred, y_val).item()
        val_losses.append(val_loss / len(val_loader))
        pbar.set_postfix({'val loss': round(val_loss / len(val_loader), 4)})
        if epoch > early_stopping and min(val_losses[-early_stopping:]) > min(val_losses):
            print("Early stopping in linear regression at epoch ", epoch)
            break
    
    # When using DataParallel for prediction, we need to handle it differently
    if multi_gpu:
        # Process in batches to avoid OOM
        all_preds = []
        test_data = x[n_samples_train:n_samples].to(device)
        test_batch_size = batch_size
        with torch.no_grad():
            for i in range(0, test_data.size(0), test_batch_size):
                end_idx = min(i + test_batch_size, test_data.size(0))
                batch_input = test_data[i:end_idx]
                pred_batch = linear(batch_input).cpu()
                all_preds.append(pred_batch)
            y_pred = torch.cat(all_preds, dim=0)
    else:
        with torch.no_grad():
            y_pred = linear(x[n_samples_train:n_samples].to(device)).cpu()
    
    y_pred = y_pred.detach()
    
    # Simplified R² calculation
    r_squares = 1 - (((y[n_samples_train:n_samples] - y_pred)**2).sum(0) / ((y[n_samples_train:n_samples] - y_mean)**2).sum(0))
    
    # Clean up
    del linear, optimizer, train_loader, val_loader
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return r_squares

def train_overcomplete_ae(data, n_samples_train, latent_dim, epochs=100, lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5):
    # Declare multi_gpu as global so it can be accessed
    global multi_gpu
    
    # Create the model (simplified - no precision handling)
    model = OvercompleteAE(data.shape[1], latent_dim, depth=ae_depth, width=ae_width, dropout=dropout).to(device)
    
    # Enable multi-GPU training if available
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Create the loss function
    loss_fn = torch.nn.MSELoss()

    # Create the data loaders - now with validation loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    # Create validation loader using a portion of the training data
    val_size = int(n_samples_train * 0.1)  # Use 10% of training data for validation
    val_data = data[n_samples_train-val_size:n_samples_train]
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

    # Train the model
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience = 50  # Early stopping patience
    patience_counter = 0
    
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        for x in data_loader:
            # Simply move to device, no dtype conversion
            x = x.to(device)
            
            # forward pass
            x_hat = model(x)
            # compute the loss
            loss = loss_fn(x_hat, x)
            # zero the gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # update the weights
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(data_loader)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x_val in val_loader:
                x_val = x_val.to(device)
                x_val_hat = model(x_val)
                val_loss += loss_fn(x_val_hat, x_val).item()
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        # Update progress bar with both train and validation loss
        pbar.set_postfix({'train_loss': round(train_loss, 4), 'val_loss': round(val_loss, 4)})
        
        # Early stopping logic
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            # No need to convert dtype
            reps_list.append(batch_reps)
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # empty cache
    del model, optimizer, loss_fn, data_loader
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    # Linear regression evaluation
    r_squares = parallel_linear_regression(reps, data[:n_samples_train], n_samples_train, int(n_samples_train*0.9), n_epochs=500, early_stopping=50)
    
    # remove all nan and inf values
    r_squares = r_squares[torch.isfinite(r_squares)]
    
    # Free memory
    del reps, reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    return np.mean(val_losses[-5:]), r_squares.mean().item()  # Return validation loss instead of train loss

############################
# run
############################

# we employ a search for the optimal latent dimension by looking at when loss increases and r_squares decrease
steps = args.steps
data_fraction = 0.9
# for the batch size, find the closest polynomial of 2 to 5% of the data
batch_size = args.batch_size
latent_min = args.latent_min
# latent max should be the max between the args max and the 90% PCA
# check if ae_width * input_dim is too big (greater than 10000)
if args.ae_width * data.shape[1] > 10000:
    # choose an ae_width that is 10000 / input_dim
    args.ae_width = round(10000 / data.shape[1], 2)
    print(f"Setting ae_width to {args.ae_width} as it was too big for the input dimension {data.shape[1]}")

# If we're using low precision and PCA is required, we need to convert to float32 temporarily
temp_data_for_pca = None
"""
if args.stage == 'noisy':
    if data.shape[0] * data.shape[1] > 1e8:  # Large dataset
        print(f"Using subset of data for PCA due to memory constraints")
        # Use a subset of data for PCA
        sample_size = min(10000, data.shape[0])
        indices = torch.randperm(data.shape[0])[:sample_size]
        temp_data_for_pca = data[indices].cpu().numpy()
    else:
        temp_data_for_pca = data.cpu().numpy()
        
    from sklearn.decomposition import PCA
    pca = PCA(n_components=min(temp_data_for_pca.shape[1], max_components), random_state=0)
    pca.fit(temp_data_for_pca)
    latent_max = max(args.latent_max, int(np.where(np.cumsum(pca.explained_variance_ratio_) >= 0.9)[0][0]))
    
    # Clean up
    del temp_data_for_pca, pca
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
else:
"""
latent_max = args.latent_max

latent_dims = [latent_max, latent_min]
losses = []
r_squares = []
print("Starting with limits: ", latent_min, latent_max)

total_decrease_threshold = args.total_decrease_threshold
step_threshold = args.step_threshold
out_file = f"03_results/reports/cliff/id_rcliff_{args.data}-{args.modality}-{args.stage}_n{n_samples}"
if args.single_batch:
    out_file += "_singlebatch"
if args.norm:
    out_file += "_norm"
#if args.floor:
#    out_file += "_floor"
out_file += f"_latent{args.latent_min}-{latent_max}_steps{args.steps}_seed{args.seed}"
out_file += f"_totalthreshold{total_decrease_threshold}"
out_file += ".csv"

# measure the execution time
start_time = time.time()
train_losses_temp, r_squares_temp = train_overcomplete_ae(data, int(data.shape[0]*data_fraction), latent_max, epochs=args.epochs, lr=args.lr, batch_size=batch_size, ae_depth=args.ae_depth, ae_width=args.ae_width, dropout=args.dropout, wd=args.weight_decay)
losses.append(train_losses_temp)
r_squares.append(r_squares_temp)
train_losses_max = train_losses_temp
above_threshold = [1]
# round the max r_square to 1 decimal (but dont round up)
#if args.floor:
#    r_squares_max = math.floor(r_squares_temp*10)/10
#else:
#    r_squares_max = round(r_squares_temp,2)
print("Latent dim: {}, train loss: {}, r_square: {}".format(latent_max, train_losses_temp, r_squares_temp))
total_loss_threshold = train_losses_max * (1 + 1 - total_decrease_threshold)
print(f"Total loss threshold for valid latent dimension: {total_loss_threshold}")

train_losses_temp, r_squares_temp = train_overcomplete_ae(data, int(data.shape[0]*data_fraction), latent_min, epochs=args.epochs, lr=args.lr, batch_size=batch_size, ae_depth=args.ae_depth, ae_width=args.ae_width, dropout=args.dropout, wd=args.weight_decay)
losses.append(train_losses_temp)
r_squares.append(r_squares_temp)
train_losses_min = train_losses_temp
r_squares_min = r_squares_temp
above_threshold.append(int(train_losses_temp <= total_loss_threshold))
print("Latent dim: {}, train loss: {}, r_square: {}".format(latent_min, train_losses_temp, r_squares_temp))

print("Starting search for optimal latent dimension")

while steps > 0:
    next_latent = int((latent_min + latent_max) / 2)
    print("Next latent dimension: ", next_latent)
    val_loss_temp, r_squares_temp = train_overcomplete_ae(
        data, int(data.shape[0]*data_fraction), next_latent, 
        epochs=args.epochs, lr=args.lr, batch_size=batch_size, 
        ae_depth=args.ae_depth, ae_width=args.ae_width, 
        dropout=args.dropout, wd=args.weight_decay
    )
    latent_dims.append(next_latent)
    losses.append(val_loss_temp)
    r_squares.append(r_squares_temp)
    above_threshold.append(int(val_loss_temp <= total_loss_threshold))
    print("Latent dim: {}, val loss: {}, r_square: {}".format(next_latent, val_loss_temp, r_squares_temp))
    
    # Use validation loss for determining the next step in binary search
    if val_loss_temp > total_loss_threshold:
        latent_min = next_latent
        train_losses_min = val_loss_temp
        r_squares_min = r_squares_temp
    elif val_loss_temp <= (1+step_threshold)*train_losses_max:  # allow a 10% increase in loss
        latent_max = next_latent
        train_losses_max = val_loss_temp
        r_squares_max = r_squares_temp
    else:
        latent_min = next_latent
        train_losses_min = val_loss_temp
        r_squares_min = r_squares_temp
    
    steps -= 1
    if (latent_max - latent_min) < args.stop_at:
        print("Stopping search as the difference between min and max dims is less than ", args.stop_at)
        break
end_time = time.time()

# return the mean between the min and max latent dimensions plus minus the stop_at value
final_latent = int((latent_min + latent_max) / 2)
print(f"Final latent dimension: {final_latent} +- {args.stop_at/2}")

out_metrics = {
    'latent_dim': latent_dims,
    'val_loss': losses,
    'r_square': r_squares,
    'above_threshold': above_threshold,
    'time': [(end_time - start_time)/60]* len(latent_dims),
}

# save metrics
df_out_metrics = pd.DataFrame(out_metrics)
df_out_metrics.to_csv(out_file)