from evaluation_suite import *
import re

def parse_shape(s):
    # replace anything but (, ), 0-9
    if type(s) is not str:
        s = str(s)
    s = s.rstrip(".").strip()
    s = re.sub("[^0-9\(\),]", "", s)
    return s

# 369, 575 ensured to be excluded from mini, small, medium, large
few_shot_samples = """Subtract a matrix of shape (2,2,3,3) from a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3). Add the result to a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3).\\n

Multiply a matrix of shape (2,2,2,2) with a matrix of shape (2,2,2,3). The result is a matrix of shape (2,2,2,3). Subtract the result from a matrix of shape (2,2,2,3). The result is a matrix of shape (2,2,2,3). Transpose the result. The result is a matrix of shape (3,2,2,2). The result is a matrix of shape (3,2,2,2).
"""

cot_few_shot = """Subtract a matrix of shape (2,2,3,3) from a matrix of shape (2,2,3,3). Add the result to a matrix of shape (2,2,3,3). To determine the shape of the result, let's think step-by-step: Subtract a matrix of shape (2,2,3,3) from a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3). Add the result to a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3). The result is a matrix of shape (2,2,3,3).\\n

Multiply a matrix of shape (2,2,2,2) with a matrix of shape (2,2,2,3). Subtract the result from a matrix of shape (2,2,2,3). Transpose the result. To determine the shape of the result, let's think step-by-step: Multiply a matrix of shape (2,2,2,2) with a matrix of shape (2,2,2,3). The result is a matrix of shape (2,2,2,3). Subtract the result from a matrix of shape (2,2,2,3). The result is a matrix of shape (2,2,2,3). Transpose the result. The result is a matrix of shape (3,2,2,2). The result is a matrix of shape (3,2,2,2).\\n
"""

ao_few_shot = """Subtract a matrix of shape (2,2,3,3) from a matrix of shape (2,2,3,3). Add the result to a matrix of shape (2,2,3,3). The shape of the overall result is (2,2,3,3).\\n

Multiply a matrix of shape (2,2,2,2) with a matrix of shape (2,2,2,3). Subtract the result from a matrix of shape (2,2,2,3). Transpose the result. The shape of the overall result is (3,2,2,2).\\n
"""

@suite("matrixshapes@multivar")
class MatrixShapesEvaluationSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__("matrixshapes@multivar", os.path.join("generated_tasks", "matrixshapes.json"), size)

        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        condition = instance["condition"]

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        samples = ""
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            samples = f"\"\"\"{cot_few_shot}\"\"\"\n"

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(samples)}
    {indent(template)}
FROM 
    "{model}"
WHERE
    {condition}
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        return parse_shape(prediction)

@suite("matrixshapes@cot")
class VanillaCoTExpressionEvaluationSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__("matrixshapes@cot", os.path.join("generated_tasks", "matrixshapes.json"), size)

        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")

        template = template.replace(" The result is a matrix of shape[var]", "")
        template = template.replace("..", ".")[:-1]
        template = template.replace(" The result is a matrix of shape[answer]", " To determine the shape of the result, let's think step-by-step:[answer]\"")

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        samples = ""
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            samples = f"\"\"\"{cot_few_shot}\"\"\"\n"

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(samples)}
    {indent(template)}
FROM 
    "{model}"
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        open_shape = model_result.rfind("(")
        close_shape = model_result.rfind(")")
        if open_shape == -1 or close_shape == -1:
            return "<none>"
        shape = model_result[open_shape:close_shape+1]
        return parse_shape(shape)

@suite("matrixshapes@ao")
class AnswerOnlyMatrixshapes(EvaluationSuite):
    name = "matrixshapes@ao"

    def __init__(self, size):
        super().__init__("matrixshapes@ao", os.path.join("generated_tasks", "matrixshapes.json"), size)

        self.few_shot_samples = []

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")

        template = template.replace(" The result is a matrix of shape[var]", "")
        template = template.replace("..", ".")[:-1]
        template = template.replace(" The result is a matrix of shape[answer]", " The shape of the overall result is[answer]\"")

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        samples = ""
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            samples = f"\"\"\"{ao_few_shot}\"\"\"\n"

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(samples)}
    {indent(template)}
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ")")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        open_shape = model_result.rfind("(")
        close_shape = model_result.rfind(")")
        if open_shape == -1 or close_shape == -1:
            return "<none>"
        shape = model_result[open_shape:close_shape+1]
        return parse_shape(shape)