from evaluation_suite import *

@suite("boolean_expressions_cot")
class VanillaCotBooleanExpressionEvaluationSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__("boolean_expressions_cot", os.path.join("generated_tasks", "boolean_expressions.json"), size)

        self.few_shot_samples = [
            "To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\n",
            "To evaluate the logical expression '(False and not False)' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\n"
        ]

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        expr = template.split("To evaluate the logical expression '")[1].split("'")[0]
        if shots == 0:
            template = f"""
"What does the logical expression '{expr}' evaluate to . Let's think step-by-step.\\n"
"[RESPONSE]"
""".strip()
        else:
            template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
            template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
""".strip()
    
    def parse_prediction(self, prediction, model_result):
        if "To conclude" in model_result:
            last_sentence = model_result.rsplit("To conclude", 1)[1]
        else:
            last_sentence = model_result
        # scanning from the back extract the first True or False you find
        for i in range(len(last_sentence)-1, 0, -1):
            if last_sentence[i:i+4] == "True":
                prediction = "True"
                break
            elif last_sentence[i:i+5] == "False":
                prediction = "False"
                break
        return prediction

@suite("boolean_expressions_instruct")
class InstructBooleanExpressions(EvaluationSuite):
    def __init__(self, size):
        super().__init__("boolean_expressions_instruct", os.path.join("generated_tasks", "boolean_expressions.json"), size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.split("\n")[0].split("The first one is '")[0] + "The first one is '[COT]\""
        expr = template.split("To evaluate the logical expression '")[1].split("'")[0]
        template = f"\"Evalute the following logical expression by evaluating subexpressions first. When you are done, begin the last sentence with 'To conclude' and end on a new line. The logical expression is '{expr}'. The first subexpression is '[COT]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        if "To conclude" in model_result:
            last_sentence = model_result.rsplit("To conclude", 1)[1]
        else:
            last_sentence = model_result
        # scanning from the back extract the first True or False you find
        for i in range(len(last_sentence)-1, 0, -1):
            if last_sentence[i:i+4].lower() == "true":
                prediction = "True"
                break
            elif last_sentence[i:i+5].lower() == "false":
                prediction = "False"
                break
        return prediction


@suite("boolean_expressions")
class BooleanExpressionEvaluationSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__("boolean_expressions", os.path.join("generated_tasks", "boolean_expressions.json"), size)

        self.few_shot_samples = [
            "To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\n",
            "To evaluate the logical expression '(False and not False)' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\n"
        ]

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
WHERE
    {condition}
        """.strip()

if __name__ == "__main__":
    kwargs = {
        "decoder": "argmax",
        "shots": 0,
        "kwargs": {
            "max_length": 320, 
            "openai_chunksize": 320
        },
        "num_workers": 2,
        "size": "mini",
        # "vanilla-cot": True
    }

    if kwargs.get("vanilla-cot", False): VanillaCotBooleanExpressionEvaluationSuite(kwargs.get("size", None)).main(**kwargs)
    elif kwargs.get("instruct", False): InstructBooleanExpressions(kwargs.get("size", None)).main(**kwargs)
    else: BooleanExpressionEvaluationSuite(kwargs.get("size", None)).main(**kwargs)