
from functools import partial
from math import exp

import datasets
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from packaging import version

_CITATION = 


def _squad_metric(predictions, references):
    squad_metric = datasets.load_metric("squad_v2")
    return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
    predictions, references = zip(*items)

    return _squad_metric(predictions=predictions, references=references).get(key, 0)


class SQuAD2(Task):
    VERSION = 1
    DATASET_PATH = "squad_v2"
    DATASET_NAME = None

    
    assert version.parse(datasets.__version__) >= version.parse(
        "1.11.0"
    ), "datasets v1.11.0 or later required for SQuAD"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        return (
            "Title: "
            + doc["title"]
            + "\n\n"
            + "Background: "
            + doc["context"]
            + "\n\n"
            + "Question: "
            + doc["question"]
            + "\n\n"
            + "Answer:"
        )

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["context"]

    def doc_to_target(self, doc):
        answer_list = doc["answers"]["text"]
        if len(answer_list) > 0:
            answer = answer_list[0]
        else:
            answer = "unanswerable"
        return " " + answer

    def construct_requests(self, doc, ctx):
        
        continuation = rf.greedy_until(ctx, ["\n"])
        is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
        return continuation, is_unanswerable

    def process_results(self, doc, results):
        
        continuation, (logprob_unanswerable, _) = results

        no_answer_probability = exp(logprob_unanswerable)

        predictions = {
            "id": doc["id"],
            "prediction_text": continuation,
            "no_answer_probability": no_answer_probability,
        }

        references = {
            "id": doc["id"],
            "answers": doc["answers"],
        }

        return {
            "exact": (
                predictions,
                references,
            ),  
            "f1": (
                predictions,
                references,
            ),  
            "HasAns_exact": (
                predictions,
                references,
            ),  
            "HasAns_f1": (
                predictions,
                references,
            ),  
            "NoAns_exact": (
                predictions,
                references,
            ),  
            "NoAns_f1": (
                predictions,
                references,
            ),  
            "best_exact": (
                predictions,
                references,
            ),  
            "best_f1": (predictions, references),  
        }

    def aggregation(self):
        
        return {
            "exact": partial(
                _squad_agg, "exact"
            ),  
            "f1": partial(_squad_agg, "f1"),  
            "HasAns_exact": partial(
                _squad_agg, "HasAns_exact"
            ),  
            "HasAns_f1": partial(
                _squad_agg, "HasAns_f1"
            ),  
            "NoAns_exact": partial(
                _squad_agg, "NoAns_exact"
            ),  
            "NoAns_f1": partial(_squad_agg, "NoAns_f1"),  
            "best_exact": partial(_squad_agg, "best_exact"),  
            "best_f1": partial(_squad_agg, "best_f1"),  
        }

    def higher_is_better(self):
        
        return {
            "exact": True,  
            "f1": True,  
            "HasAns_exact": True,  
            "HasAns_f1": True,  
            "NoAns_exact": True,  
            "NoAns_f1": True,  
            "best_exact": True,  
            "best_f1": True,  
        }
