from evaluation_suite import *
import re
# Follow readme instructions to download the strategyqa dataset


class StrategyQA(EvaluationSuite):
    def __init__(self, tag, size):
        super().__init__(tag, "generated_tasks/strategyqa_train.json", size)

        # instruct has now few shot samples
        self.few_shot_samples = []

    def instances(self):
        with open(self.file) as f:
            instances = json.load(f)
        for i in instances:
            i["target"] = str(i["answer"])
        return dict(enumerate(instances))

    def parse_prediction(self, prediction, model_result):
        print("original prediction is: ", prediction)
        # scanning from the back extract the first True or False you find
        value_str = ""
        prediction_split = re.split('[^a-zA-Z]', prediction)
        for word in reversed(prediction_split):
            print("parsing", word)
            if word.lower() in ["yes", "true"]:
                res = True
                break
            if word.lower() in ["no", "false"]:
                res = False
                break
        else:
            res = None

        return str(res)


@suite("StrategyQA@cot")
class StrategyQAcot(StrategyQA):
    def __init__(self, size):
        super().__init__("StrategyQA@cot", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n[COT] \\n Therefore, the answer (yes or no) is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no", " false", " true"]

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        query = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
                """.strip()

        if len(choices) > 0:
            query += f"\n\
WHERE\n\
    answer in {choices}\n\
"
        return query



@suite("StrategyQA@ao")
class StrategyQAao(StrategyQA):
    def __init__(self, size):
        super().__init__("StrategyQA@ao", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: The answer (yes or no) is[answer]\""

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 1
        choices = [" yes", " no", " false", " true"]
        stop_string = "" if len(choices) == 0 else (" and ".join([f"STOPS_AT(answer, '{x}')" for x in choices]) + " and ")

        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        query = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
FROM 
    "{model}"
                """.strip()

        if len(choices) > 0:
            query += f"\n\
WHERE\n\
    answer in {choices}\n\
"
        return query


@suite("StrategyQA@multivar")
class StrategyQAmvar(StrategyQA):
    def __init__(self, size):
        super().__init__("StrategyQA@multivar", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\n\"\""
        template += "\"A: Let's think step by step.\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in [" First", " Second", " Third", " Fourth", " Fifth"]:
        "{{i}},[THOUGHT]"
        if "therefore" in THOUGHT.lower(): break
        if "according" in THOUGHT.lower(): break
        if "answer" in THOUGHT.lower(): break
    " Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices}
        """.strip()
        return q


@suite("StrategyQA@multivar2")
class StrategyQAmvar2(StrategyQA):
    def __init__(self, size):
        super().__init__("StrategyQA@multivar2", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in range(5):
      "\\n-[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices} and
    STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q


@suite("StrategyQA@multivar3")
class StrategyQAmvar3(StrategyQA):
    def __init__(self, size):
        super().__init__("StrategyQA@multivar3", size)

    def make_query(self, instance, model, decoder, kwargs, shots):
        q = instance["question"]

        template = f"\"\"\"Q: {q}\\n\"\"\"\n"
        template += "\"A: Let's think step by step.\\n\"\n"

        if shots > 0: template = "\n".join(f'"{s}"' for s in self.few_shot_samples[:shots]) + "\n" + template

        kwargs["openai_chunksize"] = 64
        choices = [" yes", " no"] #, " false", " true"]
        additional_args = sorted([f"{k}={v}" for k, v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    for i in range(5):
      "\\n-[THOUGHT]"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
      "But what does '{{THOUGHT.strip()}}' mean?:\\n"
      for i in range(2):
         "          *[THOUGHT]"
         if not THOUGHT.endswith("\\n"): "\\n"
         if "Therefore" in THOUGHT: break
         if "According" in THOUGHT: break
      if not THOUGHT.endswith("\\n"): "\\n"
    " Overall this means,[CONCLUSION] Therefore, the answer (yes or no) is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and answer in {choices} and
    STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q






