
import inspect

import efficiency_benchmark.dependencies.lm_eval.datasets.arithmetic.arithmetic
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from efficiency_benchmark.dependencies.lm_eval.metrics import mean

_CITATION = 


class Arithmetic(Task):
    VERSION = 0
    DATASET_PATH = inspect.getfile(efficiency_benchmark.dependencies.lm_eval.datasets.arithmetic.arithmetic)

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        return NotImplemented

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return NotImplemented

    def doc_to_text(self, doc):
        return doc["context"]

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["context"]

    def doc_to_target(self, doc):
        return doc["completion"]

    def construct_requests(self, doc, ctx):
        ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
        return is_prediction

    def process_results(self, doc, results):
        (is_prediction,) = results
        return {"acc": is_prediction}

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {"acc": True}


class Arithmetic2DPlus(Arithmetic):
    DATASET_NAME = "arithmetic_2da"


class Arithmetic2DMinus(Arithmetic):
    DATASET_NAME = "arithmetic_2ds"


class Arithmetic3DPlus(Arithmetic):
    DATASET_NAME = "arithmetic_3da"


class Arithmetic3DMinus(Arithmetic):
    DATASET_NAME = "arithmetic_3ds"


class Arithmetic4DPlus(Arithmetic):
    DATASET_NAME = "arithmetic_4da"


class Arithmetic4DMinus(Arithmetic):
    DATASET_NAME = "arithmetic_4ds"


class Arithmetic5DPlus(Arithmetic):
    DATASET_NAME = "arithmetic_5da"


class Arithmetic5DMinus(Arithmetic):
    DATASET_NAME = "arithmetic_5ds"


class Arithmetic2DMultiplication(Arithmetic):
    DATASET_NAME = "arithmetic_2dm"


class Arithmetic1DComposite(Arithmetic):
    DATASET_NAME = "arithmetic_1dc"
