import numpy as np
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfTransformer
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split
import anndata as ad
import pandas as pd
import os
import pdb
import torch
import json
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split

from .deepnet_tools import custom_collate_fn


class RNA_ATAC_Dataset(Dataset):
    def __init__(self, rna_adata: ad.AnnData, atac_adata: ad.AnnData, genes: np.array, regions: np.array):
        """
        Dataset for RNA and ATAC data from AnnData objects.

        Args:
            rna_adata (anndata.AnnData): RNA data.
            atac_adata (anndata.AnnData): ATAC data.
            n_gene (int): Number of genes to use.
            n_region (int): Number of regions to use.
        """
        assert rna_adata.shape[0] == atac_adata.shape[0], "RNA and ATAC data must have the same number of cells."
        
        # Data matrices
        if len(genes) > 0:
            self.rna_data = rna_adata.X[:, genes]
        else:
            self.rna_data = rna_adata.X
            
        if len(regions) > 0:
            self.atac_data = atac_adata.X[:, regions]
        else:
            self.atac_data = atac_adata.X
        
        print(f"RNA data shape: {self.rna_data.shape}")
        print(f"ATAC data shape: {self.atac_data.shape}")
        
        # Metadata
        self.rna_metadata = rna_adata.obs['RNA'] if 'RNA' in rna_adata.obs.columns else None
        self.rna_metadata = self.rna_metadata == 'True'  # Convert to boolean
        self.atac_metadata = atac_adata.obs['ATAC'] if 'ATAC' in atac_adata.obs.columns else None
        self.atac_metadata = self.atac_metadata == 'True' # Convert to boolean
        self.size = rna_adata.shape[0]

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Convert sparse to dense if necessary (only for the selected index)
        # rna_sample = self.rna_data[idx].toarray() if hasattr(self.rna_data[idx], "toarray") else self.rna_data[idx]
        # atac_sample = self.atac_data[idx].toarray() if hasattr(self.atac_data[idx], "toarray") else self.atac_data[idx]
        
        # Convert to COO format
        rna_coo = self.rna_data[idx].tocoo() if hasattr(self.rna_data[idx], "tocoo") else self.rna_data[idx]
        indices = torch.LongTensor(np.vstack((rna_coo.row, rna_coo.col)))
        values = torch.FloatTensor(rna_coo.data)
        size = torch.Size(rna_coo.shape)
        rna_tensor = torch.sparse_coo_tensor(indices, values, size)

        atac_coo = self.atac_data[idx].tocoo() if hasattr(self.atac_data[idx], "tocoo") else self.atac_data[idx]
        indices = torch.LongTensor(np.vstack((atac_coo.row, atac_coo.col)))
        values = torch.FloatTensor(atac_coo.data)
        size = torch.Size(atac_coo.shape)
        atac_tensor = torch.sparse_coo_tensor(indices, values, size)

        # Convert to PyTorch tensors
        # rna_tensor = torch.tensor(rna_sample, dtype=torch.float32)
        # atac_tensor = torch.tensor(atac_sample, dtype=torch.float32)

        # Modality metadata (if available)
        rna_modality = self.rna_metadata.iloc[idx] if self.rna_metadata is not None else None
        atac_modality = self.atac_metadata.iloc[idx] if self.atac_metadata is not None else None

        return rna_tensor, atac_tensor, rna_modality, atac_modality, idx


class patchseq_Dataset(Dataset):
    def __init__(self, *data_modalities, labels=None):
        """
        Dataset for handling multiple data modalities from AnnData objects.

        Args:
            *data_modalities: Variable number of data modalities (e.g., RNA, ATAC, etc.).
        """
        num_samples = data_modalities[0].shape[0]
        assert all(modality.shape[0] == num_samples for modality in data_modalities), \
            "All modalities must have the same number of samples (cells)."

        self.data_modalities = data_modalities
        self.add_labels = True if labels is not None else False
        self.metadata = {}
        for mod_data in self.data_modalities:
            if self.add_labels:
                self.label = mod_data.obs[labels]

            if 'T' in mod_data.obs.columns:
                self.metadata['T'] = mod_data.obs['T'] 
            if 'E' in mod_data.obs.columns:
                self.metadata['E'] = mod_data.obs['E']
            if 'M' in mod_data.obs.columns:
                self.metadata['M'] = mod_data.obs['M']

        self.size = num_samples

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        data_tensors = [torch.FloatTensor(modality.X[idx]) for modality in self.data_modalities]
        modalities = [self.metadata[m].iloc[idx] for m in self.metadata.keys()]
        
        if self.add_labels:
            labels = self.label.iloc[idx]
            return *data_tensors, *modalities, labels, idx
        else:
            return *data_tensors, *modalities, idx
    
      
