import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
import tqdm
import random
from scipy.stats import wasserstein_distance as scipy_wasserstein
from scipy.optimize import minimize_scalar, brentq
from scipy.integrate import quad

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--data', type=str, default='mm_sim', help='data name. options are mm_sim, bonemarrow, brain')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--n_batches', type=int, default=3, help='number of batches (10k samples each)')
parser.add_argument('--single_batch', type=bool, default=False, help='if True, only one batch (noise in data) is used for the computation. useful for debugging')
parser.add_argument('--norm', type=bool, default=False, help='if True, the data is normalized before computing the metrics')
parser.add_argument('--wasserstein_samples', type=int, default=1000, help='number of data points to sample for Wasserstein distance calculation')
parser.add_argument('--wasserstein_alpha', type=float, default=0.9, help='threshold for cumulative variance in Wasserstein method')
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()

###
# load data
###

if args.data == 'mm_sim':
    data_dir = './01_data/mm_sim/'
    data = []
    for i in range(args.n_batches):
        if args.modality == 'rna':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
        elif args.modality == 'atac':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
        elif args.modality == 'protein':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
        elif args.modality == 'rna-atac':
            temp_data = []
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'rna-protein':
            temp_data = []
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'atac-protein':
            temp_data = []
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        elif args.modality == 'all':
            temp_data = []
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            if args.norm:
                temp_data[0] = temp_data[0] / torch.norm(temp_data[0], dim=1, keepdim=True)
                temp_data[1] = temp_data[1] / torch.norm(temp_data[1], dim=1, keepdim=True)
                temp_data[2] = temp_data[2] / torch.norm(temp_data[2], dim=1, keepdim=True)
            temp_data = torch.cat(temp_data, dim=1)
            data.append(temp_data)
        else:
            raise ValueError("data modality not supported")
    data = torch.cat(data, dim=0)
    if args.single_batch:
        # load the batch info
        metadata = pd.concat([pd.read_csv(data_dir + f"causal_variables_batch_{i}.csv") for i in range(args.n_batches)])
        if args.modality == 'rna':
            indices = np.where(metadata['mrna_batch_effect'] == 0.0)[0]
        elif args.modality == 'protein':
            indices = np.where(metadata['prot_batch_effect'] == 0.0)[0]
        else:
            indices = np.arange(data.shape[0])
        print(f"using {len(indices)} samples from batch 0")
        data = data[indices, :]
elif args.data == 'bonemarrow':
    data_dir = '../../data/singlecell/'
    import anndata as ad
    data_file = ad.read_h5ad(data_dir + "human_bonemarrow.h5ad")
    modality_switch = np.where(data_file.var['modality'] == 'ATAC')[0][0]
    if args.single_batch:
        # take a the donor-site combination that gives the most samples
        data_file = data_file[(data_file.obs['covariate_Site'] == 'site4') & (data_file.obs['DonorID'] == 19593)]
    if args.modality == 'rna':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
    elif args.modality == 'atac':
        data = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
    elif args.modality == 'rna-atac':
        if args.norm:
            data_temp_a = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
            # normalize
            data_temp_a = data_temp_a / torch.norm(data_temp_a, dim=1, keepdim=True)
            data_temp_b = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
            # normalize
            data_temp_b = data_temp_b / torch.norm(data_temp_b, dim=1, keepdim=True)
            data = torch.cat([data_temp_a, data_temp_b], dim=1)
        else:
            data = torch.tensor(np.asarray(data_file.layers['counts'].todense()))
    else:
        raise ValueError("data modality not supported")
    if not args.single_batch:
        data = data[:args.n_batches*10000, :]
    data = data.cpu()
n_samples = data.shape[0]
max_components = min(min(data.shape[0], data.shape[1]), 10000)
print(f"loaded data '{args.data}' with shape {data.shape}")

###
# check if it already has been computed
###

out_file = f"03_results/reports/id_baseline_metrics_{args.data}_{args.modality}_n{n_samples}_seed{args.seed}_v3"
if args.single_batch:
    out_file += "_singlebatch"
if args.norm:
    out_file += "_norm"
out_file += ".csv"

import os
methods_done = []
already_computed = False
if os.path.exists(out_file):
    df_out_metrics = pd.read_csv(out_file, index_col=0)
    methods_done = df_out_metrics['method'].unique()
    already_computed = True
    print(f"metrics already computed: {methods_done}. checking if there are missing ones...")

