from evaluation_suite import *
import string
import random

def make_choices(choices):
    l = "ABCDEFGHIJKL"
    return " ".join([f"({l[i]}) {c}" for i,c in enumerate(choices)])

def parse_choice(prediction):
    original_prediction = prediction
    prediction = prediction.replace("(", "").replace(")", "")
    prediction = prediction.strip(". ")
    if len(prediction) == 0:
        return "<parse_choice could not parse '{}'>".format(original_prediction)
    prediction = prediction[0]
    return prediction

@suite("date_understanding@cot")
class DateUnderstandingEvaluationSuite(EvaluationSuite):
    def __init__(self, size, name="date_understanding@cot"):
        print("size is", size)
        super().__init__(name, os.path.join("tasks", f"date_understanding.json"), size)

        self.few_shot_samples = [
"""Q: 2015 is coming in 36 hours. What is the date one week from today in MM/DD/YYYY?
A: If 2015 is coming in 36 hours, then it is coming in 2 days. 2 days before 01/01/2015 is 12/30/2014, so today is 12/30/2014. So one week from today will be 01/05/2015. So the answer is 01/05/2015.\\n
""",

"""Q: The first day of 2019 is a Tuesday, and today is the first Monday of 2019. What is the date today in MM/DD/YYYY?
A: If the first day of 2019 was Tuesday, then 01/01/2019 was a Tuesday. Today is the first monday, would be six days later. So today is 01/07/2019. So the answer is 01/07/2019.\\n
""",

"""Q: The concert was scheduled to be on 06/01/1943, but was delayed by one day to today. What is the date 10 days ago in MM/DD/YYYY?
A: One day after 06/01/1943 is 06/02/1943, so today is 06/02/1943. 10 days before today is 05/23/1943. So the answer is 05/23/1943.\\n
""",

"""Q: It is 4/19/1969 today. What is the date 24 hours later in MM/DD/YYYY?
A: Today is 04/19/1969. 24 hours later is one day after today, which would be 04/20/1969. So the answer is 04/20/1969.\\n
""",

"""Q: Jane thought today is 3/11/2002, but today is in fact Mar 12, which is 1 day later. What is the date 24 hours later in MM/DD/YYYY?
A: Today is 03/12/2002. So the date 24 hours later will be 03/13/2002. So the answer is 03/13/2002.\\n
""",

"""Q: Jane was born on the last day of Feburary in 2001. Today is her 16-year-old birthday. What is the date yesterday in MM/DD/YYYY?
A: The last day of February is the 28th, so Jane was born on 02/28/2001. Today is her 16-year old birthday, so today is 02/28/2017. So yesterday was 02/27/2017. So the answer is 02/27/2017.\\n
"""
        ]

    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in enumerate(data["examples"]):
                choices = sorted(instance_data["target_scores"].items())
                random.seed(42)
                random.shuffle(choices)

                target = [c for c,p in choices if p == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data["target"])
                choices = [c for c,p in choices]

                choice_letters = "ABCDEFGHIJKL"
                target = choice_letters[choices.index(target[0])]

                instances[i] = {
                    "template": instance_data["input"],
                    "target_scores": instance_data["target_scores"],
                    "target": target,
                    "choices": choices
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = f"""Q: {instance["template"]}"""
        choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        choices = instance["choices"]

        # condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{template}
    Answer Choices: {make_choices(choices)}
    A: Let's think step-by-step.
    [COT] Therefore, among A through {end_letter}, the answer is[answer]\"\"\"
FROM 
    "{model}"
        """.strip()
        return q

@suite("date_understanding@ao")
class AnswerOnlyDateUnderstandingEvaluationSuite(EvaluationSuite):
    def __init__(self, size, name="date_understanding@ao"):
        print("size is", size)
        super().__init__(name, os.path.join("tasks", f"date_understanding.json"), size)

        self.few_shot_samples = [
"""Q: 2015 is coming in 36 hours. What is the date one week from today in MM/DD/YYYY?
A: If 2015 is coming in 36 hours, then it is coming in 2 days. 2 days before 01/01/2015 is 12/30/2014, so today is 12/30/2014. So one week from today will be 01/05/2015. So the answer is 01/05/2015.\\n
""",

"""Q: The first day of 2019 is a Tuesday, and today is the first Monday of 2019. What is the date today in MM/DD/YYYY?
A: If the first day of 2019 was Tuesday, then 01/01/2019 was a Tuesday. Today is the first monday, would be six days later. So today is 01/07/2019. So the answer is 01/07/2019.\\n
""",

"""Q: The concert was scheduled to be on 06/01/1943, but was delayed by one day to today. What is the date 10 days ago in MM/DD/YYYY?
A: One day after 06/01/1943 is 06/02/1943, so today is 06/02/1943. 10 days before today is 05/23/1943. So the answer is 05/23/1943.\\n
""",

"""Q: It is 4/19/1969 today. What is the date 24 hours later in MM/DD/YYYY?
A: Today is 04/19/1969. 24 hours later is one day after today, which would be 04/20/1969. So the answer is 04/20/1969.\\n
""",

"""Q: Jane thought today is 3/11/2002, but today is in fact Mar 12, which is 1 day later. What is the date 24 hours later in MM/DD/YYYY?
A: Today is 03/12/2002. So the date 24 hours later will be 03/13/2002. So the answer is 03/13/2002.\\n
""",

"""Q: Jane was born on the last day of Feburary in 2001. Today is her 16-year-old birthday. What is the date yesterday in MM/DD/YYYY?
A: The last day of February is the 28th, so Jane was born on 02/28/2001. Today is her 16-year old birthday, so today is 02/28/2017. So yesterday was 02/27/2017. So the answer is 02/27/2017.\\n
"""
        ]

    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in enumerate(data["examples"]):
                choices = sorted(instance_data["target_scores"].items())
                random.seed(42)
                random.shuffle(choices)

                target = [c for c,p in choices if p == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data["target"])
                choices = [c for c,p in choices]

                choice_letters = "ABCDEFGHIJKL"
                target = choice_letters[choices.index(target[0])]

                instances[i] = {
                    "template": instance_data["input"],
                    "target_scores": instance_data["target_scores"],
                    "target": target,
                    "choices": choices
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = f"""Q: {instance["template"]}"""
        choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        choices = instance["choices"]

        # condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]
        choice_letters = ", ".join([f'" {c}"' for c in choice_label[:len(choices)]])

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{template}
    Answer Choices: {make_choices(choices)}
    Among A through {end_letter}, the answer is[answer]\"\"\"
FROM 
    "{model}"
WHERE
    answer in [{choice_letters}]
        """.strip()
        return q



@suite("date_understanding@multivar")
class MV1DateUnderstandingEvaluationSuite(EvaluationSuite):
    def __init__(self, size, name="date_understanding@multivar"):
        print("size is", size)
        super().__init__(name, os.path.join("tasks", f"date_understanding.json"), size)

        self.few_shot_samples = []

    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in enumerate(data["examples"]):
                choices = sorted(instance_data["target_scores"].items())
                random.seed(42)
                random.shuffle(choices)

                target = [c for c,p in choices if p == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data["target"])
                choices = [c for c,p in choices]

                choice_letters = "ABCDEFGHIJKL"
                target = choice_letters[choices.index(target[0])]

                instances[i] = {
                    "template": instance_data["input"],
                    "target_scores": instance_data["target_scores"],
                    "target": target,
                    "choices": choices
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = f"""Q: {instance["template"]}"""
        choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        choices = instance["choices"]

        # condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{template}
    Answer Choices: {make_choices(choices)}
    A: Let's think step by step.\"\"\"
    for i in range(5):
      "-[THOUGHT]"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, among A through {end_letter}, the answer is[answer]"
FROM 
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q



@suite("date_understanding@multivar2")
class MV2DateUnderstandingEvaluationSuite(EvaluationSuite):
    def __init__(self, size, name="date_understanding@multivar2"):
        print("size is", size)
        super().__init__(name, os.path.join("tasks", f"date_understanding.json"), size)

        self.few_shot_samples = []

    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in enumerate(data["examples"]):
                choices = sorted(instance_data["target_scores"].items())
                random.seed(42)
                random.shuffle(choices)

                target = [c for c,p in choices if p == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data["target"])
                choices = [c for c,p in choices]

                choice_letters = "ABCDEFGHIJKL"
                target = choice_letters[choices.index(target[0])]

                instances[i] = {
                    "template": instance_data["input"],
                    "target_scores": instance_data["target_scores"],
                    "target": target,
                    "choices": choices
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = f"""Q: {instance["template"]}"""
        choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        choices = instance["choices"]

        # condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{template}
    Answer Choices: {make_choices(choices)}
    A: Let's think step by step.\"\"\"
    for i in range(5):
      "\\n-[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, among A through {end_letter}, the answer is[answer]"
FROM 
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q


@suite("date_understanding@multivar2")
class MVDateUnderstandingEvaluationSuite(EvaluationSuite):
    def __init__(self, size, name="date_understanding@multivar2"):
        print("size is", size)
        super().__init__(name, os.path.join("tasks", f"date_understanding.json"), size)

        self.few_shot_samples = []

    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in enumerate(data["examples"]):
                choices = sorted(instance_data["target_scores"].items())
                random.seed(42)
                random.shuffle(choices)

                target = [c for c,p in choices if p == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data["target"])
                choices = [c for c,p in choices]

                choice_letters = "ABCDEFGHIJKL"
                target = choice_letters[choices.index(target[0])]

                instances[i] = {
                    "template": instance_data["input"],
                    "target_scores": instance_data["target_scores"],
                    "target": target,
                    "choices": choices
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        template = f"""Q: {instance["template"]}"""
        choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        choices = instance["choices"]

        # condition = instance["condition"]

        if shots > 0: template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{template}
    Answer Choices: {make_choices(choices)}
    A: Let's think step by step.\"\"\"
    for i in range(5):
      "\\n-[THOUGHT]"
      if not THOUGHT.endswith("\\n"): "\\n"
      if "Therefore" in THOUGHT: break
      if "According" in THOUGHT: break
    " Overall this means,[CONCLUSION] Therefore, among A through {end_letter}, the answer is[answer]"
FROM
    "{model}"
WHERE
    STOPS_AT(THOUGHT, "\\n") and STOPS_AT(THOUGHT, ".") and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q