"""
HellaSwag: Can a Machine Really Finish Your Sentence?
XXXX

Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.

Homepage: XXXX
"""

import re

from oe_eval.tasks.base_task import MultipleChoiceTask
from oe_eval.tasks.utils import make_mcq_prompt

_CITATION = """
@inproceedings{zellers2019hellaswag,
    title={HellaSwag: Can a Machine Really Finish Your Sentence?},
    author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
    booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
    year={2019}
}
"""


class HellaSwag(MultipleChoiceTask):
    VERSION = 0
    TASK_CONFIG_DEFAULTS: dict = {
        "dataset_path": "hellaswag",
        "native_id_field": "ind",
        "primary_metric": "acc_per_char",
        "split": "validation",
    }

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(map(self._process_doc, self.dataset["train"]))
        return self._training_docs

    def validation_docs(self):
        return map(self._process_doc, self.dataset["validation"])

    def unconditioned_prompt(self):
        return "Answer:"

    def _process_doc(self, doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "ind": doc["ind"],
            "query": self.preprocess(doc["activity_label"] + ": " + ctx),
            "choices": [self.preprocess(ending) for ending in doc["endings"]],
            "gold": int(doc["label"]),
        }
        return out_doc

    @classmethod
    def preprocess(cls, text):
        # Health: How to cope with suicidal thoughts. Put off any plans. Promise yourself that you'll wait 48 hours before doing anything. Remember, thoughts don't have the power to force you to act. Sometimes extreme pain can distort our perception. Waiting before taking action will give your mind time to clear.
        text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        # text = text.replace(" [title]", ". ") # original code creates double full stop when . [title]
        text = re.sub("\\.? \\[title\\]", ". ", text)
        text = re.sub("\\[.*?\\]", "", text)
        text = text.replace("  ", " ")
        return text

    def doc_to_text(self, doc):
        return doc["query"]


class HellaSwagMC(HellaSwag):
    TASK_CONFIG_DEFAULTS: dict = {
        "dataset_path": "hellaswag",
        "native_id_field": "ind",
        "primary_metric": "acc_raw",
        "split": "validation",
    }
    # Include answer choices in prompt, answer is just the single letter A, B, ... E.g.,
    # Health: How to choose dental floss. Choose a thick floss when you have large gaps. If you have large spaces between your teeth, pick an extra thick floss. Some options include dental tape or super dental floss.
    # Choose the best continuation:
    #  A. If you have small gaps between your teeth, a thicker floss might work fine. You should also avoid super floss, as it'll damage the fine plastic coating of your teeth.
    #  B. Choosing a thicker floss will help ensure that you are actually flossing all the surfaces of your teeth and makes flossing easier. You'll know you have gaps if a normal dental floss slides in very easily, and you see ample space around it.
    #  C. However, since dental floss is so wide and heavy, you'll need to pick a thinner one which is less brittle than regular dental floss. If your mouth has a very narrow gap between your teeth and your gums, choose an extra floss.
    #  D. You can pick between super and super thick floss with two or four slits in the end to allow blood to be expelled. Weigh your dental floss before you purchase it so you know how much dental floss you'll need.
    # Answer: B

    def _process_doc(self, doc):
        letter_indices = ["A", "B", "C", "D", "E"]
        query = make_mcq_prompt(
            f"{doc['activity_label']}: {doc['ctx_a']} {doc['ctx_b'].capitalize()}",
            doc["endings"],
            question_prefix="",
            choices_prefix="Choose the best continuation:\n",
        )
        query = self.preprocess(query)
        out_doc = {
            "query": query,
            "choices": letter_indices[: len(doc["endings"])],
            "gold": int(doc["label"]) if doc["label"] != "" else -1,
        }
        return out_doc

    def unconditioned_prompt(self):
        # Don't need unconditioned normalization here
        return None
