from evaluation_suite import *
import re
# Follow readme instructions to download the strategyqa dataset

# samples 01c3faf4915a44133f60, 028eda7ded7825edb6eb

ao_few_shot_samples = """Q: Was Iggy Pop named after his father?
A: The answer (yes or no) is yes.\\n

Q: Is a felony jury enough people for a Bunco game?
A: The answer (yes or no) is yes.\\n
"""

few_shot_samples = """Q: Was Iggy Pop named after his father?
A: Let's think step by step.
- Iggy Pop's birth name was James Newell Osterberg Jr.
- The father of Iggy Pop was James Newell Osterberg Sr.
- According to this, they have the first and last name in common
Overall, this means that the answer is yes. Therefore, the answer (yes or no) is yes.\\n

Q: Is a felony jury enough people for a Bunco game?
A: Let's think step by step.
- Felonies and other serious crimes have a jury of 12 people.
- Bunco is a parlour game requiring 12 or more players.
- Therefore, a felony usually has enough people for a Bunco game.
Overall, this means that the answer is yes. Therefore, the answer (yes or no) is yes.\\n
"""

class StrategyQA(EvaluationSuite):
    def __init__(self, tag, size):
        super().__init__(tag, "generated_tasks/strategyqa_train.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = json.load(f)
        for i in instances:
            i["target"] = str(i["answer"])
        return dict(enumerate(instances))

    def parse_prediction(self, prediction, model_result):
        print("original prediction is: ", prediction)
        # scanning from the back extract the first True or False you find
        value_str = ""
        prediction_split = re.split('[^a-zA-Z]', prediction)
        for word in reversed(prediction_split):
            print("parsing", word)
            if word.lower() in ["yes", "true"]:
                res = True
                break
            if word.lower() in ["no", "false"]:
                res = False
                break
        else:
            res = None

        return str(res)


@suite("sqauni@cot")
class StrategyQAcot(StrategyQA):
    def __init__(self, size):
        super().__init__("sqauni@cot", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = ""

        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template += f"\"\"\"{few_shot_samples}\"\"\"\n"

        template += f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n[COT] \\n Therefore, the answer (yes or no) is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        query = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
                """.strip()

        if len(choices) > 0:
            query += f"\n\
WHERE\n\
    answer in {choices}\n\
"
        return query



@suite("sqauni@ao")
class StrategyQAao(StrategyQA):
    def __init__(self, size):
        super().__init__("sqauni@ao", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = ""

        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template += f"\"\"\"{ao_few_shot_samples}\"\"\"\n"

        template += f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: The answer (yes or no) is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 1
        choices = [" yes", " no", " false", " true"]
        stop_string = "" if len(choices) == 0 else (" and ".join([f"STOPS_AT(answer, '{x}')" for x in choices]) + " and ")

        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        query = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
                """.strip()

        if len(choices) > 0:
            query += f"\n\
WHERE\n\
    answer in {choices}\n\
"
        return query


@suite("sqauni@multivar")
class StrategyQAmvar(StrategyQA):
    def __init__(self, size):
        super().__init__("sqauni@multivar", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\n\"\""
        template += "\"A: Let's think step by step.\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in [" First", " Second", " Third", " Fourth", " Fifth"]:
        "{{i}},[THOUGHT]"
        if "therefore" in THOUGHT.lower(): break
        if "according" in THOUGHT.lower(): break
        if "answer" in THOUGHT.lower(): break
    " Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices}
        """.strip()
        return q


@suite("sqauni@multivar2")
class StrategyQAmvar2(StrategyQA):
    def __init__(self, size):
        super().__init__("sqauni@multivar2", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = ""

        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template += f"\"\"\"{few_shot_samples}\"\"\"\n"

        template += f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in range(5):
      "\\n-[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices} and
    STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q


@suite("sqauni@multivar3")
class StrategyQAmvar3(StrategyQA):
    def __init__(self, size):
        super().__init__("sqauni@multivar3", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in range(5):
      "\\n-[THOUGHT]"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
      "But what does '{{THOUGHT.strip()}}' mean?:\\n"
      for i in range(2):
         "          *[THOUGHT]"
         if not THOUGHT.endswith("\\n"): "\\n"
         if "Therefore" in THOUGHT: break
         if "According" in THOUGHT: break
      if not THOUGHT.endswith("\\n"): "\\n"
    " Overall this means,[CONCLUSION] Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices} and
    STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q






