from evaluation_suite import *
import re

def parse_num(prediction):
    # scanning from the back extract the first number you find
    pattern = "([0-9]*(,)?)*(\.)?[0-9]"
    pattern = re.compile(pattern)

    # find last match
    last_match = None
    for match in pattern.finditer(prediction):
        last_match = match

    if last_match is None:
        return prediction
    else:
        print(prediction[last_match.start():last_match.end()].replace(",", ""))
        return prediction[last_match.start():last_match.end()].replace(",", "")


@suite("AddSub@cot")
class AddSub(EvaluationSuite):
    def __init__(self, size):
        super().__init__("AddSub@cot", "tasks/zero-shot-cot/AddSub/AddSub.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = json.load(f)
            instances = instances["examples"]
        for i in instances:
            i["target"] = i["lSolutions"][0]
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["sQuestion"]

        template = f"\"{q}\\n\"\n"
        template += "\"A: Let's think step by step.[COT] Therefore, the answer (arabic numerals) is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 1024

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        return parse_num(prediction)