class scDataset(Dataset):
    def __init__(self, adata: ad.AnnData, features: np.array, metadata=None):
        """
        Dataset for RNA and ATAC data from AnnData objects.

        Args:
            adata (anndata.AnnData): RNA/ATAC data.
            features (np.array): Features to use.
        """
        
        # Data matrices
        if len(features) > 0:
            self.data = adata.X[:, features]
        else:
            self.data = adata.X
        
        self.size = adata.shape[0]
        if metadata is not None:
            self.metadata = adata.obs[metadata].to_numpy()
        else:
            self.metadata = None
        

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
       
        sc_coo = self.data[idx].tocoo() if hasattr(self.data[idx], "tocoo") else self.data[idx]
        indices = torch.LongTensor(np.vstack((sc_coo.row, sc_coo.col)))
        values = torch.FloatTensor(sc_coo.data)
        size = torch.Size(sc_coo.shape)
        data_tensor = torch.sparse_coo_tensor(indices, values, size)
        
        if self.metadata is not None:
            return data_tensor, self.metadata[idx]
        else:
            return data_tensor, idx
    


def normalize_cellxgene(x):
    """Normalize based on number of input genes

    Args:
        x (np.array): cell x gene matrix (cells along axis=0, genes along axis=1)
        scale_factor (float): Scalar multiplier
    
    Returns: 
        x, np.mean(x)
    """
    # x = np.divide(x, np.sum(x, axis=1, keepdims=True))*scale_factor

    return normalize(x, axis=1, norm='l1')


def logcpm(x, scaler=1e4) -> np.array:
    """ Log CPM normalization

    inpout args
        x (np.array): cell x gene matrix (cells along axis=0, genes along axis=1)
        scaler (float, optional): scaling factor for log CPM
    
    return 
        normalized log CPM gene expression matrix
    """
    return np.log1p(normalize_cellxgene(x) * scaler)


def sparse_std(x, axis=None):
    x_ = x.copy()
    x_.data **= 2
    return np.sqrt(x_.mean(axis) - np.square(x.mean(axis)))