###
# set random states
###
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

###
# compute metrics
###
#N = 2(alpha + 1)n samples
#N/n = 2(alpha + 1)
# N/(2n) = alpha + 1
# alpha = N/(2n) - 1

def wasserstein_distance(data, n_samples=1000, alpha=0.9, max_dim=10000):
    """
    Estimate intrinsic dimensionality using Wasserstein distance.
    
    Parameters:
    - data: Input data matrix
    - n_samples: Number of data points to sample
    - alpha: Threshold for determining intrinsic dimension
    - max_dim: Maximum dimension to consider
    
    Returns:
    - Estimated intrinsic dimension
    """
    # Sample data if needed
    if n_samples < data.shape[0]:
        indices = np.random.choice(data.shape[0], n_samples, replace=False)
        sample_data = data[indices]
    else:
        sample_data = data
    
    # Compute pairwise distances
    from sklearn.metrics import pairwise_distances
    distances = pairwise_distances(sample_data, metric='euclidean')
    distances = distances[np.triu_indices(distances.shape[0], k=1)]  # Upper triangular part
    
    # Normalize distances
    distances = distances / np.mean(distances)
    
    # Sort distances for empirical CDF
    distances.sort()
    empirical_cdf = np.arange(1, len(distances) + 1) / len(distances)
    
    # Test different dimensions and compute Wasserstein distance
    best_dim = 0
    min_dist = float('inf')
    
    for d in range(1, min(max_dim, data.shape[1]) + 1):
        # Generate theoretical CDF for dimension d
        # For a d-dimensional Euclidean space, the distribution of distances follows:
        # P(r) ~ r^(d-1) * exp(-r^2/2)
        r_values = np.linspace(0, max(distances), 1000)
        theoretical_pdf = r_values**(d-1) * np.exp(-r_values**2/2)
        theoretical_pdf = theoretical_pdf / np.sum(theoretical_pdf)
        theoretical_cdf = np.cumsum(theoretical_pdf)
        
        # Interpolate empirical CDF to match theoretical CDF points
        from scipy.interpolate import interp1d
        empirical_points = np.linspace(min(distances), max(distances), len(theoretical_cdf))
        empirical_interp = interp1d(distances, empirical_cdf, bounds_error=False, fill_value=(0, 1))
        empirical_cdf_interp = empirical_interp(empirical_points)
        
        # Compute Wasserstein distance
        w_dist = scipy_wasserstein(empirical_cdf_interp, theoretical_cdf)
        
        if w_dist < min_dist:
            min_dist = w_dist
            best_dim = d
            
        # Early stopping if we've found a good dimension
        if min_dist < (1-alpha):
            break
    
    return best_dim

def marchenko_pastur_pdf(x, gamma, sigma_sq):
    """
    Calculates the Marchenko-Pastur probability density function.

    Args:
        x (float or np.ndarray): The eigenvalue(s) at which to evaluate the PDF.
        gamma (float): The aspect ratio p/n.
        sigma_sq (float): The variance of the noise.

    Returns:
        float or np.ndarray: The value of the MP PDF at x.
    """
    lambda_plus = sigma_sq * (1 + np.sqrt(gamma))**2
    lambda_minus = sigma_sq * (1 - np.sqrt(gamma))**2
    
    x = np.asanyarray(x)
    
    pdf = np.zeros_like(x, dtype=float)
    mask = (x >= lambda_minus) & (x <= lambda_plus)
    
    x_safe = np.where(x == 0, 1e-9, x)
    
    pdf[mask] = np.sqrt((lambda_plus - x[mask]) * (x[mask] - lambda_minus)) / (2 * np.pi * sigma_sq * gamma * x_safe[mask])
    
    return pdf

def marchenko_pastur_cdf(x, gamma, sigma_sq):
    """
    Calculates the Marchenko-Pastur CDF by numerically integrating the PDF.
    """
    lambda_minus = sigma_sq * (1 - np.sqrt(gamma))**2
    if x <= lambda_minus:
        return 0.0
    
    lambda_plus = sigma_sq * (1 + np.sqrt(gamma))**2
    upper_bound = min(x, lambda_plus)
    
    # Use quad for numerical integration. Handle potential integration errors.
    try:
        val, _ = quad(marchenko_pastur_pdf, lambda_minus, upper_bound, args=(gamma, sigma_sq), epsabs=1e-5)
    except Exception:
        # Silently handle integration warnings which can be frequent but non-fatal.
        return np.nan

    if x >= lambda_plus:
        return 1.0
        
    return val

