import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from fire import Fire
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    T5TokenizerFast
)

from mcsmoe.data import (
    Seq2SeqDataPreProcessor,
    tokenize_seq2seq,
    TASK_MAPPING_DATASET_ARGUMENTS,
    DataCollatorForSeq2Seq,
    EXTRA_KEYS_FOR_EVAL,
    get_evaluate_fn,
)
from mcsmoe.merging.utils import load_merged_switch_transformers_from_checkpoint

set_seed(42)


def evaluate_merged_switch_downstream(
        checkpoint: str = None,
        task: str = None
):
    try:
        tokenizer = T5TokenizerFast.from_pretrained(checkpoint)
    except OSError:
        tokenizer = T5TokenizerFast.from_pretrained("google/switch-base-32")
    model = load_merged_switch_transformers_from_checkpoint(
        checkpoint=checkpoint,
        sanity_check=True
    )

    raw_dataset = load_dataset(*TASK_MAPPING_DATASET_ARGUMENTS[task])
    eval_dataset = raw_dataset["validation"]
    eval_dataset = eval_dataset.map(
        Seq2SeqDataPreProcessor(benchmark=task, keep_specific_keys=EXTRA_KEYS_FOR_EVAL),
        batched=True,
        num_proc=6,
        remove_columns=eval_dataset.column_names
    )
    tokenized_eval_dataset = eval_dataset.map(
        lambda x: tokenize_seq2seq(tokenizer=tokenizer, batch=x, keep_other_keys=True),
        num_proc=6,
        batched=True,
        remove_columns=eval_dataset.column_names,
        load_from_cache_file=False
    )
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                           max_length=tokenizer.model_max_length,
                                           return_tensors='pt',
                                           keys_to_ignore=EXTRA_KEYS_FOR_EVAL)
    eval_dataloader = DataLoader(
        tokenized_eval_dataset,
        shuffle=False,
        collate_fn=data_collator,
        batch_size=64,
        num_workers=8
    )

    evaluate_fn = get_evaluate_fn(
        task=task,
        tokenizer=tokenizer,
        raw_eval_dataset=raw_dataset["validation"]
    )

    model.eval()
    model.to("cuda")
    output_labels = []
    output_predictions = []
    output_ids = [] if task in ["squad", "copa", "multirc"] else None
    for eval_step, eval_batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader),
                                      desc="Evaluating"):
        extra_keys_eval_batch = {}
        for key in list(eval_batch.keys()):
            if key in EXTRA_KEYS_FOR_EVAL:
                extra_keys_eval_batch[key] = eval_batch.pop(key)
        eval_batch = {k: v.cuda() for k, v in eval_batch.items()}
        with torch.no_grad():
            outputs = model(**eval_batch)
        eval_labels = eval_batch['labels']
        output_labels += torch.cat([
            eval_labels,
            torch.ones(eval_labels.shape[0], tokenizer.model_max_length - eval_labels.shape[1],
                       dtype=eval_labels.dtype,
                       device=eval_labels.device) * -100
        ], dim=-1)
        eval_logits = outputs.logits
        output_predictions += eval_logits.argmax(dim=-1).tolist()
        if task == "squad":
            output_ids += extra_keys_eval_batch["id"]
        elif task == "copa" or task == "multirc":
            output_ids += extra_keys_eval_batch["idx"]
    output_labels = torch.stack(output_labels, dim=0)
    eval_res = evaluate_fn(predictions=output_predictions, labels=output_labels, ids=output_ids)

    print("[Evaluation]Model number of parameters: ", model.num_parameters())
    print("[Evaluation]Result", eval_res)


if __name__ == '__main__':
    Fire(evaluate_merged_switch_downstream)
