
from collections import defaultdict

import numpy as np
from efficiency_benchmark.dependencies.lm_eval.base import Task, rf

_CITATION = 


class MCTACO(Task):
    VERSION = 0
    DATASET_PATH = "mc_taco"
    DATASET_NAME = None

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def doc_to_text(self, doc):
        return f"{doc['sentence']}\nQuestion: {doc['question']}\n" f"Answer: {doc['answer']}\nPlausible:"

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["question"] + " " + doc["sentence"]

    def doc_to_target(self, doc):
        return " " + ["no", "yes"][doc["label"]]

    def construct_requests(self, doc, ctx):
        
        ll_no, _ = rf.loglikelihood(ctx, " no")
        ll_yes, _ = rf.loglikelihood(ctx, " yes")
        return ll_no, ll_yes

    def process_results(self, doc, results):
        
        ll_no, ll_yes = results
        gold = doc["label"]
        pred = int(ll_yes > ll_no)
        question_id = self._question2id(doc)
        items = (gold, pred, question_id)
        return {"em": items, "f1": items}

    def _question2id(self, doc):
        
        return " ".join([doc["sentence"], doc["question"]])

    def aggregation(self):
        return {
            "f1": f1,
            "em": exact_match,
        }

    def higher_is_better(self):
        return {
            "f1": True,
            "em": True,
        }


def exact_match(items):
    
    results = list(zip(*items))
    accuracies = defaultdict(list)
    for gold, pred, question in zip(results[0], results[1], results[2]):
        accuracies[question].append(pred == gold)
    return np.mean([int(all(accs)) for accs in accuracies.values()])


def f1(items):
    
    results = list(zip(*items))
    
    gold_positives, pred_positives = defaultdict(list), defaultdict(list)
    for gold, pred, question in zip(results[0], results[1], results[2]):
        gold_positives[question].append(gold)
        pred_positives[question].append(pred)
    f1 = []
    for question in gold_positives.keys():
        gp, pp = sum(gold_positives[question]), sum(pred_positives[question])
        tp = sum(np.logical_and(gold_positives[question], pred_positives[question]))
        p = tp / pp if pp > 0.0 else 1.0
        r = tp / gp if gp > 0.0 else 1.0
        if p + r > 0.0:
            f1.append(2.0 * (p * r) / (p + r))
    return np.mean(f1)