def marchenko_pastur_quantile(q, gamma, sigma_sq):
    """
    Calculates the quantile function (inverse CDF) of the Marchenko-Pastur law
    by numerically finding the root of CDF(x) - q = 0.
    """
    lambda_minus = sigma_sq * (1 - np.sqrt(gamma))**2
    lambda_plus = sigma_sq * (1 + np.sqrt(gamma))**2

    if q <= 0.0:
        return lambda_minus
    if q >= 1.0:
        return lambda_plus

    def root_func(x):
        return marchenko_pastur_cdf(x, gamma, sigma_sq) - q

    try:
        # brentq is efficient and guaranteed to find a root if one exists in the interval
        return brentq(root_func, lambda_minus, lambda_plus)
    except (ValueError, RuntimeError):
        # Fallback for numerical issues near the boundaries
        if abs(root_func(lambda_minus)) < 1e-4: return lambda_minus
        if abs(root_func(lambda_plus)) < 1e-4: return lambda_plus
        return np.nan

def fit_marchenko_pastur(bulk_eigenvalues, gamma, p, start_index):
    """
    Fits the Marchenko-Pastur distribution by performing a linear regression on
    the quantile-quantile plot, using the correct global ranks for the bulk eigenvalues.

    Args:
        bulk_eigenvalues (np.ndarray): Array of eigenvalues for the fit, sorted ascending.
        gamma (float): The aspect ratio p/n.
        p (int): The total number of eigenvalues.
        start_index (int): The starting index of the bulk eigenvalues in the full sorted list.

    Returns:
        tuple (float, float): The estimated variance (sigma^2) and the fit intercept.
    """
    # 1. Calculate empirical probabilities based on the ranks in the *full* distribution.
    n_bulk = len(bulk_eigenvalues)
    # The ranks of the bulk eigenvalues are from start_index to start_index + n_bulk - 1
    ranks_in_full_dist = np.arange(start_index, start_index + n_bulk)
    empirical_probs = (ranks_in_full_dist + 0.5) / p

    # 2. Calculate the theoretical quantiles of a *standard* MP distribution (sigma_sq=1) for these probabilities.
    theoretical_quantiles_std = np.array(
        [marchenko_pastur_quantile(q, gamma, sigma_sq=1.0) for q in empirical_probs]
    )

    # Filter out any NaN values that might arise from numerical issues
    valid_indices = ~np.isnan(theoretical_quantiles_std)
    if np.sum(valid_indices) < 2:
        return np.mean(bulk_eigenvalues), 0.0

    empirical_quantiles = bulk_eigenvalues[valid_indices]
    theoretical_quantiles_std = theoretical_quantiles_std[valid_indices]

    # 3. Perform linear regression of empirical quantiles vs. theoretical quantiles.
    try:
        slope, intercept = np.polyfit(theoretical_quantiles_std, empirical_quantiles, 1)
    except np.linalg.LinAlgError:
        return np.mean(bulk_eigenvalues), 0.0

    if slope <= 0:
        return np.mean(bulk_eigenvalues), 0.0
        
    return slope, intercept

