
import numpy as np
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from efficiency_benchmark.dependencies.lm_eval.metrics import mean

_CITATION = 


class WinogradSchemaChallenge273(Task):
    VERSION = 0
    DATASET_PATH = "winograd_wsc"
    DATASET_NAME = "wsc273"

    upper_pronouns = [
        "A",
        "An",
        "The",
        "She",
        "He",
        "It",
        "They",
        "My",
        "His",
        "Her",
        "Their",
    ]

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def test_docs(self):
        return map(self._process_doc, self.dataset["test"])

    def _process_doc(self, doc):
        
        doc["text"] = doc["text"].replace("  ", " ")
        doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
        doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
        return doc

    def __normalize_option(self, doc, option):
        
        if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
            option += "'s"
        
        pronoun = option.split()[0]
        start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
        if not start_of_sentence and pronoun in self.upper_pronouns:
            return option.replace(pronoun, pronoun.lower())
        return option

    def fewshot_examples(self, k, rnd):
        
        

        if self._fewshot_docs is None:
            self._fewshot_docs = list(self.test_docs())

        return rnd.sample(list(self._fewshot_docs), k)

    def doc_to_text(self, doc):
        return self.partial_context(doc, doc["options"][doc["label"]])

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["text"]

    @classmethod
    def partial_context(cls, doc, option):
        
        
        return doc["text"][: doc["pronoun_loc"]] + option

    def doc_to_target(self, doc):
        return self.partial_target(doc)

    @classmethod
    def partial_target(cls, doc):
        
        start_index = doc["pronoun_loc"] + len(doc["pronoun"])
        return " " + doc["text"][start_index:].strip()

    def construct_requests(self, doc, ctx):
        
        target = self.partial_target(doc)
        lls = []
        for option in doc["options"]:
            partial_ctx = self.partial_context(doc, option)
            full_ctx = self.append_context(ctx, partial_ctx)
            lls.append(rf.loglikelihood(full_ctx, target)[0])
        return lls

    @classmethod
    def append_context(cls, ctx, partial_ctx):
        ctx = ctx.split("\n\n")  
        ctx.pop()  
        return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx

    def process_results(self, doc, results):
        
        return {"acc": np.argmax(results) == doc["label"]}

    def aggregation(self):
        
        return {"acc": mean}

    def higher_is_better(self):
        
        return {"acc": True}
