import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
import tqdm
import random
import math
import os

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--data', type=str, default='mm_sim', help='data name. options are mm_sim, bonemarrow, brain')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--stage', type=str, default='noisy', help='stage of the data in the data generation process. only valid for mm_sim. options are noisy, raw, processed.')
parser.add_argument('--n_batches', type=int, default=3, help='number of batches (10k samples each)')
parser.add_argument('--single_batch', type=bool, default=False, help='if True, only one batch (noise in data) is used for the computation. useful for debugging')
parser.add_argument('--norm', type=bool, default=False, help='if True, the data is normalized before computing the metrics')
#parser.add_argument('--steps', type=int, default=10, help='number of steps for the latent dimension search')
#parser.add_argument('--latent_min', type=int, default=10, help='minimum latent dimension')
#parser.add_argument('--latent_max', type=int, default=1000, help='maximum latent dimension')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--gpu', type=int, default=0, help='gpu id to use')
parser.add_argument('--multi_gpu', action='store_true', help='Use multiple GPUs if available')
parser.add_argument('--gpu_ids', type=str, default='', help='Comma-separated list of GPU IDs to use (e.g., "0,1,2"). If empty, all available GPUs will be used.')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs for training')
parser.add_argument('--ae_depth', type=int, default=2, help='depth of the autoencoder')
parser.add_argument('--ae_width', type=float, default=0.5, help='width of the autoencoder')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate of the autoencoder')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='weight decay of the autoencoder')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate of the autoencoder')
parser.add_argument('--batch_size', type=int, default=512, help='batch size for training')
#parser.add_argument('--floor', type=bool, default=False, help='if True, the r_square is floored to 1 decimal')
#parser.add_argument('--stop_at', type=int, default=10, help='stop when the difference between min and max dims is less than this value')
#parser.add_argument('--rank_ratio', type=float, default=0.5, help='ratio for rank reduction in the weight matrices (between 0 and 1)')
#parser.add_argument('--initial_rank_ratio', type=float, default=1.0, help='Initial rank ratio for adaptive rank reduction (1.0 = full rank)')
#parser.add_argument('--min_rank_ratio', type=float, default=0.1, help='Minimum rank ratio (lower bound for rank reduction)')
#parser.add_argument('--rank_reduction_frequency', type=int, default=10, help='How often to try reducing rank (in epochs)')
#parser.add_argument('--rank_reduction_threshold', type=float, default=0.01, help='Energy threshold for rank reduction')
#parser.add_argument('--warmup_epochs', type=int, default=0, help='Number of epochs to train before starting rank reduction')
#parser.add_argument('--reduce_on_best_loss', action='store_true', help='Only reduce rank when loss is at or better than best loss')
args = parser.parse_args()

###
# load data
###

