import resource
from typing import Optional

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from evaluate import load
from fire import Fire
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
)

from mcsmoe.data import (
    CasualZeroShotDataPreProcessor,
    tokenize_casual_zero_shot,
    TASK_MAPPING_DATASET_ARGUMENTS,
    DataCollatorForLanguageModeling,
    build_index_for_dataset,
    gather_predictions_references_by_casual_lm_loss
)
from mcsmoe.merging.utils import load_merged_fsgpt_moe_from_checkpoint

set_seed(42)
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))


def evaluate_merged_downstream_zero_shot(
        task: str,
        checkpoint: Optional[str] = None,
        eval_batch_size: int = 32,
):
    try:
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    except OSError:
        print(f"[Evaluation warning] No tokenizer found in checkpoint {checkpoint}, "
              f"using hf-private-path instead.")
        tokenizer = AutoTokenizer.from_pretrained("hf-private-path")
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        keys_to_ignore=["answer_idx", "choice_idx", "idx"]
    )
    model = load_merged_fsgpt_moe_from_checkpoint(
        checkpoint=checkpoint,
        sanity_check=False
    )
    print(f"Model number of parameters: {model.num_parameters()}")

    raw_dataset = load_dataset(*TASK_MAPPING_DATASET_ARGUMENTS[task])
    eval_dataset = raw_dataset["validation"] if task != "mnli" else raw_dataset["validation_matched"]

    eval_dataset = build_index_for_dataset(eval_dataset)
    eval_dataset = eval_dataset.map(
        CasualZeroShotDataPreProcessor(benchmark=task),
        batched=True,
        num_proc=8,
        remove_columns=eval_dataset.column_names,
        load_from_cache_file=False
    )
    tokenized_eval_dataset = eval_dataset.map(
        lambda x: tokenize_casual_zero_shot(tokenizer=tokenizer, batch=x),
        num_proc=8,
        batched=True,
        remove_columns=eval_dataset.column_names,
        load_from_cache_file=False
    )
    eval_dataloader = DataLoader(
        tokenized_eval_dataset,
        shuffle=False,
        batch_size=eval_batch_size,
        num_workers=4,
        collate_fn=data_collator
    )

    model.eval()
    ids_list = []
    answer_ids_list = []
    choice_ids_list = []
    losses_list = []
    try:
        metric = load(*TASK_MAPPING_DATASET_ARGUMENTS[task])
    except FileNotFoundError:
        print(f"[Evaluation warning] No metric found for task {task}, using accuracy instead.")
        metric = load("accuracy")
    for eval_step, eval_batch in tqdm(enumerate(eval_dataloader),
                                      total=len(eval_dataloader), desc="Evaluating"):
        ids_list += eval_batch.pop("idx")
        answer_ids_list += eval_batch.pop("answer_idx")
        choice_ids_list += eval_batch.pop("choice_idx")
        eval_batch = {k: v.cuda() for k, v in eval_batch.items()}
        with torch.no_grad():
            with torch.autocast("cuda"):
                logits = model(**eval_batch).logits
            # eval probs
            loss_fct = torch.nn.CrossEntropyLoss()
            labels = eval_batch["labels"]
            shift_labels = labels.new_zeros(labels.shape)
            shift_labels[:, :-1] = labels[:, 1:].clone()
            shift_labels[:, -1] = model.config.pad_token_id
            losses = torch.stack([
                loss_fct(logits[i], shift_labels[i].to(logits.device)) for i in range(logits.shape[0])
            ])
            losses_list += losses.tolist()
    predictions_references = gather_predictions_references_by_casual_lm_loss(
        ids_list=ids_list,
        answer_ids_list=answer_ids_list,
        choice_ids_list=choice_ids_list,
        losses_list=losses_list,
    )
    predictions = predictions_references["predictions"]
    references = predictions_references["references"]
    if task == "multirc":
        predictions = [
            {'prediction': p, 'idx': id} for p, id in zip(predictions, output_ids)
        ]
    elif task == "hotpotqa":
        predictions = [
            {'prediction_text': p, 'id': id} for p, id in zip(predictions, output_ids)
        ]
        references = [
            # answer_start is not used in the evaluation, so fake it
            {'answers': {'text': [r], 'answer_start': [2333]}, 'id': id} for r, id in
            zip(references, output_ids)
        ]
    eval_res = metric.compute(predictions=predictions, references=references)
    print(f"{task} evaluation result: {eval_res}")


if __name__ == '__main__':
    Fire(evaluate_merged_downstream_zero_shot)