def bema_intrinsic_dimension(data, bulk_percentile=0.60):
    """
    Estimates intrinsic dimensionality using Bulk Edge Marchenko-Pastur Analysis (BEMA).
    
    Args:
        data: Input data matrix (n_samples x n_features)
        bulk_percentile: Fraction of eigenvalues to use as "bulk" for MP fitting (default 0.60)
                        E.g., 0.60 means use middle 60% (from 20th to 80th percentile)
    
    Returns:
        int: Estimated intrinsic dimension (number of spike eigenvalues)
    """
    if isinstance(data, torch.Tensor):
        data = data.numpy()
    
    # Compute covariance matrix and eigenvalues
    if data.shape[0] > data.shape[1]:
        # More samples than features: use standard covariance
        cov_matrix = np.cov(data.T)
    else:
        # More features than samples: use Gram matrix approach
        cov_matrix = np.cov(data)
    
    eigenvalues_all, _ = np.linalg.eig(cov_matrix)
    eigenvalues_all = eigenvalues_all.real
    eigenvalues_all.sort()
    
    n, p = data.shape
    gamma = p / n
    
    # Fit Marchenko-Pastur to the 'bulk' eigenvalues
    # Center the bulk window around the median
    lower_percentile = (1.0 - bulk_percentile) / 2.0
    upper_percentile = 1.0 - lower_percentile
    start_index = int(len(eigenvalues_all) * lower_percentile)
    end_index = int(len(eigenvalues_all) * upper_percentile)
    bulk_eigenvalues = eigenvalues_all[start_index:end_index]
    
    # Fit the distribution
    sigma_sq_fit, _ = fit_marchenko_pastur(bulk_eigenvalues, gamma, len(eigenvalues_all), start_index)
    
    # Calculate threshold
    lambda_plus = sigma_sq_fit * (1 + np.sqrt(gamma))**2
    
    # Count spikes (eigenvalues above threshold)
    spikes = eigenvalues_all[eigenvalues_all > lambda_plus]
    
    return len(spikes)

if not already_computed:
    print("computing metrics...")

out_metrics = {
    'method': [],
    'time': [],
    'dim': [],
    'params': []
}

"""
# perform a PCA on the data
if 'PCA' not in methods_done:
    print("computing PCA")
    from sklearn.decomposition import PCA
    thresholds = [0.8, 0.9, 0.95]
    times = []
    dims = []
    for threshold in tqdm.tqdm(thresholds):
        start_time = time.time()
        pca = PCA(n_components=max_components, random_state=0)
        pca.fit(data)
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(np.where(np.cumsum(pca.explained_variance_ratio_) >= threshold)[0][0])
    out_metrics['method'].extend(['PCA']*len(thresholds))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'threshold={threshold}' for threshold in thresholds])
"""
# BEMA (Bulk Edge Marchenko-Pastur Analysis)
if 'BEMA' not in methods_done:
    print("computing BEMA")
    bulk_percentile_options = [0.80, 0.9, 0.95, 0.99]
    times = []
    dims = []
    for bulk_pct in tqdm.tqdm(bulk_percentile_options):
        start_time = time.time()
        bema_dim = bema_intrinsic_dimension(data, bulk_percentile=bulk_pct)
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(bema_dim)
    out_metrics['method'].extend(['BEMA']*len(bulk_percentile_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'bulk_percentile={bulk_pct:.2f}' for bulk_pct in bulk_percentile_options])