if args.data == 'mm_sim':
    data_dir = './01_data/mm_sim/'
    data = []
    for i in range(args.n_batches):
        if args.modality == 'rna':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'atac':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'protein':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
        elif args.modality == 'rna-atac':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'rna-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'atac-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'all':
            temp_data = []
            if args.stage != 'noisy':
                raise ValueError("stage not supported for all modality.")
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0], dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1], dim=1, keepdim=True)
                temp_data[2] = temp_data[2] / torch.norm(temp_data[2], dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        else:
            raise ValueError("data modality not supported")
    data = torch.cat(data, dim=0)
    if args.single_batch:
        # load the batch info
        metadata = pd.concat([pd.read_csv(data_dir + f"causal_variables_batch_{i}.csv") for i in range(args.n_batches)])
        if args.modality == 'rna':
            indices = np.where(metadata['mrna_batch_effect'] == 0.0)[0]
        elif args.modality == 'protein':
            indices = np.where(metadata['prot_batch_effect'] == 0.0)[0]
        else:
            indices = np.arange(data.shape[0])
        print(f"using {len(indices)} samples from batch 0")
        data = data[indices, :]
elif args.data == 'bonemarrow':
    data_dir = '../../data/singlecell/'
    import anndata as ad
    data_file = ad.read_h5ad(data_dir + "human_bonemarrow.h5ad")
    modality_switch = np.where(data_file.var['modality'] == 'ATAC')[0][0]
    if args.single_batch:
        # take a the donor-site combination that gives the most samples
        data_file = data_file[(data_file.obs['covariate_Site'] == 'site4') & (data_file.obs['DonorID'] == 19593)]
    if args.modality == 'rna':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
    elif args.modality == 'atac':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
    elif args.modality == 'rna-atac':
        if args.norm:
            data_temp_a = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
            # normalize
            data_temp_a = data_temp_a / torch.norm(data_temp_a, dim=1, keepdim=True)
            data_temp_b = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
            # normalize
            data_temp_b = data_temp_b / torch.norm(data_temp_b, dim=1, keepdim=True)
            data = torch.cat([data_temp_a, data_temp_b], dim=1)
        else:
            data = torch.tensor(np.asarray(data_file.layers['counts'].todense()))
    else:
        raise ValueError("data modality not supported")
    if not args.single_batch:
        data = data[:args.n_batches*10000, :]
    data = data.cpu()
data = data.float()

n_samples = data.shape[0]
max_components = min(min(data.shape[0], data.shape[1]), 10000)
print(f"loaded data '{args.data}' with shape {data.shape}")

############################
# functions
############################

import torch.nn as nn
import torch
import torch.nn.functional as F
import tqdm
import gc
import numpy as np

# Configure device(s) for training
multi_gpu = args.multi_gpu
if args.multi_gpu and torch.cuda.is_available() and torch.cuda.device_count() > 1:
    if args.gpu_ids:
        # Use specific GPUs
        gpu_ids = [int(id) for id in args.gpu_ids.split(',')]
        
        # IMPORTANT: Don't set CUDA_VISIBLE_DEVICES here since that reindexes devices
        # and could cause conflicts with DataParallel which expects device 0
        
        # Just use the first GPU in the list as the primary device
        primary_gpu = gpu_ids[0]
        device = torch.device(f'cuda:{primary_gpu}')
        print(f"Using multiple GPUs: {gpu_ids} with primary GPU {primary_gpu}")
        
        # Since DataParallel requires the primary device to be cuda:0,
        # and your colleague is blocking cuda:0, we'll disable multi_gpu if needed
        if 0 not in gpu_ids:
            print("Warning: DataParallel requires cuda:0 to be available. Since it's not in your GPU list, falling back to single GPU mode.")
            multi_gpu = False
    else:
        raise ValueError("When using multi-GPU, please specify --gpu_ids with comma-separated GPU IDs (e.g., '0,1,2').")
else:
    # Single GPU or CPU
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using single device: {device}")
    multi_gpu = False

# set the random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

class AdaptiveRankReducedLinear(nn.Module):
    """
    Linear layer with adaptive rank reduction as described in the paper:
    "Rank-Reduced Neural Networks for Data Compression" (https://arxiv.org/pdf/2405.13980)
    """
    def __init__(self, in_features, out_features, initial_rank_ratio=1.0, min_rank=1, bias=True):
        super(AdaptiveRankReducedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.min_rank = max(1, min_rank)

        # Calculate maximum possible rank
        self.max_rank = min(in_features, out_features)
        
        # Start with full rank or specified initial rank
        self.current_rank = max(1, int(self.max_rank * initial_rank_ratio))
        
        # Create factorized weight matrices at full dimension
        self.U = nn.Parameter(torch.Tensor(out_features, self.max_rank))
        self.V = nn.Parameter(torch.Tensor(self.max_rank, in_features))
        
        # Keep track of active dimensions
        self.active_dims = self.current_rank
        
        # Keep track of singular values for adaptive rank reduction
        self.register_buffer('singular_values', torch.ones(self.max_rank))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
            
        self.init_parameters()
        
    def init_parameters(self):
        # Initialize using Xavier initialization
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def reduce_rank(self, new_rank):
        """Reduce the effective rank by zeroing out smallest singular values"""
        with torch.no_grad():
            # Compute current weight matrix
            W = torch.matmul(self.U, self.V)
            
            # Perform SVD
            U, S, V = torch.svd(W)
            
            # Zero out smallest singular values but keep matrix dimensions
            # Set all singular values after new_rank to zero
            zeroing_mask = torch.ones_like(S)
            zeroing_mask[new_rank:] = 0
            S_reduced = S * zeroing_mask
            
            # Store singular values for monitoring
            self.singular_values = S.detach().clone()
            
            # Reconstruct U and V with reduced effective rank
            # U_reduced will have zeros in columns beyond the new rank
            # V_reduced will have zeros in rows beyond the new rank
            sqrt_S = torch.sqrt(S_reduced)
            
            # Prepare scaled U and V matrices
            U_scaled = U * sqrt_S.unsqueeze(0)
            V_scaled = torch.matmul(torch.diag(sqrt_S), V.t())
            
            # Update parameters while maintaining original dimensions
            self.U.data.copy_(U_scaled)
            self.V.data.copy_(V_scaled)
            
            # Update current rank (for tracking)
            self.active_dims = new_rank
            
        return True
    
    def get_rank_reduction_info(self):
        """Return information about singular values for making rank reduction decisions"""
        # Calculate full SVD if needed
        with torch.no_grad():
            W = torch.matmul(self.U, self.V)
            _, S, _ = torch.svd(W)
            return S
    
    def forward(self, x):
        # Compute W = U * V on the fly
        # Use matmul for better efficiency with low-rank matrices
        # For effective rank reduction, we only use the active dimensions
        U_active = self.U[:, :self.active_dims]
        V_active = self.V[:self.active_dims, :]
        
        return F.linear(x, torch.matmul(U_active, V_active), self.bias)
    
    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}, current_rank={self.active_dims}'

class AdaptiveRankReducedAE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, depth=2, width=0.5, dropout=0.0, 
                 initial_rank_ratio=1.0, min_rank=1):
        super(AdaptiveRankReducedAE, self).__init__()
        
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.adaptive_layers = []  # Track adaptive rank layers for rank reduction
        
        hidden_dim = int(width * input_dim)
        ff_input_dim = input_dim
        self.convolution = False
        
        print(f"Creating AdaptiveRankReducedAE with\n   input_dim={input_dim}, latent_dim={latent_dim}, "
              f"depth={depth}, width={width}, dropout={dropout}")
        print(f"   hidden_dim: {hidden_dim}, ff_input_dim: {ff_input_dim}")
        print(f"   initial_rank_ratio: {initial_rank_ratio}, min_rank: {min_rank}")

        # Large input dimension handling with convolutional block
        if input_dim > 100000:
            print(f"Input dimension {input_dim} is too large, using convolutional block to reduce it.")
            padding = 0
            kernel_size = 3
            stride = 2
            # Use a 1D convolutional layer to reduce the input dimension
            self.encoder.append(torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.encoder.append(torch.nn.Flatten())
            reduced_dim = int((input_dim + 2 * padding - kernel_size) / stride + 1)
            print(f"Reduced input dimension from {input_dim} to {reduced_dim} using convolutional block.")
            hidden_dim = int(width * reduced_dim)
            ff_input_dim = reduced_dim
            self.convolution = True
            
        for i in range(depth):
            if i == (depth - 1):
                # Bottleneck layer - THIS is the only place to use AdaptiveRankReducedLinear
                encoder_layer = AdaptiveRankReducedLinear(
                    hidden_dim, latent_dim, 
                    initial_rank_ratio=initial_rank_ratio,
                    min_rank=min_rank
                )
                self.encoder.append(encoder_layer)
                self.adaptive_layers.append(encoder_layer)
                
                # Final decoder layer - standard linear
                decoder_layer = nn.Linear(hidden_dim, ff_input_dim)
                self.decoder.append(decoder_layer)
            else:
                if i == 0:
                    # First encoder layer - input to hidden (standard linear)
                    encoder_layer = nn.Linear(ff_input_dim, hidden_dim)
                    self.encoder.append(encoder_layer)
                    
                    # First decoder layer - latent to hidden (standard linear)
                    decoder_layer = nn.Linear(latent_dim, hidden_dim)
                    self.decoder.append(decoder_layer)
                else:
                    # Middle layers - all standard linear
                    encoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.encoder.append(encoder_layer)
                    
                    decoder_layer = nn.Linear(hidden_dim, hidden_dim)
                    self.decoder.append(decoder_layer)
                
                # Add activation
                self.encoder.append(nn.ReLU())
                self.decoder.append(nn.ReLU())
                
                # Add dropout if specified
                if dropout > 0.0:
                    self.encoder.append(nn.Dropout(dropout))
                    self.decoder.append(nn.Dropout(dropout))
                    
        if input_dim > 100000:
            # Add a final convolutional layer to upsample back to the original input dimension
            self.decoder.append(torch.nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding))
            self.decoder.append(torch.nn.Flatten())
    
    def reduce_rank(self, reduction_ratio=0.9, threshold=0.01):
        """Reduce rank of all adaptive layers based on singular value importance"""
        changes_made = False
        
        for layer in self.adaptive_layers:
            # Get singular values
            S = layer.get_rank_reduction_info()
            
            if len(S) <= layer.min_rank:
                continue  # Already at minimum rank
                
            # Calculate normalized cumulative energy
            energy = S**2
            normalized_energy = energy / energy.sum()
            cumulative_energy = torch.cumsum(normalized_energy, dim=0)
            
            # Find the rank that preserves specified energy threshold
            # Make sure we don't go below the minimum rank
            target_rank = max(layer.min_rank, 
                             torch.sum(cumulative_energy < (1.0 - threshold)).item())
            
            # Alternative: just reduce by fixed ratio, but not below min_rank
            current_rank = layer.active_dims
            ratio_rank = max(layer.min_rank, int(current_rank * reduction_ratio))
            
            # Take the larger of the two approaches
            new_rank = max(target_rank, ratio_rank)
            
            # Only reduce if new rank is smaller than current
            if new_rank < current_rank:
                layer.reduce_rank(new_rank)
                changes_made = True
                
        return changes_made
    
    def get_total_rank(self):
        """Return total rank across all adaptive layers"""
        return sum(layer.active_dims for layer in self.adaptive_layers)
    
    def encode(self, x):
        if self.convolution:
            x = x.view(x.shape[0], 1, -1)
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x):
        for layer in self.decoder:
            if self.convolution and isinstance(layer, nn.ConvTranspose1d):
                x = x.view(x.shape[0], 1, -1)
            x = layer(x)
        return x
    
    def forward(self, x):
        x = self.encode(x)
        x = self.decode(x)
        return x

