
import inspect
from itertools import zip_longest

import efficiency_benchmark.dependencies.lm_eval.datasets.coqa.coqa
import transformers.data.metrics.squad_metrics as squad_metrics
from efficiency_benchmark.dependencies.lm_eval.base import Task, mean, rf

_CITATION = 


class CoQA(Task):
    VERSION = 1
    DATASET_PATH = inspect.getfile(efficiency_benchmark.dependencies.lm_eval.datasets.coqa.coqa)
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        pass

    def doc_to_text(self, doc):
        
        
        doc_text = doc["story"] + "\n\n"
        for q, a in zip_longest(
            doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
        ):  
            question = f"Q: {q}\n\n"
            answer = f"A: {a}\n\n" if a is not None else "A:"
            doc_text += question + answer
        return doc_text

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])

    @classmethod
    def get_answers(cls, doc, turn_id):
        
        answers = []
        answer_forturn = doc["answers"]["input_text"][turn_id - 1]
        answers.append(answer_forturn)

        additional_answers = doc.get("additional_answers")
        if additional_answers:
            for key in additional_answers:
                additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
                if additional_answer_for_turn.lower() not in map(str.lower, answers):
                    answers.append(additional_answer_for_turn)
        return answers

    @classmethod
    def get_answer_choice(self, raw_text):
        
        
        
        
        if raw_text == "unknown":
            return "0"
        if squad_metrics.normalize_answer(raw_text) == "yes":
            return "1"
        if squad_metrics.normalize_answer(raw_text) == "no":
            return "2"
        return "3"  

    @staticmethod
    def compute_scores(gold_list, pred):
        
        
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
                
                em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }

    def doc_to_target(self, doc, turnid=None):
        
        if turnid is None:
            turnid = len(doc["questions"]["input_text"])
        raw_text = doc["answers"]["input_text"][turnid - 1]
        return " " + raw_text

    def construct_requests(self, doc, ctx):
        
        cont_request = rf.greedy_until(ctx, ["\nQ:"])
        return cont_request

    def process_results(self, doc, results):
        
        turn_id = len(doc["questions"]["input_text"])
        gold_list = self.get_answers(doc, turn_id)
        pred = results[0].strip().split("\n")[0]

        scores = self.compute_scores(gold_list, pred)

        return {
            "f1": scores["f1"],
            "em": scores["em"],
        }

    def higher_is_better(self):
        return {
            "f1": True,
            "em": True,
        }

    def aggregation(self):
        return {
            "f1": mean,
            "em": mean,
        }
