
import inspect

import efficiency_benchmark.dependencies.lm_eval.datasets.headqa.headqa
from efficiency_benchmark.dependencies.lm_eval.base import MultipleChoiceTask

_CITATION = 


class HeadQABase(MultipleChoiceTask):
    VERSION = 0
    DATASET_PATH = inspect.getfile(efficiency_benchmark.dependencies.lm_eval.datasets.headqa.headqa)

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(map(self._process_doc, self.dataset["train"]))
        return self._training_docs

    def validation_docs(self):
        return map(self._process_doc, self.dataset["validation"])

    def test_docs(self):
        return map(self._process_doc, self.dataset["test"])

    def _process_doc(self, doc):
        out_doc = {
            "id": doc["qid"],
            "query": "Question: " + doc["qtext"] + "\nAnswer:",
            "choices": [answer["atext"] for answer in doc["answers"]],
            "gold": int(doc["ra"]) - 1,
        }
        return out_doc

    def doc_to_text(self, doc):
        return doc["query"]

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["query"]


class HeadQAEn(HeadQABase):
    DATASET_NAME = "en"


class HeadQAEs(HeadQABase):
    DATASET_NAME = "es"



class HeadQAEsDeprecated(HeadQABase):
    DATASET_NAME = "es"

    def __init__(self):
        super().__init__()
        print(
            "WARNING: headqa is deprecated. Please use headqa_es or headqa_en instead. See https://github.com/EleutherAI/lm-evaluation-harness/pull/240 for more info."
        )
