import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk,concatenate_datasets
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
from general_utils.h5ad_to_dataset import data_preparation_geneformer,extract_emb_geneformer
import argparse
from pathlib import Path


def parse_args():
    parse = argparse.ArgumentParser(description='Cell Embedding Retrieval from data')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    parse.add_argument('--output_dir', default=None, type=str, help='Output file directory')
    parse.add_argument('--model_save_dir', default=None, type=str, help='Saved model directory')

    args = parse.parse_args()  
    return args

args = parse_args()
dataset_path , dataset_organ,adata = data_preparation_geneformer(args.input_adata)
original_index_list = list(adata.obs.index)
train_dataset=load_from_disk(dataset_path)
path = args.model_save_dir
obj = Path(path)

if (obj.exists()):
    extract_emb_geneformer(original_index_list,8,"{}".format(path),dataset_path,dataset_organ,args.output_dir,adata)
else:
    print("Please enter a correct path for a pretrained model.")
