from evaluation_suite import *
import re

def parse_num(prediction):
    prediction = str(prediction)

    # scanning from the back extract the first number you find
    pattern = "([0-9]*(,)?)*(\.)?[0-9]"
    pattern = re.compile(pattern)

    # find last match
    last_match = None
    for match in pattern.finditer(prediction):
        last_match = match
    
    if last_match is None:
        return prediction
    else:
        return prediction[last_match.start():last_match.end()].replace(",", "")

@suite("GSM8@cot")
class Gsm8(EvaluationSuite):
    def __init__(self, size):
        super().__init__("GSM8@cot", "tasks/gsm8.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = float(i["answer"].rsplit("###", 1)[1].replace(",", "."))
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"].strip()

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 128

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: {q}
    
    A: Let's think step by step.[COT] Therefore, the answer (arabic numerals) is[answer]
    \"\"\"
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, "\\n")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_num(prediction)


@suite("GSM8@ao")
class Gsm8AnswerOnly(EvaluationSuite):
    def __init__(self, size):
        super().__init__("GSM8@ao", "tasks/gsm8.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = float(i["answer"].rsplit("###", 1)[1].replace(",", "."))
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"].strip()

        template = ""

        kwargs["openai_chunksize"] = 128

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    \"\"\"Q: {q}
    The answer (arabic numerals) is[answer]
    \"\"\"
FROM 
    "{model}"    
""".strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_num(prediction)


@suite("GSM8@multivar")
class Gsm8Multivar(EvaluationSuite):
    def __init__(self, size):
        super().__init__("GSM8@multivar", "tasks/gsm8.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = float(i["answer"].rsplit("###", 1)[1].replace(",", "."))
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"].strip()

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 128

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: {q}\\n\"\"\"
    "A: Let's think step by step.\\n"
    for i in range(10):
      "({{i+1}})[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "answer" in THOUGHT: break
    "Overall this means,[CONCLUSION] Therefore, the answer (arabic numerals) is[answer]\"\"\"
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ".") and STOPS_AT(THOUGHT, "\\n") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_num(prediction)
