from geneformer import TranscriptomeTokenizer
import os
import scanpy as sc
from collections import Counter
from geneformer import EmbExtractor
import pandas as pd
import numpy as np

def data_preparation_geneformer(input_adata):
    tk = TranscriptomeTokenizer({"cell_type": "cell_type", "tissue_general": "organ","donor_id":"donor_id","index":"index"}, nproc=16)
    file_name = input_adata.split('/')[-1]
    adata = sc.read(input_adata)
    adata.obs['index'] = list(adata.obs.index)
    organ = list(Counter(adata.obs['tissue_general']).keys())[0]
    try:
        os.mkdir("downstream_analysis/{}".format(organ))
        adata.write("downstream_analysis/{}/{}".format(organ,file_name))
        tk.tokenize_data("downstream_analysis/{}".format(organ),
                        "downstream_analysis/{}".format(organ),
                        "{}".format(organ),
                        file_format="h5ad")

    except OSError as error:
        print("The h5ad file was already processed")

    return "model/{}/{}.dataset".format(organ,organ), organ,adata

def extract_emb_geneformer(original_index_list,num_classes,output_dir,dataset_path,dataset_organ,output_adata,adata):

    embex = EmbExtractor(
                     max_ncells=20000,
                     emb_layer=0,
                     emb_label=["cell_type","donor_id","index"],
                     labels_to_plot=["cell_type","donor_id"],
                     forward_batch_size=10,
                     nproc=2,
                     )
    embs = embex.extract_embs(output_dir,
                            dataset_path,
                            "{}".format(output_dir),
                            "{}".format(dataset_organ))

    embs = pd.DataFrame(embs)
    embs['index'] = embs['index'].astype("category")
    embs['index'] = embs['index'].cat.set_categories(original_index_list)
    embs = embs.sort_values(["index"])
    embs.index = original_index_list
    embs.columns = [str(i) for i in embs.columns]
    npy_output = embs.iloc[:,:-3].to_numpy()
    np.save(output_adata+"/embeds_geneformer.npy",npy_output)