def parallel_linear_regression(x, y, n_samples, n_samples_train, n_epochs=500, early_stopping=50):
    # Declare multi_gpu as global so it can be accessed
    global multi_gpu
    
    import tqdm
    
    y_mean = y[n_samples_train:n_samples].mean(dim=0)

    # Adjust batch size if using multiple GPUs
    batch_size = 128
    if multi_gpu:
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
        batch_size = batch_size * num_gpus
        print(f"Using batch size {batch_size} for linear regression with {num_gpus} GPUs")

    # loaders
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[:n_samples_train], y[:n_samples_train]), batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[n_samples_train:n_samples], y[n_samples_train:n_samples]), batch_size=batch_size, shuffle=False)

    # set up a linear layer to use for parallel regression - explicitly use float32
    linear = nn.Linear(x.shape[1], y.shape[1]).to(device).float()
    
    # Enable multi-GPU for linear model if available
    if multi_gpu:
        try:
            # Make sure model is on the right device before wrapping
            linear = linear.to(device)
            for param in linear.parameters():
                if param.device != device:
                    param.data = param.data.to(device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            if args.gpu_ids:
                # DataParallel uses indices starting from 0 after CUDA_VISIBLE_DEVICES is set
                num_gpus = len(args.gpu_ids.split(','))
                linear = nn.DataParallel(linear, device_ids=list(range(num_gpus)))
            else:
                linear = nn.DataParallel(linear)
        except Exception as e:
            print(f"Failed to use DataParallel for linear model: {e}")
            print(f"Falling back to single GPU")
            linear = linear.to(device)
    
    optimizer = torch.optim.Adam(linear.parameters(), lr=0.0001, weight_decay=0)
    loss_fn = nn.MSELoss()

    # train the linear layer
    val_losses = []
    pbar = tqdm.tqdm(range(n_epochs))
    for epoch in pbar:
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()  # Force float32
            optimizer.zero_grad()
            y_pred = linear(x_batch)
            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            optimizer.step()
        val_loss = 0
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device).float(), y_val.to(device).float()  # Force float32
            with torch.no_grad():
                y_pred = linear(x_val)
                val_loss += loss_fn(y_pred, y_val).item()
        val_losses.append(val_loss / len(val_loader))
        pbar.set_postfix({'val loss': round(val_loss / len(val_loader), 4)})
        if epoch > early_stopping and min(val_losses[-early_stopping:]) > min(val_losses):
            print("Early stopping in linear regression at epoch ", epoch)
            break
    
    # When using DataParallel for prediction, we need to handle it differently
    if multi_gpu:
        # Process in batches to avoid OOM
        all_preds = []
        test_data = x[n_samples_train:n_samples].to(device)
        test_batch_size = batch_size
        with torch.no_grad():
            for i in range(0, test_data.size(0), test_batch_size):
                end_idx = min(i + test_batch_size, test_data.size(0))
                batch_input = test_data[i:end_idx]
                pred_batch = linear(batch_input).cpu()
                all_preds.append(pred_batch)
            y_pred = torch.cat(all_preds, dim=0)
    else:
        with torch.no_grad():
            y_pred = linear(x[n_samples_train:n_samples].to(device)).cpu()
    
    y_pred = y_pred.detach()
    
    # Simplified R² calculation
    r_squares = 1 - (((y[n_samples_train:n_samples] - y_pred)**2).sum(0) / ((y[n_samples_train:n_samples] - y_mean)**2).sum(0))
    
    # Clean up
    del linear, optimizer, train_loader, val_loader
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return r_squares