"""
# svd
if 'SVD' not in methods_done:
    print("computing svd")
    # SVD with threshold 0 (matrix rank - non-zero singular values)
    start_time = time.time()
    rank = torch.linalg.matrix_rank(torch.tensor(data).float()).item()
    end_time = time.time()
    out_metrics['method'].append('SVD')
    out_metrics['time'].append((end_time - start_time)/60)
    out_metrics['dim'].append(rank)
    out_metrics['params'].append('threshold=0')
    
    # SVD with cumulative energy thresholds
    start_time = time.time()
    U, S, Vh = torch.linalg.svd(torch.tensor(data).float(), full_matrices=False)
    singular_values_energy = (S ** 2) / torch.sum(S ** 2)
    cumulative_energy = torch.cumsum(singular_values_energy, dim=0)
    end_time = time.time()
    
    for threshold in [0.8, 0.9, 0.95]:
        try:
            svd_dim = torch.where(cumulative_energy >= threshold)[0][0].item() + 1
        except:
            svd_dim = np.nan
        out_metrics['method'].append('SVD')
        out_metrics['time'].append((end_time - start_time)/60)
        out_metrics['dim'].append(svd_dim)
        out_metrics['params'].append(f'threshold={threshold}')


# scikit
import skdim
data = data.cpu().numpy()

# lPCA
if 'lPCA' not in methods_done:
    print("computing lPCA")
    alpha_options = [0.0001, 0.001, 0.01, 0.05, 0.1, 0.5, 0.9]
    times = []
    dims = []
    for alpha in tqdm.tqdm(alpha_options):
        start_time = time.time()
        pca_dim = skdim.id.lPCA(alphaFO=alpha).fit(data).dimension_
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['lPCA']*len(alpha_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'alpha={alpha}' for alpha in alpha_options])

if 'ICA' not in methods_done:
    from sklearn.decomposition import FastICA
    print("computing ICA")
    thresholds = [1e-3, 1e-4, 1e-5]
    times = []
    dims = []
    for threshold in tqdm.tqdm(thresholds):
        start_time = time.time()
        ica = FastICA(n_components=max_components, random_state=0, tol=threshold)
        ica.fit(data)
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(np.where(np.cumsum(ica.explained_variance_ratio_) >= threshold)[0][0])
    out_metrics['method'].extend(['ICA']*len(thresholds))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'threshold={threshold}' for threshold in thresholds])

if 'KernelPCA' not in methods_done:
    from sklearn.decomposition import KernelPCA
    print("computing KernelPCA")
    kernels = ['linear', 'poly', 'rbf', 'sigmoid', 'cosine']
    times = []
    dims = []
    for kernel in tqdm.tqdm(kernels):
        start_time = time.time()
        kpca = KernelPCA(n_components=max_components, kernel=kernel, random_state=0, remove_zero_eig=True)
        kpca.fit(data)
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(len(kpca.eigenvalues_))
    out_metrics['method'].extend(['KernelPCA']*len(kernels))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'kernel={kernel}' for kernel in kernels])

# Correlation Integral
if 'CorrInt' not in methods_done:
    '''
    Estimating the fractal dimension.
    '''
    print("computing CorrInt")
    k1_options = [2, 5, 10, 20, 50, 100]
    k2_options = [2, 5, 10, 20, 50, 100]
    times = []
    dims = []
    for k1 in tqdm.tqdm(k1_options):
        for k2 in k2_options:
            if k2 <= k1:
                continue
            start_time = time.time()
            pca_dim = skdim.id.CorrInt(k1=k1,k2=k2).fit(data).dimension_
            end_time = time.time()
            times.append((end_time - start_time)/60)
            dims.append(pca_dim)
    out_metrics['method'].extend(['CorrInt']*len(times))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'k1={k1}, k2={k2}' for k1 in k1_options for k2 in k2_options if k2 > k1])

# FisherS
if 'FisherS' not in methods_done:
    print("computing FisherS")
    thresholds = [10, 100, 1000, max_components]
    times = []
    dims = []
    for threshold in tqdm.tqdm(thresholds):
        start_time = time.time()
        pca_dim = skdim.id.FisherS().fit(data).dimension_
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['FisherS']*len(thresholds))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'threshold={threshold}' for threshold in thresholds])

# MindML
if 'MiND_ML' not in methods_done:
    print("computing MiND_ML")
    k_options = [2, 5, 10, 20, 50, 100]
    times = []
    dims = []
    for k in tqdm.tqdm(k_options):
        start_time = time.time()
        pca_dim = skdim.id.MiND_ML(k=k,D=10000).fit(data).dimension_ # D is an upper bound for the ID
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['MiND_ML']*len(k_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'k={k}' for k in k_options])

# MLE
if 'MLE' not in methods_done:
    print("computing MLE")
    sigma_options = [0, 0.001, 0.01, 0.1]
    k_options = [2, 5, 10, 20, 50, 100]
    times = []
    dims = []
    for sigma in tqdm.tqdm(sigma_options):
        for k in k_options:
            start_time = time.time()
            pca_dim = skdim.id.MLE(sigma=sigma, K=k, n=data.shape[1]).fit(data).dimension_
            end_time = time.time()
            times.append((end_time - start_time)/60)
            dims.append(pca_dim)
    out_metrics['method'].extend(['MLE']*len(sigma_options)*len(k_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'sigma={sigma}, k={k}' for sigma in sigma_options for k in k_options])

# MOM
if 'MOM' not in methods_done:
    print("computing MOM")
    start_time = time.time()
    pca_dim = skdim.id.MOM().fit(data).dimension_
    end_time = time.time()
    out_metrics['method'].append('MOM')
    out_metrics['time'].append((end_time - start_time)/60)
    out_metrics['dim'].append(pca_dim)
    out_metrics['params'].append(None)

# TLE
if 'TLE' not in methods_done:
    print("computing TLE")
    epsilons = [1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 0.1, 1.0]
    times = []
    dims = []
    for epsilon in tqdm.tqdm(epsilons):
        start_time = time.time()
        pca_dim = skdim.id.TLE(epsilon=epsilon).fit(data).dimension_
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['TLE']*len(epsilons))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'epsilon={epsilon}' for epsilon in epsilons])

# TwoNN
if 'TwoNN' not in methods_done:
    print("computing TwoNN")
    discard_fractions = [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 0.9]
    times = []
    dims = []
    for discard_fraction in tqdm.tqdm(discard_fractions):
        start_time = time.time()
        pca_dim = skdim.id.TwoNN(discard_fraction=discard_fraction).fit(data).dimension_
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['TwoNN']*len(discard_fractions))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'discard_fraction={discard_fraction}' for discard_fraction in discard_fractions])
#"""

