import numpy as np
import pandas as pd
import anndata as ad
import argparse
from tqdm import tqdm
from scipy.sparse import csr_matrix
import pdb
import time

from gmmot.utils.config_tools import get_paths


def get_data(class_label, n_gene=10000, saving=True):
    
    t0 = time.time()    
    toml_file = 'pyproject.toml'
    dataset = 'ageing'
    age_key = 'age_cat'
    
    config = get_paths(toml_file=toml_file, sub_file='scrna')
    data_file = config['paths']['data_path'] / dataset / config['scrna']['file_h5ad']

    print('Reading data files ...')
    adata = ad.read_h5ad(data_file)
    genes = adata.var.index.values

    print(f'Number of cells: {adata.X.shape[0]}')
    print(f'Number of genes: {adata.X.shape[1]}')
    
    gene_file = config['paths']['data_path'] / dataset / config['scrna']['genes']
    opc_marker_file = config['paths']['data_path'] / dataset / config['scrna']['opc_markers']

    marker_genes_ = pd.read_csv(gene_file, index_col=0)['x'].values
    opc_markers = pd.read_csv(opc_marker_file, index_col=0)['x'].values
    marker_genes = np.union1d(marker_genes_, opc_markers)
   
    output_path = config['paths']['data_path'] / dataset 
   
    if class_label:
        col = 'class_label'
        unique_labels = adata.obs[col].unique()
        types = [st for st in unique_labels if st.startswith(class_label)]
        adata = adata[adata.obs[col].isin(set(types)).values]
    else:
        class_label = 'all'
        

    print(f'Final number of cells: {adata.X.shape[0]}')
   
    data_df = adata.obs.copy()
    time_points = data_df[age_key].unique()
    age = np.array(data_df[age_key].values)
    
    print('Sorting genes ...')
    mean_gene_counts = []
    
    with tqdm(total=len(time_points)) as pbar:
        for it, t in enumerate(time_points):
            ind = np.where(age == t)[0]
            mean_gene_counts.append(np.mean(adata.X[ind, :].toarray(), axis=0))
            pbar.update(1)
        
    print(f"Time elapsed: {(time.time() - t0)/60:.2f} min")

    g_std = np.std(np.array(mean_gene_counts), axis=0)

    sorted_genes = genes[np.argsort(g_std)[::-1]]
    selected_genes = np.union1d(sorted_genes[:n_gene], marker_genes)
    print(f'Number of marker genes:{len(marker_genes)}, Number of selected genes:{len(selected_genes)}')

    if saving:
        df_gene = pd.DataFrame(g_std, index=genes, columns=['std'])
        df_gene.to_csv(output_path / f'{class_label}_sorted_genes_adult_to_aged.txt')
        
    index_genes = [i for i, g in enumerate(genes) if g in selected_genes]
        
    new_adata = ad.AnnData(X=adata.X[:, index_genes], obs=data_df, var=pd.DataFrame(index=selected_genes))
    
    if saving:
        print('Saving the data ...')
        new_adata.write_h5ad(output_path / f'Mouse_Aging_10Xv3_{class_label}_nGene_{len(selected_genes)}.h5ad')
        df_gene = pd.DataFrame(selected_genes, columns=['genes'])
        df_gene.to_csv(output_path / f'{class_label}_ageing_genes.txt')

    
    return new_adata