def train_overcomplete_ae(data, n_samples_train, latent_dim, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=1, 
                         rank_schedule=None, rank_reduction_frequency=10, 
                         rank_reduction_threshold=0.01, warmup_epochs=0,
                         patience=10,
                         reduce_on_best_loss='false'):
    """
    Train an autoencoder with adaptive rank reduction
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank_ratio: Minimum rank ratio (lower bound)
    - rank_schedule: Custom schedule for rank reduction (epochs at which to reduce)
    - rank_reduction_frequency: How often to try reducing rank (in epochs)
    - rank_reduction_threshold: Energy threshold for rank reduction
    - warmup_epochs: Number of epochs to train before starting rank reduction
    - reduce_on_best_loss: Only reduce rank when loss is at or better than best loss
    """
    # Declare multi_gpu as global so it can be accessed
    global multi_gpu
    
    # Create model with adaptive rank reduction
    model = AdaptiveRankReducedAE(
        data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank
    ).to(device)
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    # Create optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        # Reduce rank every rank_reduction_frequency epochs, but start after warmup period
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                 epochs, 
                                 rank_reduction_frequency))
    
    # Train the model
    train_losses = []
    best_loss = float('inf')
    patience_counter = 0
    rank_history = [model.get_total_rank() if hasattr(model, 'get_total_rank') else 
                   (model.module.get_total_rank() if multi_gpu else 0)]
    
    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        
        for x in data_loader:
            x = x.to(device)
            
            # Forward pass
            x_hat = model(x)
            
            # Compute loss
            loss = loss_fn(x_hat, x)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        train_loss /= len(data_loader)
        train_losses.append(train_loss)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1
        
        # Apply rank reduction at scheduled epochs, respecting warmup period
        if epoch in rank_schedule:
            # Check if we should reduce rank based on loss condition
            should_reduce = True
            if reduce_on_best_loss == 'true' and train_loss > best_loss:
                should_reduce = False
                # Don't print a separate message, will show in progress bar
            elif reduce_on_best_loss == 'stagnation' and patience_counter < patience:
                should_reduce = False
                # Don't print a separate message, will show in progress bar
            
            if should_reduce:
                # Get current total rank
                total_rank_before = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
                # Apply rank reduction
                if multi_gpu:
                    changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                else:
                    changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold)
                    
                # Get new rank but don't print separate message
                total_rank_after = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
                
                # Store current rank in history
                if changes_made:
                    rank_history.append(total_rank_after)
        
        # Update progress bar with both loss and current rank information
        current_rank = model.module.get_total_rank() if multi_gpu else model.get_total_rank()
        pbar.set_postfix({
            'loss': round(train_loss, 4),
            'rank': current_rank,
            'best_loss': round(best_loss, 4)
        })
        if epoch > early_stopping and min(train_losses[-early_stopping:]) > min(train_losses):
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Calculate latent representations in batches
    reps_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples_train, batch_size):
            end_idx = min(i + batch_size, n_samples_train)
            x_batch = data[i:end_idx].to(device)
            
            # If using DataParallel, need to access module directly or handle the encoding differently
            if multi_gpu:
                batch_reps = model.module.encode(x_batch).cpu()
            else:
                batch_reps = model.encode(x_batch).cpu()
                
            # No need to convert dtype
            reps_list.append(batch_reps)
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Combine latent representations from all batches
    reps = torch.cat(reps_list, dim=0)
    
    # empty cache
    del model, optimizer, loss_fn, data_loader
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    # Linear regression evaluation
    r_squares = parallel_linear_regression(reps, data[:n_samples_train], n_samples_train, int(n_samples_train*0.9), n_epochs=500, early_stopping=50)
    
    # remove all nan and inf values
    r_squares = r_squares[torch.isfinite(r_squares)]
    
    # Free memory
    del reps, reps_list
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    gc.collect()

    return np.mean(train_losses[-5:]), r_squares.mean().item(), rank_history

