from typing import Optional

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from fire import Fire
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    T5TokenizerFast,
    SwitchTransformersForConditionalGeneration as HFSwitch,
    T5ForConditionalGeneration
)

from mcsmoe.data import (
    Seq2SeqDataPreProcessor,
    tokenize_seq2seq,
    TASK_MAPPING_DATASET_ARGUMENTS,
    DataCollatorForSeq2Seq,
    EXTRA_KEYS_FOR_EVAL,
    get_evaluate_fn,
    keep_only_supporting_facts_in_context_for_hotpotqa
)

set_seed(42)


def evaluate_downstream(
        checkpoint: str = None,
        task: str = None,
        debug: Optional[bool] = False
):
    """
    Parameters
    ----------
    checkpoint: str
        Path to the checkpoint
    task: str
        Downstream task name, minuscule
    debug: bool
        Whether to use debug mode
    """
    try:
        tokenizer = T5TokenizerFast.from_pretrained(checkpoint)
    except OSError:
        tokenizer = T5TokenizerFast.from_pretrained("google/switch-base-32")
    if "t5" in checkpoint:
        model = T5ForConditionalGeneration.from_pretrained(checkpoint)
    else:
        model = HFSwitch.from_pretrained(checkpoint)

    raw_dataset = load_dataset(*TASK_MAPPING_DATASET_ARGUMENTS[task])
    if debug:
        eval_dataset = raw_dataset["train"].select(range(1000))
    else:
        eval_dataset = raw_dataset["validation"] if task != "mnli" else (
            raw_dataset["validation_matched"], raw_dataset["validation_mismatched"]
        )

    eval_dataset = eval_dataset.map(
        keep_only_supporting_facts_in_context_for_hotpotqa,
        batched=False,
        num_proc=6
    )
    eval_dataset = eval_dataset.map(
        Seq2SeqDataPreProcessor(benchmark=task, keep_specific_keys=EXTRA_KEYS_FOR_EVAL),
        batched=True,
        num_proc=6,
        remove_columns=eval_dataset.column_names
    )
    tokenized_eval_dataset = eval_dataset.map(
        lambda x: tokenize_seq2seq(tokenizer=tokenizer, batch=x, keep_other_keys=True),
        num_proc=6,
        batched=True,
        remove_columns=eval_dataset.column_names,
        load_from_cache_file=False
    )
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                           max_length=tokenizer.model_max_length,
                                           return_tensors='pt',
                                           keys_to_ignore=EXTRA_KEYS_FOR_EVAL)
    eval_dataloader = DataLoader(
        tokenized_eval_dataset,
        shuffle=False,
        collate_fn=data_collator,
        batch_size=64,
        num_workers=8
    )

    evaluate_fn = get_evaluate_fn(
        task=task,
        tokenizer=tokenizer,
        raw_eval_dataset=raw_dataset["validation"]
    )

    model.eval()
    model.to("cuda")
    output_labels = []
    output_predictions = []
    output_ids = [] if task in ["squad", "copa", "multirc", "hotpotqa"] else None
    for eval_step, eval_batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader),
                                      desc="Evaluating"):
        extra_keys_eval_batch = {}
        for key in list(eval_batch.keys()):
            if key in EXTRA_KEYS_FOR_EVAL:
                extra_keys_eval_batch[key] = eval_batch.pop(key)
        eval_batch = {k: v.cuda() for k, v in eval_batch.items()}
        with torch.no_grad():
            outputs = model(**eval_batch)
        eval_labels = eval_batch['labels']
        output_labels += torch.cat([
            eval_labels,
            torch.ones(eval_labels.shape[0], tokenizer.model_max_length - eval_labels.shape[1],
                       dtype=eval_labels.dtype,
                       device=eval_labels.device) * -100
        ], dim=-1)
        eval_logits = outputs.logits
        output_predictions += eval_logits.argmax(dim=-1).tolist()
        if task in ["squad", "squad_v2", "hotpotqa"]:
            output_ids += extra_keys_eval_batch["id"]
        elif task == "copa" or task == "multirc":
            output_ids += extra_keys_eval_batch["idx"]
    output_labels = torch.stack(output_labels, dim=0)
    eval_res = evaluate_fn(predictions=output_predictions, labels=output_labels, ids=output_ids)
    print(eval_res)


if __name__ == '__main__':
    Fire(evaluate_downstream)
