import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import sparse
import time
import tqdm
import random
import math
import os
import gc
import argparse

# --- Argument Parser ---
parser = argparse.ArgumentParser(description='Compute Intrinsic Dimension using ARD-VAE')
parser.add_argument('--data', type=str, default='mm_sim', help='Data name (e.g., mm_sim, bonemarrow)')
parser.add_argument('--modality', type=str, default='rna', help='Data modality (e.g., rna, atac, protein)')
parser.add_argument('--stage', type=str, default='noisy', help='Stage for mm_sim data (noisy, raw, processed)')
parser.add_argument('--n_batches', type=int, default=3, help='Number of data batches to load (10k samples each)')
parser.add_argument('--single_batch', type=bool, default=False, help='If True, use only one noise-free batch for debugging')
parser.add_argument('--norm', type=bool, default=False, help='If True, normalize data')
parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
parser.add_argument('--gpu', type=int, default=0, help='GPU ID to use for single GPU')
parser.add_argument('--multi_gpu', action='store_true', help='Use multiple GPUs if available')
parser.add_argument('--gpu_ids', type=str, default='', help='Comma-separated list of GPU IDs for multi-GPU')
parser.add_argument('--epochs', type=int, default=5000, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=512, help='Training batch size')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer')
parser.add_argument('--ard_latent_dim', type=int, default=500, help='Latent dimension for ARD-VAE (should be larger than expected ID)')
parser.add_argument('--kl_weight', type=float, default=1.0, help='KL weight for the VAE loss (ARD KL scaling)')
args = parser.parse_args()