# skipping the ones that take long right now
"""
# kNN
if 'kNN' not in methods_done:
    print("computing kNN")
    m_options = [1, 2, 5]
    gamma_options = [1, 2, 5]
    times = []
    dims = []
    for m in tqdm.tqdm(m_options):
        for gamma in gamma_options:
            start_time = time.time()
            pca_dim = skdim.id.kNN(M=m, gamma=gamma).fit(data).dimension_
            end_time = time.time()
            times.append((end_time - start_time)/60)
            dims.append(pca_dim)
    out_metrics['method'].extend(['kNN']*len(m_options)*len(gamma_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'm={m}, gamma={gamma}' for m in m_options for gamma in gamma_options])

# ESS (angle-based)
if 'ESS' not in methods_done:
    print("computing ESS")
    d_options = [1, 2, 5, 10]
    times = []
    dims = []
    for d in tqdm.tqdm(d_options):
        start_time = time.time()
        pca_dim = skdim.id.ESS(d=d).fit(data).dimension_
        end_time = time.time()
        times.append((end_time - start_time)/60)
        dims.append(pca_dim)
    out_metrics['method'].extend(['ESS']*len(d_options))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'd={d}' for d in d_options])

if 'MADA' not in methods_done:
    print("computing MADA")
    start_time = time.time()
    pca_dim = skdim.id.MADA().fit(data).dimension_
    end_time = time.time()
    out_metrics['method'].append('MADA')
    out_metrics['time'].append((end_time - start_time)/60)
    out_metrics['dim'].append(pca_dim)
    out_metrics['params'].append(None)
"""

"""
if 'DANCo' not in methods_done:
    print("computing DANCo")
    start_time = time.time()
    pca_dim = skdim.id.DANCo().fit(data).dimension_
    end_time = time.time()
    out_metrics['method'].append('DANCo')
    out_metrics['time'].append((end_time - start_time)/60)
    out_metrics['dim'].append(pca_dim)
    out_metrics['params'].append(None)


if 'Wasserstein' not in methods_done:
    print("computing Wasserstein")
    start_time = time.time()
    n_w_samples = [100, 1000, 10000]
    n_w_alphas = [0.2, 0.5, 0.7, 0.8, 0.9, 0.95, 0.99]
    times = []
    dims = []
    pca_dim = wasserstein_distance(data, n_w_samples, n_w_alphas)
    for n_w_sample in n_w_samples:
        for n_w_alpha in n_w_alphas:
            start_time = time.time()
            pca_dim = wasserstein_distance(data, n_samples=n_w_sample, alpha=n_w_alpha)
            end_time = time.time()
            times.append((end_time - start_time)/60)
            dims.append(pca_dim)
    out_metrics['method'].extend(['Wasserstein']*len(n_w_samples)*len(n_w_alphas))
    out_metrics['time'].extend(times)
    out_metrics['dim'].extend(dims)
    out_metrics['params'].extend([f'n_samples={n_w_sample}, alpha={n_w_alpha}' for n_w_sample in n_w_samples for n_w_alpha in n_w_alphas])
"""

# save metrics
if len(methods_done) == 0:
    #df_out_metrics = pd.DataFrame(out_metrics).T
    df_out_metrics = pd.DataFrame(out_metrics)
else:
    if len(out_metrics['method']) > 0:
        df_out_metrics2 = pd.DataFrame(out_metrics)
        # concat and reindex
        df_out_metrics = pd.concat([df_out_metrics, df_out_metrics2])
        df_out_metrics = df_out_metrics.reset_index(drop=True)
df_out_metrics.to_csv(out_file)