############################
# run
############################

# Replace the latent dimension search with a hyperparameter sweep over rank reduction parameters
print("Starting hyperparameter sweep for rank reduction parameters")

# Fixed latent dimension (use the max as a starting point)
#latent_dim = args.latent_max
#print(f"Using fixed latent dimension: {latent_dim}")

# Define parameter ranges for the sweep
#latent_dims = [50,100,500,1000]
latent_dims = [500]
initial_rank_ratios = [1.0]  # Starting rank ratio
#rank_reduction_frequencies = [5, 10, 20]  # How often to reduce rank (in epochs)
rank_reduction_frequencies = [5]  # How often to reduce rank (in epochs)
#rank_reduction_thresholds = [0.001, 0.01, 0.1]  # Energy threshold for pruning (version 1)
rank_reduction_thresholds = [0.005, 0.05]  # Energy threshold for pruning (version 2)
#warmup_epochs_options = [0, 100]  # Options for warmup epochs
warmup_epochs_options = [0]  # Options for warmup epochs
#reduce_on_best_loss_options = ['false', 'true', 'stagnation']  # Options for conditional reduction
reduce_on_best_loss_options = [False]  # Options for conditional reduction
data_fraction = 0.9  # Fraction of data to use for training
early_stopping = 100

# Initialize arrays to store results
all_params = []
all_losses = []
all_r_squares = []
all_final_ranks = []
all_times = []