# --- Load Data ---
if args.data == 'mm_sim':
    data_dir = './01_data/mm_sim/'
    data = []
    for i in range(args.n_batches):
        if args.modality == 'rna':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'atac':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
        elif args.modality == 'protein':
            if args.stage == 'noisy':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
        elif args.modality == 'rna-atac':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'rna-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'atac-protein':
            temp_data = []
            if args.stage == 'noisy':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.stage == 'raw':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
            elif args.stage == 'processed':
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            else:
                raise ValueError("stage not supported. options are noisy, raw, processed")
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'all':
            temp_data = []
            if args.stage != 'noisy':
                raise ValueError("stage not supported for all modality.")
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0], dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1], dim=1, keepdim=True)
                temp_data[2] = temp_data[2] / torch.norm(temp_data[2], dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        else:
            raise ValueError("data modality not supported")
    data = torch.cat(data, dim=0)
    if args.single_batch:
        metadata = pd.concat([pd.read_csv(data_dir + f"causal_variables_batch_{i}.csv") for i in range(args.n_batches)])
        indices = np.where(metadata['mrna_batch_effect'] == 0.0)[0] if args.modality == 'rna' else \
                  (np.where(metadata['prot_batch_effect'] == 0.0)[0] if args.modality == 'protein' else np.arange(data.shape[0]))
        print(f"Using {len(indices)} samples from batch 0")
        data = data[indices, :]
elif args.data == 'bonemarrow':
    data_dir = '../../data/singlecell/'
    import anndata as ad
    data_file = ad.read_h5ad(data_dir + "human_bonemarrow.h5ad")
    modality_switch = np.where(data_file.var['modality'] == 'ATAC')[0][0]
    if args.single_batch:
        data_file = data_file[(data_file.obs['covariate_Site'] == 'site4') & (data_file.obs['DonorID'] == 19593)]
    if args.modality == 'rna':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
    elif args.modality == 'atac':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
    elif args.modality == 'rna-atac':
        data = torch.tensor(np.asarray(data_file.layers['counts'].todense()))
        if args.norm:
            data_rna = data[:, :modality_switch] / torch.norm(data[:, :modality_switch], dim=1, keepdim=True)
            data_atac = data[:, modality_switch:] / torch.norm(data[:, modality_switch:], dim=1, keepdim=True)
            data = torch.cat([data_rna, data_atac], dim=1)
    else:
        raise ValueError("data modality not supported")
    if not args.single_batch:
        data = data[:args.n_batches*10000, :]
    data = data.cpu()
data = data.float()
n_samples = data.shape[0]
print(f"Loaded data '{args.data}' with shape {data.shape}")

# --- Device Configuration ---
multi_gpu = args.multi_gpu and torch.cuda.is_available() and torch.cuda.device_count() > 1
if multi_gpu:
    if args.gpu_ids:
        gpu_ids = [int(id) for id in args.gpu_ids.split(',')]
        if 0 not in gpu_ids:
            print("Warning: DataParallel requires cuda:0. Falling back to single GPU mode.")
            multi_gpu = False
            device = torch.device(f'cuda:{gpu_ids[0]}')
        else:
            device = torch.device(f'cuda:{gpu_ids[0]}')
            print(f"Using multiple GPUs: {gpu_ids} with primary GPU {gpu_ids[0]}")
    else:
        raise ValueError("When using multi-GPU, please specify --gpu_ids.")
else:
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using single device: {device}")

# --- Set Random Seed ---
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

'''
# --- ARD-VAE Model and Functions ---
class ARD_VAE(nn.Module):
    """
    Variational Autoencoder with an Automatic Relevance Determination (ARD) prior.
    This architecture learns which latent dimensions are necessary for reconstruction,
    effectively estimating the intrinsic dimension of the data.
    """
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super(ARD_VAE, self).__init__()
        self.encoder_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder_layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        # log_alpha is the log precision of the ARD prior, initialized to a small value.
        self.log_alpha = nn.Parameter(torch.randn(1, latent_dim) * 0.01)

    def encode(self, x):
        h = self.encoder_layers(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder_layers(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar, self.log_alpha

    def get_id(self, threshold=0.01):
        """
        Calculates the Intrinsic Dimension by counting active latent variables.
        A variable is 'active' if its corresponding precision (alpha) is high,
        meaning 1/alpha is below the specified threshold.
        """
        with torch.no_grad():
            log_alpha = self.log_alpha.squeeze()
            # Active dims are where: 1/exp(log_alpha) < threshold
            # Which simplifies to: log_alpha > -log(threshold)
            active_dims = (log_alpha > -math.log(threshold)).sum().item()
            return active_dims
'''

class ARD_VAE(nn.Module):
    """
    Variational Autoencoder with a hierarchical prior that results in a 
    Student's t-distribution on the latent variables.
    """
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super(ARD_VAE, self).__init__()
        # --- Encoder and Decoder are the same as before ---
        self.encoder_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder_layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
        # --- Learnable parameters for the Inverse-Gamma prior ---
        # We use logs for stability. Replaces the old self.log_alpha.
        self.log_alpha = nn.Parameter(torch.randn(1, latent_dim) * 0.01)
        self.log_beta = nn.Parameter(torch.randn(1, latent_dim) * 0.01)

    def encode(self, x):
        h = self.encoder_layers(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder_layers(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        # Return both new parameters
        return x_recon, mu, logvar, self.log_alpha, self.log_beta
    
    def get_id(self, threshold=0.01):
        """
        Calculates the Intrinsic Dimension by counting active latent variables
        based on the expected variance of the Student's t-distribution prior.
        A variable is 'active' if its expected variance (beta/alpha) is below the threshold.
        """
        with torch.no_grad():
            log_alpha = self.log_alpha.squeeze()
            log_beta = self.log_beta.squeeze()
            
            # expected_variance = beta / alpha
            # log(expected_variance) = log_beta - log_alpha
            log_expected_variance = log_beta - log_alpha
            
            active_dims = (log_expected_variance < math.log(threshold)).sum().item()
            return active_dims
    
    def get_id_nonparametric(self, data_loader, device, variance_threshold=0.99, batch_size=100):
        """
        Non-parametric ID estimation using Jacobian-weighted variance.
        This method doesn't require a threshold on variance values, instead using
        percentage of explained variance (typically 99%).
        
        Based on the ARD-VAE paper approach:
        1. Compute estimated variance: σ² = b/a for each dimension
        2. Compute Jacobian J = ∂x̂/∂μ for each latent dimension
        3. Weight vector: w = Σ ||J_i||² (summed over output dims and samples)
        4. Weighted variance: σ²_w = w ⊙ σ²
        5. Count dimensions needed to explain variance_threshold of total variance
        
        Args:
            data_loader: DataLoader with samples to compute Jacobian
            device: Device to run computations on
            variance_threshold: Fraction of variance to explain (default: 0.99)
            batch_size: Number of samples to use for Jacobian computation
        
        Returns:
            Estimated intrinsic dimension
        """
        self.eval()
        
        # 1. Compute estimated variance: σ² = b/a
        with torch.no_grad():
            alpha = torch.exp(self.log_alpha.squeeze())
            beta = torch.exp(self.log_beta.squeeze())
            estimated_variance = beta / alpha  # Shape: (latent_dim,)
        
        # 2. Compute Jacobian-based weights
        # Get a batch of data
        data_iter = iter(data_loader)
        x_batch = next(data_iter).to(device)
        
        # Limit to batch_size samples
        if x_batch.size(0) > batch_size:
            x_batch = x_batch[:batch_size]
        
        n_samples = x_batch.size(0)
        latent_dim = self.log_alpha.size(1)
        
        # Encode to get mean representations
        with torch.no_grad():
            mu, _ = self.encode(x_batch)
        
        # Compute Jacobian for each latent dimension
        weights = torch.zeros(latent_dim, device=device)
        
        for i in range(latent_dim):
            # We need to compute ∂x̂/∂μ_i for each sample
            jacobian_norms_squared = []
            
            for sample_idx in range(n_samples):
                # Get single sample's mu
                mu_sample = mu[sample_idx:sample_idx+1].clone().requires_grad_(True)
                
                # Decode
                x_recon = self.decode(mu_sample)
                
                # Compute gradient of output w.r.t. mu_i
                # Sum over output dimensions to get total effect
                grad_outputs = torch.ones_like(x_recon)
                grads = torch.autograd.grad(
                    outputs=x_recon,
                    inputs=mu_sample,
                    grad_outputs=grad_outputs,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                # Get gradient for dimension i and compute squared norm
                grad_i = grads[0, i]
                jacobian_norms_squared.append(grad_i.pow(2).sum().item())
            
            # Average over samples (as per Eq. 19 in paper)
            weights[i] = sum(jacobian_norms_squared) / n_samples
        
        # 3. Compute weighted variance (Eq. 20)
        weights = weights.cpu()
        estimated_variance = estimated_variance.cpu()
        weighted_variance = weights * estimated_variance
        
        # 4. Sort by weighted variance (descending) and compute cumulative percentage
        sorted_indices = torch.argsort(weighted_variance, descending=True)
        sorted_weighted_var = weighted_variance[sorted_indices]
        
        total_variance = sorted_weighted_var.sum()
        cumulative_variance = torch.cumsum(sorted_weighted_var, dim=0)
        cumulative_percentage = cumulative_variance / total_variance
        
        # 5. Find number of dimensions needed to explain variance_threshold
        intrinsic_dim = (cumulative_percentage < variance_threshold).sum().item() + 1
        
        # Ensure at least 1 dimension
        intrinsic_dim = max(1, intrinsic_dim)
        
        return intrinsic_dim, {
            'weights': weights.numpy(),
            'estimated_variance': estimated_variance.numpy(),
            'weighted_variance': weighted_variance.numpy(),
            'sorted_indices': sorted_indices.numpy(),
            'cumulative_percentage': cumulative_percentage.numpy()
        }

#def ard_vae_loss_function(x_recon, x, mu, logvar, log_alpha, kl_weight=1.0):
#    """
#    Calculates the VAE loss with an ARD prior (negative ELBO).
#    Loss = Reconstruction Loss + KL(q(z|x) || p(z)) + KL_ARD
#    """
#    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
#    kl_z = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#    kl_ard = 0.5 * torch.sum(log_alpha.exp() * mu.pow(2) - log_alpha)
#    return recon_loss + kl_weight * kl_z + kl_ard

'''
def ard_vae_loss_function(x_recon, x, mu, logvar, log_alpha, kl_weight=1.0):
    """
    Calculates the VAE loss with an ARD prior (negative ELBO).
    Loss = Reconstruction Loss + KL(q(z|x) || p(z|alpha))
    """
    # Reconstruction Loss (e.g., Mean Squared Error)
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')

    # Correct KL Divergence for ARD Prior
    # alpha is the precision, the inverse of variance
    alpha = torch.exp(log_alpha)
    var = torch.exp(logvar)
    
    # D_kl(q(z|x) || p(z|alpha))
    kl_div = 0.5 * torch.sum(alpha * (mu.pow(2) + var) - log_alpha - logvar - 1)

    return recon_loss + kl_weight * kl_div
'''

def student_t_ard_vae_loss(x_recon, x, mu, logvar, log_alpha, log_beta, kl_weight=1.0):
    """
    Calculates the VAE loss with a Student's t-distribution ARD prior.
    """
    # Reconstruction Loss - use mean reduction to match the cliff method
    recon_loss = F.mse_loss(x_recon, x, reduction='mean')

    # KL Divergence for the hierarchical prior
    # D_KL(q(z|x) || p(z)) where p(z) is Student's t
    # Divide by batch size to get per-sample KL divergence
    
    alpha = torch.exp(log_alpha)
    beta = torch.exp(log_beta)
    
    # E[1/sigma_z^2] from the Inverse-Gamma distribution is alpha/beta
    term1 = (alpha / beta) * (mu.pow(2) + logvar.exp())
    
    # E[log(1/sigma_z^2)] is log(beta) - digamma(alpha)
    term2 = torch.log(beta) - torch.digamma(alpha)

    # The full KL divergence term, averaged over batch
    kl_div = 0.5 * torch.mean(torch.sum(term1 + term2 - logvar - 1, dim=1))
    
    return recon_loss + kl_weight * kl_div

def train_ard_vae(data, n_samples_train, latent_dim, epochs, lr, batch_size, wd, kl_weight=1.0, early_stopping=100):
    """
    Trains the ARD-VAE model to estimate the intrinsic dimension.
    """
    global multi_gpu
    model = ARD_VAE(data.shape[1], latent_dim, hidden_dim=1000).to(device)

    if multi_gpu:
        try:
            gpu_ids_list = [int(id) for id in args.gpu_ids.split(',')]
            model = nn.DataParallel(model.to('cuda:0'), device_ids=gpu_ids_list)
            print(f"Using DataParallel for ARD-VAE across GPUs: {gpu_ids_list}")
        except Exception as e:
            print(f"Failed DataParallel setup: {e}. Falling back to single GPU.")
            multi_gpu = False
            model = model.module.to(device) # unwrap and move to primary device

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    data_loader = torch.utils.data.DataLoader(data[:n_samples_train], batch_size=batch_size, shuffle=True)
    
    train_losses, best_loss = [], float('inf')
    pbar = tqdm.tqdm(range(epochs), desc="Training ARD-VAE")

    for epoch in pbar:
        model.train()
        total_loss = 0.0
        for x in data_loader:
            x = x.to(device if not multi_gpu else 'cuda:0')
            optimizer.zero_grad()
            x_recon, mu, logvar, log_alpha, log_beta = model(x)
            # Call the new loss function with the correct arguments
            loss = student_t_ard_vae_loss(x_recon, x, mu, logvar, log_alpha, log_beta, kl_weight=kl_weight)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(data_loader.dataset)
        train_losses.append(avg_loss)
        if avg_loss < best_loss: best_loss = avg_loss
        
        model_ref = model.module if multi_gpu else model
        #estimated_id = model_ref.get_id(threshold=ard_threshold)
        #pbar.set_postfix({'loss': f"{avg_loss:.4f}", 'est_id': estimated_id, 'best_loss': f"{best_loss:.4f}"})
        pbar.set_postfix({'loss': f"{avg_loss:.4f}", 'best_loss': f"{best_loss:.4f}"})

        if epoch > early_stopping and min(train_losses[-early_stopping:]) > min(train_losses[:-early_stopping]):
            print(f"Early stopping at epoch {epoch}")
            break
            
    #final_id = (model.module if multi_gpu else model).get_id(threshold=ard_threshold)
    final_loss = np.mean(train_losses[-10:]) if train_losses else float('inf')

    # Estimate IDs across multiple thresholds (parametric method)
    thresholds_to_try = [0.01, 0.05, 0.1]
    model_ref = model.module if multi_gpu else model
    estimated_ids = {}
    with torch.no_grad():
        for t in thresholds_to_try:
            estimated_ids[f'threshold_{t}'] = int(model_ref.get_id(threshold=t))
    
    # Estimate ID using non-parametric method (Jacobian-weighted variance)
    variance_thresholds_to_try = [0.99]
    nonparam_ids = {}
    nonparam_details = {}
    
    print("\nComputing non-parametric ID estimates using Jacobian weighting...")
    for var_thresh in variance_thresholds_to_try:
        try:
            est_id, details = model_ref.get_id_nonparametric(
                data_loader=data_loader,
                device=device if not multi_gpu else 'cuda:0',
                variance_threshold=var_thresh,
                batch_size=100
            )
            nonparam_ids[f'var_threshold_{var_thresh}'] = est_id
            nonparam_details[f'var_threshold_{var_thresh}'] = details
            print(f"  Variance threshold {var_thresh}: ID = {est_id}")
        except Exception as e:
            print(f"  Error computing non-parametric ID for threshold {var_thresh}: {e}")
            nonparam_ids[f'var_threshold_{var_thresh}'] = -1

    del model, optimizer, data_loader
    gc.collect()
    torch.cuda.empty_cache()

    return final_loss, estimated_ids, nonparam_ids, nonparam_details

# --- Main Execution Block ---
if __name__ == "__main__":
    start_time = time.time()
    data_fraction = 0.9
    n_samples_train = int(data.shape[0] * data_fraction)

    print("Running Intrinsic Dimension Estimation with ARD-VAE")
    out_file = f"03_results/reports/ard_vae_{args.data}-{args.modality}-{args.stage}_n{n_samples}_kl{args.kl_weight}_seed{args.seed}.csv"
    os.makedirs(os.path.dirname(out_file), exist_ok=True)
    
    config_start_time = time.time()

    final_loss, estimated_ids_all, nonparam_ids_all, nonparam_details_all = train_ard_vae(
        data=data, n_samples_train=n_samples_train, latent_dim=args.ard_latent_dim,
        epochs=args.epochs, lr=args.lr, batch_size=args.batch_size,
        wd=args.weight_decay, kl_weight=args.kl_weight,
        early_stopping=100
    )
    
    config_time = (time.time() - config_start_time) / 60

    print("\n--- ARD-VAE Results ---")
    print(f"  Final Loss: {final_loss:.4f}")
    print(f"  Time Taken: {config_time:.2f} minutes")
    print("\n  Parametric ID estimates (threshold-based):")
    for key, val in estimated_ids_all.items():
        print(f"    {key}: {val}")
    print("\n  Non-parametric ID estimates (Jacobian-weighted variance):")
    for key, val in nonparam_ids_all.items():
        print(f"    {key}: {val}")

    results = {
        'data': [args.data], 'modality': [args.modality], 'stage': [args.stage],
        'n_samples': [n_samples], 'latent_dim_start': [args.ard_latent_dim],
        'kl_weight': [args.kl_weight], 'learning_rate': [args.lr],
        'epochs': [args.epochs], 'seed': [args.seed],
        'final_loss': [final_loss], 'time_min': [config_time]
    }
    # Add parametric estimated IDs for all thresholds to results
    for t, est_id in estimated_ids_all.items():
        results[f'parametric_id_{t}'] = [est_id]
    # Add non-parametric estimated IDs for all variance thresholds to results
    for t, est_id in nonparam_ids_all.items():
        results[f'nonparametric_id_{t}'] = [est_id]
    
    pd.DataFrame(results).to_csv(out_file, index=False)
    
    # Optionally save detailed variance information from non-parametric method
    detail_file = out_file.replace('.csv', '_nonparam_details.npz')
    np.savez(detail_file, **{k: v for details in nonparam_details_all.values() 
                             for k_inner, v in details.items() 
                             for k in [f"{k_inner}"]})
    print(f"\nNon-parametric details saved to {detail_file}")
    
    total_time = (time.time() - start_time) / 60
    print(f"\nTotal script execution time: {total_time:.2f} minutes")
    print(f"Results saved to {out_file}")