
import collections

import datasets
import numpy as np
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from efficiency_benchmark.dependencies.lm_eval.metrics import mean

_CITATION = 


class each:
    def __init__(self, f):
        self.f = f

    def __rrshift__(self, other):
        return list(map(self.f, other))


class RACE(Task):
    VERSION = 1
    DATASET_PATH = "race"
    DATASET_NAME = "high"

    cache = {}
    letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _collate_data(self, set):
        if set in self.cache:
            return self.cache[set]
        
        
        

        r = collections.defaultdict(list)
        for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
            r[item["article"]].append(item)

        res = list(
            r.values()
            >> each(
                lambda x: {
                    "article": x[0]["article"],
                    "problems": x
                    >> each(
                        lambda y: {
                            "question": y["question"],
                            "answer": y["answer"],
                            "options": y["options"],
                        }
                    ),
                }
            )
        )

        self.cache[set] = res
        return res

    def training_docs(self):
        return self._collate_data("train")

    def validation_docs(self):
        return self._collate_data("validation")

    def test_docs(self):
        return self._collate_data("test")

    @classmethod
    def get_answer_option(cls, problem):
        answer = cls.letter_to_num[problem["answer"]]
        return problem["options"][answer]

    @classmethod
    def last_problem(cls, doc):
        return doc["problems"][-1]

    def doc_to_text(self, doc):
        text = "Article: " + doc["article"] + "\n\n"
        for problem in doc["problems"][:-1]:
            if problem["question"][-6:] == "  _  .":
                text += problem["question"][-5:] + self.get_answer_option(problem) + "\n"
            else:
                question = "Question: " + problem["question"] + "\n"
                answer = "Answer: " + self.get_answer_option(problem) + "\n"
                text += question + answer
        text += self.last_problem(doc)["question"]
        return text

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["article"]

    def doc_to_target(self, doc):
        return " " + self.get_answer_option(self.last_problem(doc))

    def construct_requests(self, doc, ctx):
        
        problem = self.last_problem(doc)
        ll_choices = [rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)]
        return ll_choices

    def process_results(self, doc, results):
        
        gold = self.letter_to_num[self.last_problem(doc)["answer"]]
        pred = np.argmax(results)
        return {"acc": int(pred == gold)}

    def aggregation(self):
        
        return {"acc": mean}

    def higher_is_better(self):
        
        return {"acc": True}
