"""
"""
import re
from typing import List

import numpy as np

from lm_eval.api.instance import Instance
from lm_eval.api.task import ConfigurableTask
import string
ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)
import evaluate
from functools import partial
import datasets
from lm_eval.api.metrics import mean

def contains_score(prediction, reference):
    answer, options = reference
    options = options.split(',')
    if answer + ')' in prediction:
        return 1

    # answer_index = ord(answer) - ord('a')
    # if options[answer_index][3:].strip() in prediction:
    #     return 1

    return 0


class MathQA_Conv_NoInd(ConfigurableTask):
    VERSION = 0
    DATASET_PATH = "allenai/math_qa"
    DATASET_NAME = "default"
    TEXT_SUFFIX = ""

    def __init__(self):
        super().__init__(config={"metadata": {"version": self.VERSION}})

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def validation_docs(self):
        return self.dataset["validation"]

    def doc_to_text(self, doc):
        texts = [doc['Problem'] + " Choose the correct answer from the following options: " + doc['options']]
        # for i, question in enumerate(doc["questions"][1:]):
        #     pre_answer = doc['answers']['texts'][i][0]
        #     texts.append(pre_answer)
        # texts[-1] += " Base your answer more on the given information."
        texts[-1] += self.TEXT_SUFFIX

        return texts
        
    
    def doc_to_hint(self, doc):
        return ""

    def doc_to_target(self, doc):
        return doc["correct"]

    def construct_requests(self, doc, ctx, **kwargs):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        return [
            Instance(
                request_type="generate_until",
                doc=doc,
                arguments=(ctx, {"max_gen_toks": 500}, ""),
                idx=0,
                **kwargs,
            )
        ]

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        # continuation, (logprob_unanswerable, _) = results
        continuation = results

        return {
            # "exact": (
            #     predictions,
            #     references,
            # ),  # Exact match (the normalized answer exactly match the gold answer)
            # "f1": (
            #     predictions,
            #     references,
            # ),
            # "HasAns_exact": (
            #     predictions,
            #     references
            # ),
            # "HasAns_f1": (
            #     predictions,
            #     references
            # ),
            'contains': contains_score(continuation[0], (doc['correct'], doc['options']))
            # (
            #     continuation[0],
            #     references['answers']['text']
            # )
            
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            # "exact": partial(
            #     _squad_agg, "exact"
            # ),  # Exact match (the normalized answer exactly match the gold answer)
            # "f1": partial(
            #     _squad_agg, "f1"
            # ),
            # "HasAns_exact": partial(
            #     _squad_agg, "HasAns_exact"
            # ),  # Exact match (the normalized answer exactly match the gold answer)
            # "HasAns_f1": partial(
            #     _squad_agg, "HasAns_f1"
            # ),  # The F-score of predicted tokens versus the gold answer
            "contains": mean
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "exact": True,  # Exact match (the normalized answer exactly match the gold answer)
            "f1": True,
            "HasAns_exact": True,  # Exact match (the normalized answer exactly match the gold answer)
            "HasAns_f1": True,
            }


def f1_score(prediction: str, label: dict):

    f1_metric = evaluate.load("squad")

    return f1_metric.compute(predictions=[{'prediction_text': prediction, 'id': label['id']}], references=[label])['f1']

