
import re

from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from efficiency_benchmark.dependencies.lm_eval.metrics import mean

_CITATION = 


ANS_RE = re.compile(r"
INVALID_ANS = "[invalid]"


class GradeSchoolMath8K(Task):
    VERSION = 0
    DATASET_PATH = "gsm8k"
    DATASET_NAME = "main"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        raise NotImplementedError

    def test_docs(self):
        return self.dataset["test"]

    def doc_to_text(self, doc):
        return "Question: " + doc["question"] + "\nAnswer:"

    def doc_to_target(self, doc):
        return " " + doc["answer"]

    def construct_requests(self, doc, ctx):
        
        
        
        completion = rf.greedy_until(ctx, ["\n"])
        return completion

    def _extract_answer(self, completion):
        match = ANS_RE.search(completion)
        if match:
            match_str = match.group(1).strip()
            match_str = match_str.replace(",", "")
            return match_str
        else:
            return INVALID_ANS

    def _is_correct(self, completion, answer):
        gold = self._extract_answer(answer)
        assert gold != INVALID_ANS, "No ground truth answer found in the document."
        return self._extract_answer(completion) == gold

    def process_results(self, doc, results):
        
        completion = results[0]
        answer = doc["answer"]
        return {"acc": self._is_correct(completion, answer)}

    def aggregation(self):
        
        return {"acc": mean}

    def higher_is_better(self):
        
        return {"acc": True}
