import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk,concatenate_datasets
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
from general_utils.h5ad_to_dataset import data_preparation_geneformer
import argparse

def parse_args():
    parse = argparse.ArgumentParser(description='Cell Embedding Retrieval from data')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    parse.add_argument('--pretrained_model_path', default=None, type=str, help='Input file path for pretrained geneformer model')
    parse.add_argument('--model_save_name', default=None, type=str, help='Saved model name')
    args = parse.parse_args()
    return args

args = parse_args()
dataset_path , dataset_organ,adata = data_preparation_geneformer(args.input_adata)
original_index_list = list(adata.obs.index)
train_dataset=load_from_disk(dataset_path)
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ"]).keys():
    if organ in ["bone_marrow"]:
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
    trainset_organ = train_dataset
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]

    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep

    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)

    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ")

    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    num_classes = len(target_dict_list[0].keys())

    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example

    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.3))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.7),len(labeled_trainset))])
    trained_labels = list(Counter(labeled_train_split["label"]).keys())

    def if_trained_label(example):
        return example["label"] in trained_labels

    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
evalset_dict = dict(zip(organ_list,evalset_list))

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')

    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

max_input_size = 2 ** 11  
max_lr = 5e-5
freeze_layers = 0
num_gpus = 4
num_proc = 16
geneformer_batch_size = 1
lr_schedule_fn = "linear"
warmup_steps = 500
epochs = 3
optimizer = "adamw"

for organ in organ_list:
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)

    model = BertForSequenceClassification.from_pretrained(args.pretrained_model_path, 
                                                      num_labels=len(organ_label_dict.keys()),
                                                      output_attentions = False,
                                                      output_hidden_states = False,ignore_mismatched_sizes=True).to("cuda")
    output_dir = "downstream_analysis/geneformer/{}".format(args.model_save_name)
    saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")

    if os.path.isfile(saved_model_test) == True:
        raise Exception("Model already saved to this directory.")

    subprocess.call(f'mkdir {output_dir}', shell=True)
    training_args = {
        "learning_rate": max_lr,
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        "load_best_model_at_end": True,
        "output_dir": output_dir,
    }

    training_args_init = TrainingArguments(**training_args)

    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics
    )

    trainer.train()
    predictions = trainer.predict(organ_evalset)

    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)

    trainer.save_metrics("eval",predictions.metrics)
    trainer.save_model(output_dir)