
import re
import string
from collections import Counter

from efficiency_benchmark.dependencies.lm_eval.base import Task, rf
from efficiency_benchmark.dependencies.lm_eval.metrics import f1_score, mean

_CITATION = 


def normalize_answer(s):
    

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def categorise_answer(answer_blob):
    if answer_blob["unanswerable"]:
        answer = "unanswerable"
        answer_type = "unanswerable"
        return answer, answer_type
    elif answer_blob["yes_no"]:
        answer = "yes"
        answer_type = "bool"
        return answer, answer_type
    elif answer_blob["free_form_answer"]:
        answer = answer_blob["free_form_answer"]
        answer_type = "free form answer"
        return answer, answer_type
    elif answer_blob["extractive_spans"]:
        answer = answer_blob["extractive_spans"]
        answer_type = "extractive_spans"
        return answer, answer_type
    elif answer_blob["yes_no"] is False:
        answer = "no"
        answer_type = "bool"
        return answer, answer_type


def token_f1_score(prediction, ground_truth):
    
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


class QASPER(Task):
    VERSION = 0
    DATASET_PATH = "qasper"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def doc_to_text(self, doc):
        return (
            "TITLE: "
            + doc["title"]
            + "\n"
            + "ABSTRACT: "
            + doc["abstract"]
            + "\n\n"
            + "Q: "
            + doc["question"]
            + "\n\n"
            + "A:"
        )

    def doc_to_target(self, doc):
        answer = doc["answer"]
        if isinstance(answer, list):
            answer = ", ".join(answer)
        return " " + answer

    def training_docs(self):
        for doc in self.dataset["train"]:
            yield from self._process_doc(doc)

    def validation_docs(self):
        for doc in self.dataset["validation"]:
            yield from self._process_doc(doc)

    def _process_doc(self, doc):
        
        obs_list = []
        for question, answer_list in zip(doc["qas"]["question"], doc["qas"]["answers"]):
            for answer_blob in answer_list["answer"]:
                answer, answer_type = categorise_answer(answer_blob)
                obs_list.append(
                    {
                        "title": doc["title"],
                        "abstract": doc["abstract"],
                        "question": question,
                        "answer": answer,
                        "answer_type": answer_type,
                    }
                )
        return obs_list

    def process_results(self, doc, results):
        
        
        if not results:
            return {}
        elif len(results) == 1:
            [res] = results
        elif len(results) == 2:
            [ll_yes, ll_no] = results

        
        
        
        

        res_dict = {}
        
        if doc["answer_type"] == "bool":
            gold = 1 if doc["answer"] == "yes" else 0
            pred = ll_yes > ll_no
            res_dict["f1_yesno"] = (gold, pred)

        
        if doc["answer_type"] == "free form answer":
            res_dict["f1_abstractive"] = token_f1_score(res, doc["answer"])

        
        
        
        return res_dict

    def aggregation(self):
        return {
            "f1_yesno": f1_score,
            "f1_abstractive": mean,
        }

    def construct_requests(self, doc, ctx):
        
        
        if doc["answer_type"] in ("free form answer"):
            return [rf.greedy_until(ctx, ["\n"])]
        elif doc["answer_type"] in ("bool"):
            ll_yes, _ = rf.loglikelihood(ctx, " yes")
            ll_no, _ = rf.loglikelihood(ctx, " no")
            return [ll_yes, ll_no]
        else:
            return []

    def higher_is_better(self):
        
        return {
            "f1_yesno": True,
            "f1_abstractive": True,
        }
