from evaluation_suite import *

def parse_choice(prediction):
    original_prediction = prediction
    prediction = prediction.replace("(", "").replace(")", "")
    prediction = prediction.strip(". ")
    if len(prediction) == 0:
        return "<parse_choice could not parse '{}'>".format(original_prediction)
    prediction = prediction[0]
    return prediction

# few shot sample from aqua test.json, verified to not be contained in mini, small or medium
few_shot_samples = """Q: 10kg of a mixture contains 30% sand and 70% clay. In order to make the mixture contain equal quantities of clay and sand how much of the mixture is to be removed and replaced with pure sand?
Answer Choices: (A) 10/7 B) 20/7 C) 30/7 D) 40/7 E) 50/7
A: Let's think step by step.
- The mixture contains 3kg sand and 7 kg clay.
- For the mixture to be in equal quantities, there should be 2 kg of clay removed.
- Clay and sand are in the ratio 7:3
So part of sand to be removed = 2*3/7 = 6/7
So the answer is that the total mixture to be removed = 2 + 6/7 = 20/7.
Overall this means, the answer is B. Therefore, among A through E, the answer is B.\\n
Q: 30 is subtracted from a number, it is reduced to its one third. What is the value of 50% of that number?
Answer Choices: (A) 22.5 B) 84 C) 21 D) 24 E) 25
A: Let's think step by step.
- 2/3 x = 30 => x = 45
- 45 * 1/2 = 22.5
- So the answer is 22.5.
Overall this means, the answer is 22.5. Therefore, among A through E, the answer is A.\\n
"""

# few shot sample from aqua test.json, verified to not be contained in mini, small or medium
few_shot_ao = """Q: 10kg of a mixture contains 30% sand and 70% clay. In order to make the mixture contain equal quantities of clay and sand how much of the mixture is to be removed and replaced with pure sand?
Answer Choices: (A) 10/7 B) 20/7 C) 30/7 D) 40/7 E) 50/7
A: Among A through E, the answer is B.\\n
Q: 30 is subtracted from a number, it is reduced to its one third. What is the value of 50% of that number?
Answer Choices: (A) 22.5 B) 84 C) 21 D) 24 E) 25
A: Among A through E, the answer is A.\\n

"""

@suite("AQUA@cot")
class AquaCot(EvaluationSuite):
    def __init__(self, size):
        super().__init__("AQUA@cot", "tasks/zero-shot-cot/AQuA/test.json", size)

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = i["correct"]
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]
        template = ""
        
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template += f"\"\"\"{few_shot_samples}\"\"\"\n"

        template += f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"Answer Choices:"
        for o in instance["options"]:
            letter, text = o.split(")", 1)
            template += f" ({letter}) {text}"
        template += "\\n\"\n"
        template += "\"A: Let's think step by step.[COT] Therefore, among A through E, the answer is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 128

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ".")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_choice(prediction)


@suite("AQUA@multivar")
class AquaMultiVarCot(EvaluationSuite):
    def __init__(self, size):
        super().__init__("AQUA@multivar", "tasks/zero-shot-cot/AQuA/test.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = i["correct"]
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"].replace("\n", "\\n")

        template = f"\"\"\"{q}\\n\"\"\"\n"
        choices = "Answer Choices:"
        for o in instance["options"]:
            letter, text = o.split(")", 1)
            choices += f" ({letter}) {text}"
        # template += "\"A: Let's think step by step.\\Therefore, among A through E, the answer is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 128 if shots == 0 else 1024

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"Q: {q}\\n\"\"\"
    "{choices}\\n"
    "A: Let's think step by step. \\n"
    for i in ["Firstly", " Secondly", " Thirdly", " Fourthly"]:
      "{{i}},[THOUGHT]"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, among A through E, the answer is[answer]"
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ".") and STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_choice(prediction)


@suite("AQUA@dash_multivar")
class AquaDashMultiVarCot(EvaluationSuite):
    def __init__(self, size):
        super().__init__("AQUA@dash_multivar", "tasks/zero-shot-cot/AQuA/test.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = i["correct"]
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"].replace("\n", "\\n")

        template = ""
        
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template = f"\"\"\"{few_shot_samples}\"\"\"\n"

        choices = "Answer Choices:"
        for o in instance["options"]:
            letter, text = o.split(")", 1)
            choices += f" ({letter}) {text}"

        kwargs["openai_chunksize"] = 32

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        end_letter = choice_label[len(instance["options"]) - 1]

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}\"\"\"Q: {q}\\n\"\"\"
    "{choices}\\n"
    "A: Let's think step by step.\\n"
    for i in range(8):
      "\\n-[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "answer" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, among A through {end_letter}, the answer is[answer]"
FROM 
    "{model}"
WHERE
    STOPS_AT(answer, ".") and STOPS_AT(THOUGHT, "\\n") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_choice(prediction)

@suite("AQUA@ao")
class AquaCotAo(EvaluationSuite):
    def __init__(self, size):
        super().__init__("AQUA@ao", "tasks/zero-shot-cot/AQuA/test.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = []
            for line in f:
                instances.append(json.loads(line))
        for i in instances:
            i["target"] = i["correct"]
        return dict(enumerate(instances))

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]
        
        template = ""
        
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template += f"\"\"\"{few_shot_ao}\"\"\"\n"

        template += f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"Answer Choices:"
        for o in instance["options"]:
            letter, text = o.split(")", 1)
            template += f" ({letter}) {text}"
        template += "\\n\"\n"
        template += "\"A: Among A through E, the answer is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 16

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        return f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
where
    STOPS_AT(answer, ".") and STOPS_AT(answer, "\\n")
""".strip()
    
    def parse_prediction(self, prediction, model_result):
        # print("original prediction is", prediction)
        # print("prediction is", prediction.rstrip("."))
        # scanning from the back extract the first True or False you find
        return parse_choice(prediction)

if __name__ == "__main__":
    print(AquaMultiVarCot.parse_prediction(None, "B: 3", None))
