from evaluation_suite import *
import re

def parse_num(prediction):
    if type(prediction) is not str:
        return str(prediction)
    # scanning from the back extract the first number you find
    pattern = "((-)?[0-9]*(,)?)*(\.)?[0-9]"
    pattern = re.compile(pattern)

    # find last match
    last_match = None
    for match in pattern.finditer(prediction):
        last_match = match

    if last_match is None:
        return prediction
    else:
        print(prediction[last_match.start():last_match.end()].replace(",", ""))
        return prediction[last_match.start():last_match.end()].replace(",", "")


@suite("multiarith@cot")
class VanillaCotBooleanExpressionEvaluationSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__("boolean_expressions_cot", os.path.join("generated_tasks", "boolean_expressions.json"), size)

        self.few_shot_samples = [
            "To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\n",
            "To evaluate the logical expression '(False and not False)' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\n"
        ]

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        expr = template.split("To evaluate the logical expression '")[1].split("'")[0]
        if shots == 0:
            template = f"""
"What does the logical expression '{expr}' evaluate to . Let's think step-by-step.\\n"
"[RESPONSE]"
""".strip()
        else:
            template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
            template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
""".strip()
    
    def parse_prediction(self, prediction, model_result):
        p = prediction
        parsed = parse_num(prediction)
        print("parse prediction", prediction, "parsed", parsed, "p", p)
        return parsed


@suite("multiarith2@cot")
class InstructMultistepArith(EvaluationSuite):
    def __init__(self, size):
        super().__init__("multiarith2@cot", os.path.join("generated_tasks", "multistep_arithmetic.json"), size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""
        expr = template.split("To evaluate the expression '")[1].split("'")[0]

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 128

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: Evaluate the arithmetic expression '{expr}'.
    A: Let's think step by step.
    [COT] Therefore, the answer (arabic numerals) is[answer]\"\"\"
FROM 
    "{model}"
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        p = prediction
        parsed = parse_num(prediction)
        print("parse prediction", prediction, "parsed", parsed, "p", p)
        return parsed



@suite("multiarith2@ao")
class AoInstructMultistepArith(EvaluationSuite):
    def __init__(self, size):
        super().__init__("multiarith2@ao", os.path.join("generated_tasks", "multistep_arithmetic.json"), size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""
        expr = template.split("To evaluate the expression '")[1].split("'")[0]

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 8

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: Evaluate the arithmetic expression '{expr}'.
    A: Let's think step by step.
    The answer (arabic numerals) is[answer]\"\"\"
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ".")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        p = prediction
        parsed = parse_num(prediction)
        print("parse prediction", prediction, "parsed", parsed, "p", p)
        return parsed


@suite("multiarith2@guided")
class MultistepGuidedArithmeticsGuided(EvaluationSuite):
    def __init__(self, size):
        super().__init__("multiarith@guided", os.path.join("generated_tasks", "multistep_arithmetic.json"), size)

        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        # template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        # template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""
        # expr = template.split("To evaluate the expression '")[1].split("'")[0]
        
        condition = instance["condition"]

        kwargs["openai_chunksize"] = 8

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
WHERE
    {condition}
        """.strip()

    def parse_prediction(self, prediction, model_result):
        return prediction.rstrip(".").strip()


@suite("multiarith2@multivar")
class MultistepArithmeticsGuided(EvaluationSuite):
    def __init__(self, size):
        super().__init__("multiarith@multivar", os.path.join("generated_tasks", "multistep_arithmetic.json"), size)

        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""
        expr = template.split("To evaluate the expression '")[1].split("'")[0]
        
        condition = instance["condition"]

        kwargs["openai_chunksize"] = 128

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: Evaluate the arithmetic expression '{expr}'.
    A: Let's think step by step.\\n\"\"\"
    for i in range(10):
      "({{i+1}})[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "answer" in THOUGHT: break
    "Overall this means,[CONCLUSION] Therefore, the answer (arabic numerals) is[answer]\"\"\"
FROM 
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q

    def parse_prediction(self, prediction, model_result):
        return prediction.rstrip(".").strip()
    
if __name__ == "__main__":
    print(parse_num("The answer is -234."))