# Create a filename based on the latent dimension instead of min-max range
out_file = f"03_results/reports/reducedrank/rank_sweep_{args.data}-{args.modality}-{args.stage}_n{n_samples}"
out_file += f"patience{early_stopping}_seed{args.seed}_run1-thresholds_v2"
#out_file += f"patience{early_stopping}_seed{args.seed}_run3-reductionschedule"
out_file += ".csv"

# measure the execution time
start_time = time.time()

# Run the hyperparameter sweep
total_configs = len(initial_rank_ratios) * len(rank_reduction_frequencies) * len(rank_reduction_thresholds) * len(warmup_epochs_options) * len(reduce_on_best_loss_options)
print(f"Total configurations to test: {total_configs}")
config_counter = 0

for latent_dim in latent_dims:
    for init_ratio in initial_rank_ratios:
        for red_freq in rank_reduction_frequencies:
            for red_thresh in rank_reduction_thresholds:
                for warmup in warmup_epochs_options:
                    for reduce_on_best in reduce_on_best_loss_options:
                        config_counter += 1
                        config_start_time = time.time()
                        
                        # Print progress information
                        print(f"\nConfiguration {config_counter}/{total_configs}:")
                        print(f"  initial_rank_ratio: {init_ratio}")
                        print(f"  rank_reduction_frequency: {red_freq}")
                        print(f"  rank_reduction_threshold: {red_thresh}")
                        print(f"  warmup_epochs: {warmup}")
                        print(f"  reduce_on_best_loss: {reduce_on_best}")
                        
                        # Train the autoencoder with these parameters
                        train_loss, r_square, rank_history = train_overcomplete_ae(
                            data, int(data.shape[0]*data_fraction), latent_dim, 
                            epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, 
                            ae_depth=args.ae_depth, ae_width=args.ae_width, 
                            dropout=args.dropout, wd=args.weight_decay,
                            early_stopping=early_stopping,
                            initial_rank_ratio=init_ratio,
                            #min_rank_ratio=args.min_rank_ratio,
                            rank_reduction_frequency=red_freq,
                            rank_reduction_threshold=red_thresh,
                            warmup_epochs=warmup,
                            patience=10,
                            reduce_on_best_loss=reduce_on_best
                        )
                        
                        # Get the final rank (last value in rank history)
                        final_rank = rank_history[-1] if rank_history else 0
                        
                        # Calculate time taken for this configuration
                        config_time = (time.time() - config_start_time) / 60  # in minutes
                        
                        # Store the results
                        param_str = f"init={init_ratio},freq={red_freq},thresh={red_thresh},warmup={warmup},best_loss={reduce_on_best},latent={latent_dim}"
                        all_params.append(param_str)
                        all_losses.append(train_loss)
                        all_r_squares.append(r_square)
                        all_final_ranks.append(final_rank)
                        all_times.append(config_time)
                        
                        print(f"Results: loss={train_loss:.4f}, r_square={r_square:.4f}, final_rank={final_rank}")
                        print(f"Time taken: {config_time:.2f} minutes")
                        
                        # Save intermediate results after each configuration
                        intermediate_metrics = {
                            'params': all_params,
                            'initial_rank_ratio': [p.split(',')[0].split('=')[1] for p in all_params],
                            'rank_reduction_frequency': [p.split(',')[1].split('=')[1] for p in all_params],
                            'rank_reduction_threshold': [p.split(',')[2].split('=')[1] for p in all_params],
                            'warmup_epochs': [p.split(',')[3].split('=')[1] for p in all_params],
                            'reduce_on_best_loss': [p.split(',')[4].split('=')[1] for p in all_params],
                            'latent_dim': [p.split(',')[5].split('=')[1] for p in all_params],
                            'train_loss': all_losses,
                            'r_square': all_r_squares,
                            'final_rank': all_final_ranks,
                            'time': all_times,
                        }
                        pd.DataFrame(intermediate_metrics).to_csv(out_file)

end_time = time.time()
total_time = (end_time - start_time) / 60  # in minutes

print(f"\nHyperparameter sweep completed in {total_time:.2f} minutes")
print(f"Results saved to {out_file}")