def reorder_genes(x, eps, chunksize, binary):
    t_gene = x.shape[1]
    g_std = []
    
    # if issparse(x):
    #     g_std.append(sparse_std(x, axis=0))
    # else:
    #     g_std.append(np.std(x, axis=0))

    for iter in range(int(t_gene // chunksize) + 1):
        print(f"{iter} -------------------------------")
        ind0 = iter * chunksize
        ind1 = np.min((t_gene, (iter + 1) * chunksize))
        
        if binary:
            x_ = (x[:, ind0:ind1] != 0).astype(int)
        else:
            x_ = x[:, ind0:ind1]
            
        if issparse(x_):
            g_std.append(sparse_std(x_, axis=0))
        else:
            g_std.append(np.std(x_, axis=0))

    g_std = [np.array(gg).flatten() for gg in g_std]
    g_std = np.concatenate(g_std)
    
    g_ind = np.argsort(g_std)
    g_std = g_std[g_ind]
    g_ind = g_ind[g_std > eps]
    return g_ind[::-1], g_std[::-1]


def get_HVG(x, thr=0.1, binary=True, chunksize=10000):
    
    g_index, g_std = reorder_genes(x, thr, chunksize, binary)
    return np.asarray(g_index).flatten(), g_std


def tfidf(x, scaler=1e4):
    tfidf = TfidfTransformer(norm='l1', use_idf=True)
    return tfidf.fit_transform(x) * scaler

def get_HAP(x, thr=0.1, binary=True):
    if binary:
        x_bin = (x != 0).astype(int)
        colsum = np.array(np.sum(x_bin, axis=0))
    else:
        colsum = np.array(np.sum(x, axis=0))

    colsum = colsum.reshape(-1)
    return np.logical_and((colsum > x.shape[0] * thr), (colsum < x.shape[0] * (1 - thr)))


def split_data_Kfold(class_label, K_fold):
    uniq_label = np.unique(class_label)
    label_train_indices = [[] for ll in uniq_label]
    label_test_indices = [[] for ll in uniq_label]

    # Split the the data to train and test keeping the same ratio for all classes
    for i_l, label in enumerate(uniq_label):
        label_indices = np.where(class_label == label)[0]
        test_size = int(( 1 /K_fold) * len(label_indices))

        # Prepare the test and training indices for K folds
        for fold in range(K_fold):
            ind_0 = fold * test_size
            ind_1 = (1 + fold) * test_size
            tmp_ind = list(label_indices)
            label_test_indices[i_l].append(tmp_ind[ind_0:ind_1])
            del tmp_ind[ind_0:ind_1]
            label_train_indices[i_l].append(tmp_ind)
    test_ind = [[] for k in range(K_fold)]
    train_ind = [[] for k in range(K_fold)]
    for fold in range(K_fold):
        for i_l in range(len(uniq_label)):
            test_ind[fold].append(label_test_indices[i_l][fold])
            train_ind[fold].append(label_train_indices[i_l][fold])
        test_ind[fold] = np.concatenate(test_ind[fold])
        train_ind[fold] = np.concatenate(train_ind[fold])
        # Shuffle the indices
        index = np.arange(len(test_ind[fold]))
        np.random.shuffle(index)
        test_ind[fold] = test_ind[fold][index]
        index = np.arange(len(train_ind[fold]))
        np.random.shuffle(index)
        train_ind[fold] = train_ind[fold][index]

    return train_ind, test_ind


def load_data(file, gene_file='', n_gene=0):

    adata = ad.read_h5ad(file)
    data = dict()
    data['log1p'] = adata.X.toarray()
    data['gene_id'] = adata.var.index.values
    
    if gene_file:
        df_ = pd.read_csv(gene_file)
        for key in df_.keys():
            # check key include gene in the name
            if 'gene' in key.lower():
                gene_list = df_[key].values
                break
        # search gene list in the data
        gene_index = [np.where(data['gene_id'] == gg)[0][0] for gg in gene_list]
        data['log1p'] = data['log1p'][:, gene_index]
        data['gene_id'] = data['gene_id'][gene_index]
    
    if n_gene > 0:
        data['log1p'] = data['log1p'][:, :n_gene]
        data['gene_id'] = data['gene_id'][:n_gene]
            
    print(f"Number of cells: {data['log1p'].shape[0]}, Number of genes: {data['log1p'].shape[1]}")

    return data
    
    

def get_data(x, train_size, additional_val, seed=0):

        test_size = x.shape[0] - train_size
        train_cpm, test_cpm, train_ind, test_ind = train_test_split(x, np.arange(x.shape[0]), train_size=train_size, test_size=test_size, random_state=seed)
        
        if additional_val:
            train_cpm, val_cpm, train_ind, val_ind = train_test_split(train_cpm, train_ind, train_size=train_size - test_size, test_size=test_size, random_state=seed)
        else:
            val_cpm = []
            val_ind = []

        return train_cpm, val_cpm, test_cpm, train_ind, val_ind, test_ind



def get_loaders(x, label=[], batch_size=128, train_size=0.9, n_aug_smp=0, netA=None, aug_param=0., device=None, additional_val=False):

    if len(label) > 0:
        train_ind, val_ind, test_ind = [], [], []
        for ll in np.unique(label):
            indx = np.where(label == ll)[0]
            tt_size = int(train_size * sum(label == ll))
            _, _, _, train_subind, val_subind, test_subind = get_data(x, tt_size, additional_val)
            train_ind.append(indx[train_subind])
            test_ind.append(indx[test_subind])
            if additional_val:
                val_ind.append(indx[val_subind])

        train_ind = np.concatenate(train_ind)
        test_ind = np.concatenate(test_ind)
        train_set = x[train_ind, :]
        test_set = x[test_ind, :]
        
        if additional_val:
            val_ind = np.concatenate(val_ind)
            val_set = x[val_ind, :]
        
    else:
        tt_size = int(train_size * x.shape[0])
        train_set, val_set, test_set, train_ind, val_ind, test_ind = get_data(x, tt_size, additional_val)

    train_set_torch = torch.FloatTensor(train_set)
    train_ind_torch = torch.FloatTensor(train_ind)
    if n_aug_smp > 0:
        train_set = train_set_torch.clone()
        train_set_ind = train_ind_torch.clone()
        for _ in range(n_aug_smp):
            if netA:
                noise = 0.1*torch.randn(train_set_torch.shape[0], aug_param['num_n'], device=device)
                if device:
                    _, gen_data = netA(train_set_torch.cuda(device), noise, True, device)
                else:
                    _, gen_data = netA(train_set_torch, noise, True, device)
                data_bin = 0. * train_set_torch
                data_bin[train_set_torch > 1e-4] = 1.
                fake_data = gen_data * data_bin
                train_set = torch.cat((train_set, fake_data.cpu().detach()), 0)

            else:
                train_set = torch.cat((train_set, train_set_torch), 0)
                
            train_set_ind = torch.cat((train_set_ind, train_ind_torch), 0)

        train_data = TensorDataset(train_set, train_set_ind)
    else:
        train_data = TensorDataset(train_set_torch, train_ind_torch)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

    test_set_torch = torch.FloatTensor(test_set)
    test_ind_torch = torch.FloatTensor(test_ind)
    test_data = TensorDataset(test_set_torch, test_ind_torch)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=True, drop_last=False, pin_memory=True)

    data_set_troch = torch.FloatTensor(x)
    all_ind_torch = torch.FloatTensor(range(x.shape[0]))
    all_data = TensorDataset(data_set_troch, all_ind_torch)
    alldata_loader = DataLoader(all_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
    
    if additional_val:
        val_set_torch = torch.FloatTensor(val_set)
        val_ind_torch = torch.FloatTensor(val_ind)
        validation_data = TensorDataset(val_set_torch, val_ind_torch)
        validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True)
    else:
        validation_loader = []
        val_ind = []

    return alldata_loader, train_loader, test_loader, validation_loader, test_ind, val_ind



def multiome_loader(
                    rna_adata: str, 
                    atac_adata: str, 
                    batch_size: int, 
                    num_workers: int = 4, 
                    training_size: float = 0.9, 
                    world_size: int = 1,
                    use_dist_sampler: bool = False,
                    only_shared: bool = False, 
                    genes: np.array = np.array([]),
                    regions: np.array = np.array([]),
                    rank: int = 0,
                    random_seed: int = 0,
                    prefetch_factor: int = 2,
                    ):
    """
    Prepares a DataLoader for RNA and ATAC data.

    Args:
        rna_adata (str): Path to the RNA AnnData file (.h5ad).
        atac_adata (str): Path to the ATAC AnnData file (.h5ad).
        batch_size (int): Batch size for the DataLoader.
        num_workers (int): Number of worker threads for data loading.
        training_size (float): Fraction of the data to use for training.
        only_shared (bo ol): Whether to use only the shared cells between RNA and ATAC data.
        world_size (int): Number of processes in the distributed setup.
        use_dist_sampler (bool): Whether to use the distributed sampler.
        rank (int): Rank of the current process.
        random_seed (int): Random seed for data splitting.
        prefetch_factor (int): Number of samples loaded in advance by each worker.

    Returns:
        DataLoader: PyTorch DataLoader that yields RNA and ATAC batches.
    """
    # Load AnnData files
    print("Loading AnnData files ...")
    rna_adata = ad.read_h5ad(rna_adata, backed='r')
    atac_adata = ad.read_h5ad(atac_adata, backed='r')
    
    if only_shared:
        print("Using only shared cells between RNA and ATAC data.")
        rna_adata = rna_adata[(rna_adata.obs['RNA'] == 'True') & (rna_adata.obs['ATAC'] == 'True')]
        atac_adata = atac_adata[(atac_adata.obs['RNA'] == 'True') & (atac_adata.obs['ATAC'] == 'True')]
    
    print(f"RNA data shape: {rna_adata.shape}")
    print(f"ATAC data shape: {atac_adata.shape}")

    # Create the dataset
    dataset = RNA_ATAC_Dataset(rna_adata, atac_adata, genes, regions)
    del rna_adata, atac_adata
    
    train_size = int(training_size * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))
    
    # Determine the number of available CPU cores
    available_cores = len(os.sched_getaffinity(0)) # multiprocessing.cpu_count() and os.cpu_count() do not returns the correct number of cores
    if num_workers > available_cores:
        print(f"Number of available CPU cores: {available_cores}")

    # Set the number of workers to the minimum of the available cores or the suggested max number
    num_workers = min(available_cores, num_workers)

    if world_size > 1 or use_dist_sampler:
        train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=random_seed)
        test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        all_sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        print(f"Number of workers: {num_workers} (distributed sampler)")
    else:
        train_sampler = None
        test_sampler = None
        all_sampler = None
        print(f"Number of workers: {num_workers}")
    
    # Create the DataLoader
    train_loader = DataLoader(
                            train_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=(train_sampler is None),  # Shuffle only if no sampler
                            drop_last=True,
                            pin_memory=True,  
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=train_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    test_loader = DataLoader(
                            test_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=test_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    alldata_loader = DataLoader(
                                dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                persistent_workers=True,
                                prefetch_factor=prefetch_factor,
                                sampler=all_sampler,
                                collate_fn=custom_collate_fn,
                                )
    
    return train_loader, test_loader, alldata_loader



def unimodal_loader(
                    data_file: str, 
                    batch_size: int, 
                    num_workers: int = 4, 
                    training_size: float = 0.9, 
                    world_size: int = 1,
                    use_dist_sampler: bool = False,
                    only_shared: bool = False, 
                    features: np.array = np.array([]),
                    rank: int = 0,
                    random_seed: int = 0,
                    prefetch_factor: int = 2,
                    ):
    """
    Prepares a DataLoader for RNA and ATAC data.

    Args:
        data_file (str): Path to the RNA AnnData file (.h5ad).
        batch_size (int): Batch size for the DataLoader.
        num_workers (int): Number of worker threads for data loading.
        training_size (float): Fraction of the data to use for training.
        only_shared (bool): Whether to use only the shared cells between RNA and ATAC data.
        world_size (int): Number of processes in the distributed setup.
        use_dist_sampler (bool): Whether to use the distributed sampler.
        rank (int): Rank of the current process.
        random_seed (int): Random seed for data splitting.
        prefetch_factor (int): Number of samples loaded in advance by each worker.

    Returns:
        DataLoader: PyTorch DataLoader that yields RNA and ATAC batches.
    """
    # Load AnnData files
    print("Loading AnnData files ...")
    adata = ad.read_h5ad(data_file, backed='r+')
    
    if only_shared:
        print("Using only shared cells between RNA and ATAC data.")
        adata = adata[(adata.obs['RNA'] == 'True') & (adata.obs['ATAC'] == 'True')]
    
    print(f"data shape: {adata.shape}")

    # Create the dataset
    dataset = scDataset(adata, features)
    del adata
    
    train_size = int(training_size * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))
    
    # Determine the number of available CPU cores
    available_cores = len(os.sched_getaffinity(0)) # multiprocessing.cpu_count() and os.cpu_count() do not returns the correct number of cores
    if num_workers > available_cores:
        print(f"Number of available CPU cores: {available_cores}")

    # Set the number of workers to the minimum of the available cores or the suggested max number
    num_workers = min(available_cores, num_workers)

    if world_size > 1 or use_dist_sampler:
        train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=random_seed)
        test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        all_sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        print(f"Number of workers: {num_workers} (distributed sampler)")
    else:
        train_sampler = None
        test_sampler = None
        all_sampler = None
        print(f"Number of workers: {num_workers}")
    
    # Create the DataLoader
    train_loader = DataLoader(
                            train_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=(train_sampler is None),  # Shuffle only if no sampler
                            drop_last=True,
                            pin_memory=True,  
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=train_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    test_loader = DataLoader(
                            test_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=test_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    alldata_loader = DataLoader(
                                dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                persistent_workers=True,
                                prefetch_factor=prefetch_factor,
                                sampler=all_sampler,
                                collate_fn=custom_collate_fn,
                                )
    
    return train_loader, test_loader, alldata_loader

def load_smartseq(data_file):
    data = np.load(data_file)
    obs_df = pd.DataFrame()
    obs_df['specimen_id'] = 1e3 + np.arange(data['logcpm'].shape[0])
    obs_df['SS'] = np.array([True] * len(obs_df['specimen_id']))
    andata_smseq = ad.AnnData(X=data['logcpm'], obs=obs_df)
    return andata_smseq


def load_EM_arbors(data_file):
    data = np.load(data_file)
    obs_df = pd.DataFrame()
    obs_df['specimen_id'] = 1e6 + np.arange(data['arbors'].shape[0])
    obs_df['EM'] = np.array([True] * len(obs_df['specimen_id']))
    andata_EM = ad.AnnData(X=data['arbors'].reshape(data['arbors'].shape[0], -1) , obs=obs_df, uns={'origin_shape':data['arbors'].shape})
    return andata_EM
    

def load_npz_patchseq(data_file, anno_file, exclude_type=[], min_num=0):
    data_npz = np.load(data_file)
    annotations = np.load(anno_file)
    uniq_clusters = np.unique(data_npz['cluster_label'])
    uniq_clusters = uniq_clusters[uniq_clusters != 'nan']
    count = np.zeros(len(uniq_clusters))
    for it, tt in enumerate(uniq_clusters):
        count[it] = sum(data_npz['cluster_label'] == tt)

    uniq_clusters = uniq_clusters[count >= min_num]

    subclass_ind = []
    for tt in uniq_clusters:
        subclass_ind.append(np.array([i for i in range(len(data_npz['cluster_label'])) if tt == data_npz['cluster_label'][i]]))

    subclass_ind = np.concatenate(subclass_ind)
    ref_len = len(data_npz['cluster_label'])
    all_key = list(data_npz.keys())
    data = dict()
    for k in all_key:
        data[k] = np.array(data_npz[k])[subclass_ind]
        # if len(data_npz[k]) >= ref_len:
            # if k == 'pca-ipfx' or k == 'logcpm':
            #     data[k] = np.array(data_npz[k])[subclass_ind, :]
            # elif k == 'arbors':
            #     try:
            #         data[k] = np.array(data_npz[k])[subclass_ind, :, :, :]
            #     except:
            #         data[k] = np.array(data_npz[k])[subclass_ind, :]
            # else:
            #     data[k] = np.array(data_npz[k])[subclass_ind]

    data['cluster_label'] = np.array([c.strip() for c in data['cluster_label']])

    if len(exclude_type) > 0:
        subclass_ind = np.array([i for i in range(len(data['cluster_label'])) if data['cluster_label'][i] not in exclude_type])
        ref_len = len(data['cluster_label'])
        all_key = list(data.keys())
        for k in all_key:
            data[k] = np.array(data[k])[subclass_ind]
            # if len(data[k]) >= ref_len:
            #     if k == 'pca-ipfx' or k == 'logcpm':
            #         data[k] = np.array(data[k])[subclass_ind, :]
            #     elif k == 'arbors':
            #         try:
            #             data[k] = np.array(data[k])[subclass_ind, :, :, :]
            #         except:
            #             data[k] = np.array(data[k])[subclass_ind, :]
            #     else:
            #         data[k] = np.array(data[k])[subclass_ind]

    data['cluster_id'] = np.zeros(len(data['cluster_label']))
    for i, c in enumerate(np.unique(data['cluster_label'])):
        data['cluster_id'][data['cluster_label'] == c] = i+1

    obs_df = pd.DataFrame()
    obs_df['cluster_label'] = data['cluster_label']
    obs_df['subclass'] = data['subclass']
    obs_df['cluster_id'] = data['cluster_id'].astype(int)
    obs_df['specimen_id'] = data['specimen_id']
    obs_df['T'] = np.array([True] * len(obs_df['cluster_label']))
    gene_id = np.array([c.strip() for c in annotations['gene_ids']])
    adata_T = ad.AnnData(X=data['logcpm'], obs=obs_df, var=gene_id)
    obs_df = pd.DataFrame()
    obs_df['cluster_label'] = data['cluster_label']
    obs_df['subclass'] = data['subclass']
    obs_df['cluster_id'] = data['cluster_id'].astype(int)
    obs_df['specimen_id'] = data['specimen_id']
    obs_df['E'] = np.array([True] * len(obs_df['cluster_label']))
    adata_E = ad.AnnData(X=data['pca-ipfx'], obs=obs_df, var=annotations['pca-ipfx_features'])
    obs_df = pd.DataFrame()
    obs_df['cluster_label'] = data['cluster_label']
    obs_df['subclass'] = data['subclass']
    obs_df['cluster_id'] = data['cluster_id'].astype(int)
    obs_df['specimen_id'] = data['specimen_id']
    obs_df['M'] = np.array([True] * len(obs_df['cluster_label']))
    adata_M = ad.AnnData(X=data['arbors'].reshape(data['arbors'].shape[0], -1) , obs=obs_df, uns={'origin_shape':data['arbors'].shape})

    # D['XE'] = data['pca-ipfx']
    # D['XM'] = data['arbors']
    # D['T'] = ['True'] * D['XT'].shape[0]
    # D['E'] = ['True'] * D['XE'].shape[0]
    # D['M'] = ['True'] * D['XM'].shape[0]
    # D['cluster_label'] = data['cluster_label']
    # D['cluster_id'] = data['cluster_id'].astype(int)
    # D['specimen_id'] = data['specimen_id']
    # D['gene_ids'] = np.array([c.strip() for c in annotations['gene_ids']])
    # D['E_features'] = annotations['pca-ipfx_features']
    # # D['M_features'] = data['M_features']

    return adata_T, adata_E, adata_M



def patchseq_loader(
                    data: list, 
                    batch_size: int, 
                    label: str = '',
                    num_workers: int = 4, 
                    training_size: float = 0.9, 
                    world_size: int = 1,
                    use_dist_sampler: bool = False,
                    only_shared: bool = False, 
                    rank: int = 0,
                    random_seed: int = 0,
                    prefetch_factor: int = 2,
                    ):
    """
    Prepares a DataLoader for RNA and ATAC data.

    Args:
        data_dict (dict): Dictionary containing the data for all modalities.
        batch_size (int): Batch size for the DataLoader.
        num_workers (int): Number of worker threads for data loading.
        training_size (float): Fraction of the data to use for training.
        only_shared (bool): Whether to use only the shared cells between RNA and ATAC data.
        world_size (int): Number of processes in the distributed setup.
        use_dist_sampler (bool): Whether to use the distributed sampler.
        rank (int): Rank of the current process.
        random_seed (int): Random seed for data splitting.
        prefetch_factor (int): Number of samples loaded in advance by each worker.

    Returns:
        DataLoader: PyTorch DataLoader that yields RNA and ATAC batches.
    """

    # Create the dataset
    if label == '':
        dataset = patchseq_Dataset(*data)
    else:
        dataset = patchseq_Dataset(*data, labels=label)
    
    train_size = int(training_size * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))
    
    # Determine the number of available CPU cores
    available_cores = len(os.sched_getaffinity(0)) # multiprocessing.cpu_count() and os.cpu_count() do not returns the correct number of cores
    if num_workers > available_cores:
        print(f"Number of available CPU cores: {available_cores}")

    # Set the number of workers to the minimum of the available cores or the suggested max number
    num_workers = min(available_cores, num_workers)

    if world_size > 1 or use_dist_sampler:
        train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=random_seed)
        test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        all_sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        print(f"Number of workers: {num_workers} (distributed sampler)")
    else:
        train_sampler = None
        test_sampler = None
        all_sampler = None
        print(f"Number of workers: {num_workers}")
    
    # Create the DataLoader
    train_loader = DataLoader(
                            train_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=(train_sampler is None),  # Shuffle only if no sampler
                            drop_last=True,
                            pin_memory=True,  
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=train_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    test_loader = DataLoader(
                            test_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=test_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    alldata_loader = DataLoader(
                                dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                persistent_workers=True,
                                prefetch_factor=prefetch_factor,
                                sampler=all_sampler,
                                collate_fn=custom_collate_fn,
                                )
    
    return train_loader, test_loader, alldata_loader


def scLoader(
            adata: ad.AnnData, 
            features: np.array,
            metadata: str = None,
            batch_size: int = 512, 
            num_workers: int = 4, 
            training_size: float = 0.9, 
            world_size: int = 1,
            use_dist_sampler: bool = False,
            only_shared: bool = False, 
            rank: int = 0,
            random_seed: int = 0,
            prefetch_factor: int = 2,
            ):
    """
    Prepares a DataLoader for RNA and ATAC data.

    Args:
        data_dict (np.array): numopy array containing the data.
        batch_size (int): Batch size for the DataLoader.
        num_workers (int): Number of worker threads for data loading.
        training_size (float): Fraction of the data to use for training.
        only_shared (bool): Whether to use only the shared cells between RNA and ATAC data.
        world_size (int): Number of processes in the distributed setup.
        use_dist_sampler (bool): Whether to use the distributed sampler.
        rank (int): Rank of the current process.
        random_seed (int): Random seed for data splitting.
        prefetch_factor (int): Number of samples loaded in advance by each worker.

    Returns:
        DataLoader: PyTorch DataLoader that yields RNA and ATAC batches.
    """

    # Create the dataset
    dataset = scDataset(adata, features, metadata)
    train_size = int(training_size * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(random_seed))
    
    # Determine the number of available CPU cores
    available_cores = len(os.sched_getaffinity(0)) # multiprocessing.cpu_count() and os.cpu_count() do not returns the correct number of cores
    if num_workers > available_cores:
        print(f"Number of available CPU cores: {available_cores}")

    # Set the number of workers to the minimum of the available cores or the suggested max number
    num_workers = min(available_cores, num_workers)

    if world_size > 1 or use_dist_sampler:
        train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=random_seed)
        test_sampler = DistributedSampler(test_dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        all_sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False, seed=random_seed)
        print(f"Number of workers: {num_workers} (distributed sampler)")
    else:
        train_sampler = None
        test_sampler = None
        all_sampler = None
        print(f"Number of workers: {num_workers}")
    
    # Create the DataLoader
    train_loader = DataLoader(
                            train_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=(train_sampler is None),  # Shuffle only if no sampler
                            drop_last=True,
                            pin_memory=True,  
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=train_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    test_loader = DataLoader(
                            test_dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            persistent_workers=True,
                            prefetch_factor=prefetch_factor,
                            sampler=test_sampler,
                            collate_fn=custom_collate_fn,
                            )
    
    alldata_loader = DataLoader(
                                dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                persistent_workers=True,
                                prefetch_factor=prefetch_factor,
                                sampler=all_sampler,
                                collate_fn=custom_collate_fn,
                                )
    
    return train_loader, test_loader, alldata_loader


def load_param_file(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:  # Ensure it's not an empty line
                try:
                    parsed_line = json.loads(line)  # Parse JSON (dict or value)
                except json.JSONDecodeError:
                    parsed_line = line  # Keep as string if it's not valid JSON
                data.append(parsed_line)
    return data