from evaluation_suite import *


def parse_choice(prediction):
    # print("parse choice from", prediction)
    # original_prediction = prediction
    # prediction = prediction.replace("(", "").replace(")", "")
    # prediction = prediction.strip(". ")
    # if len(prediction) == 0:
    #     return "<parse_choice could not parse '{}'>".format(original_prediction)
    # prediction = prediction[0]
    print("parse prediction as", prediction.strip())
    return prediction.strip()


# 20, 167, ensured to be excluded from mini, small, medium and large
ao_few_shot_samples = """Q: Alice, Bob, and Claire are playing a game. At the start of the game, they are each holding a ball: Alice has a blue ball, Bob has a yellow ball, and Claire has a green ball. As the game progresses, pairs of players trade balls. First, Bob and Claire swap balls. Then, Bob and Alice swap balls. Finally, Bob and Claire swap balls. At the end of the game, Claire has the blue ball.\\n

Q: Alice, Bob, and Claire are holding a white elephant gift exchange. At the start of the event, they are each holding a present of a different color: Alice has a red present, Bob has a black ball, and Claire has a pink ball. As the event progresses, pairs of people swap gifts. First, Alice and Claire swap their gifts. Then, Bob and Alice swap their gifts. Finally, Alice and Claire swap their gifts. At the end of the event, Claire has the black ball.\\n
"""

# 20, 167, ensured to be excluded from mini, small, medium and large
cot_few_shot_samples = """Q: Alice, Bob, and Claire are playing a game. At the start of the game, they are each holding a ball: Alice has a blue ball, Bob has a yellow ball, and Claire has a green ball. \\n

As the game progresses, pairs of players trade balls. First, Bob and Claire swap balls. Then, Bob and Alice swap balls. Finally, Bob and Claire swap balls. 

A: Let's think step by step.
As the game progresses, pairs of players trade balls. First, Bob and Claire swap balls. 
- This means, now  Bob has a green ball and Claire has a yellow ball. Whereas  Alice still has a blue ball.
Then, Bob and Alice swap balls.
- This means, now  Bob has a blue ball and Alice has a green ball. Whereas  Claire still has a yellow ball.. Finally, Bob and Claire swap balls. 
Overall this means, now Alice has a yellow ball, Bob has a green ball, and Claire has a blue ball. Therefore, At the end of the game, Claire has the blue ball\\n

Q: Alice, Bob, and Claire are holding a white elephant gift exchange. At the start of the event, they are each holding a present of a different color: Alice has a red present, Bob has a black ball, and Claire has a pink ball.\\n

As the event progresses, pairs of people swap gifts. First, Alice and Claire swap their gifts. Then, Bob and Alice swap their gifts. Finally, Alice and Claire swap their gifts. 

A: Let's think step by step.
As the event progresses, pairs of people swap gifts. First, Alice and Claire swap their gifts.
- This means, now Alice has a pink present and Claire has a red present. Whereas Bob still has a black ball.
Then, Bob and Alice swap their gifts. 
- This means, now Alice has a black ball and Bob has a pink present. Whereas Claire still has a red present.
Finally, Alice and Claire swap their gifts. 
Overall this means, Alice has a red present, Bob has a pink present, and Claire has a black ball. Therefore, At the end of the event, Claire has the black ball\\n
"""

# 20, 167, ensured to be excluded from mini, small, medium and large
few_shot_samples = """Q: Alice, Bob, and Claire are playing a game. At the start of the game, they are each holding a ball: Alice has a blue ball, Bob has a yellow ball, and Claire has a green ball.\\n

As the game progresses, pairs of players trade balls. First, Bob and Claire swap balls. 
- This means, now  Bob has a green ball and Claire has a yellow ball. Whereas  Alice still has a blue ball.
Then, Bob and Alice swap balls.
- This means, now  Bob has a blue ball and Alice has a green ball. Whereas  Claire still has a yellow ball.. Finally, Bob and Claire swap balls. 
A: Overall this means, now Alice has a yellow ball, Bob has a green ball, and Claire has a blue ball. Therefore, At the end of the game, Claire has the blue ball\\n

Q: Alice, Bob, and Claire are holding a white elephant gift exchange. At the start of the event, they are each holding a present of a different color: Alice has a red present, Bob has a black ball, and Claire has a pink ball.\\n

As the event progresses, pairs of people swap gifts. First, Alice and Claire swap their gifts.
- This means, now Alice has a pink present and Claire has a red present. Whereas Bob still has a black ball.
Then, Bob and Alice swap their gifts. 
- This means, now Alice has a black ball and Bob has a pink present. Whereas Claire still has a red present.
Finally, Alice and Claire swap their gifts. 
A: Overall this means, Alice has a red present, Bob has a pink present, and Claire has a black ball. Therefore,  At the end of the event, Claire has the black ball\\n
"""

@suite("tracking_shuffled_objects@multivar2")
class TrackingShuffledObjectsSuiteMV2(EvaluationSuite):
    def __init__(self, size):
        super().__init__(f"tracking_shuffled_objects@multivar2", os.path.join("generated_tasks", f"tracking_shuffled_objects.json"), size)

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in data["instances"].items():
                target = [k for k,v in instance_data["target_scores"].items() if v == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data)

                instances[i] = {
                    "template": instance_data["template"],
                    "condition": instance_data["condition"],
                    "target_scores": instance_data["target_scores"],
                    "target": target[0].rstrip(".")
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        # template = f"""\"Q: {instance["template"]}\\n\"""" + "\n\"A: [COT][ANSWER]\""
        # choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        
        template = ""
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = f'''"""Q: {template}"""'''

        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            template = f"\"\"\"{few_shot_samples}\"\"\"\n" + template
        
        template = template.replace("Now[state_var] ", "\\n- This means, now[state_var] Whereas [state_var]\\n")
        
        sentences = template.split(".")
        last_sentence = sentences[-1]
        template = ".".join(sentences[:-1]) + "."

        # template = template.replace("\n\n", "\n\nA: Let's keep track as it goes:\n")

        condition = instance["condition"].replace('."', '"')
        condition = condition.replace(' STOPS_AT(state_var, "")', ' STOPS_AT(state_var, ".")')
        # condition = instance["condition"]

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args
        
        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    A: Overall this means[CONCLUSION] Therefore, {last_sentence}
FROM 
    "{model}"
WHERE
    {condition} and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".") {" and len(TOKENS(state_var)) < 24" if decoder=="beam_search" else ""}
        """.strip()
        return q



@suite("tracking_shuffled_objects@multivar")
class TrackingShuffledObjectsSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__(f"tracking_shuffled_objects@multivar", os.path.join("generated_tasks", f"tracking_shuffled_objects.json"), size)

        self.few_shot_samples = []

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in data["instances"].items():
                target = [k for k,v in instance_data["target_scores"].items() if v == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data)

                instances[i] = {
                    "template": instance_data["template"],
                    "condition": instance_data["condition"],
                    "target_scores": instance_data["target_scores"],
                    "target": target[0].rstrip(".")
                }
            return instances

    def make_query(self, instance, model, decoder, kwargs, shots):
        # template = f"""\"Q: {instance["template"]}\\n\"""" + "\n\"A: [COT][ANSWER]\""
        # choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = f'''"""Q: {template}"""'''
        template = template.replace("Now[state_var] ", "As a result, [state_var]")
        
        sentences = template.split(".")
        last_sentence = sentences[-1]
        template = ".".join(sentences[:-1]) + "."

        # template = template.replace("\n\n", "\n\nA: Let's keep track as it goes:\n")

        condition = instance["condition"].replace('."', '"')
        condition = condition.replace(' STOPS_AT(state_var, "")', ' STOPS_AT(state_var, ".")')
        # condition = instance["condition"]

        if shots > 0: 
            assert len(self.few_shot_samples) > 0, f"No few shot samples available for {self.task}"
            template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args
        
        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(template)}
    A: Overall this means[CONCLUSION] Therefore, {last_sentence}
FROM 
    "{model}"
WHERE
    {condition} and STOPS_AT(CONCLUSION, "\\n") and STOPS_AT(CONCLUSION, ".")
        """.strip()
        return q


@suite("tracking_shuffled_objects@multivar_cot")
class MVCotTrackingShuffledObjectsSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__(f"tracking_shuffled_objects@multivar_cot", os.path.join("generated_tasks", f"tracking_shuffled_objects.json"), size)

        self.few_shot_samples = []

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in data["instances"].items():
                target = [k for k,v in instance_data["target_scores"].items() if v == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data)

                instances[i] = {
                    "template": instance_data["template"],
                    "condition": instance_data["condition"],
                    "target_scores": instance_data["target_scores"],
                    "target": target[0].rstrip("."),
                }
            return instances

    def get_choices(self, instance):
        condition = instance["condition"]
        # condition = instance["condition"]
        choices = condition.split("answer in [")[1].split("]")[0].replace('", "', '|').replace('"', '').split("|")
        return [c.strip().rstrip(".") for c in choices]

    def make_query(self, instance, model, decoder, kwargs, shots):
        # template = f"""\"Q: {instance["template"]}\\n\"""" + "\n\"A: [COT][ANSWER]\""
        # choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.replace("Now[state_var] ", "Now[state_var] This means that[state_var] ")


        sentences = template.split(".")
        last_sentence = sentences[-1]
        template = ".".join(sentences[:-1]) + "."

        template = f'''Q: {template}'''

        choices = self.get_choices(instance)

        if shots > 0: 
            assert len(self.few_shot_samples) > 0, f"No few shot samples available for {self.task}"
            template = "\n".join(f'"""{s}"""' for s in self.few_shot_samples[:shots]) + "\n" + template

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    \"\"\"{indent(template)}
    A: Overall this means[CONCLUSION] Therefore, {last_sentence}\"\"\"
FROM 
    "{model}"
WHERE
    STOPS_AT(CONCLUSION, ".") and STOPS_AT(state_var, ".") and STOPS_AT(state_var, "\\n") and
    answer in [{", ".join(f'" {c}"' for c in choices)}] {" and len(TOKENS(state_var)) < 24" if decoder=="beam_search" else ""}
        """.strip()
        return q
    
    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)

def make_choices(choices):
    l = "ABCDEFGHIJKL"
    return " ".join([f"({l[i]}) {c}" for i,c in enumerate(choices)])

@suite("tracking_shuffled_objects@cot")
class CotTrackingShuffledObjectsSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__(f"tracking_shuffled_objects@cot", os.path.join("generated_tasks", f"tracking_shuffled_objects.json"), size)

        self.few_shot_samples = []

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in data["instances"].items():
                target = [k for k,v in instance_data["target_scores"].items() if v == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data)

                instances[i] = {
                    "template": instance_data["template"],
                    "condition": instance_data["condition"],
                    "target_scores": instance_data["target_scores"],
                    "target": target[0].rstrip("."),
                }
            return instances

    def get_choices(self, instance):
        condition = instance["condition"]
        # condition = instance["condition"]
        choices = condition.split("answer in [")[1].split("]")[0].replace('", "', '|').replace('"', '').split("|")
        return [c.strip().rstrip(".") for c in choices]

    def make_query(self, instance, model, decoder, kwargs, shots):
        # template = f"""\"Q: {instance["template"]}\\n\"""" + "\n\"A: [COT][ANSWER]\""
        # choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.replace("Now[state_var] ", "")
        sentences = template.split(".")
        last_sentence = sentences[-1]
        template = ".".join(sentences[:-1]) + "."
        template = f'''{template}'''

        choices = self.get_choices(instance)

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        samples = ""
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            samples = f"\"\"\"{cot_few_shot_samples}\"\"\"\n"

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(samples)}
    \"\"\"Q: {indent(template)}
    A: Let's think step by step.
    [COT] Therefore,{last_sentence}.
    \"\"\"
FROM 
    "{model}"
WHERE
    answer in [{", ".join(f'" {c}"' for c in choices)}] {" and len(TOKENS(COT)) < 150" if decoder=="beam_search" else ""}
        """.strip()
        return q
    
    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)


@suite("tracking_shuffled_objects@ao")
class AnswerOnlyTrackingShuffledObjectsSuite(EvaluationSuite):
    def __init__(self, size):
        super().__init__(f"tracking_shuffled_objects@ao", os.path.join("generated_tasks", f"tracking_shuffled_objects.json"), size)

        self.few_shot_samples = []

    def instances(self):
        # load task data
        with open(self.file, "r") as f:
            data = json.load(f)
            instances = {}

            for i, instance_data in data["instances"].items():
                target = [k for k,v in instance_data["target_scores"].items() if v == 1]
                assert len(target) == 1, "More than one target score for dataset sample {}".format(instance_data)

                instances[i] = {
                    "template": instance_data["template"],
                    "condition": instance_data["condition"],
                    "target_scores": instance_data["target_scores"],
                    "target": target[0].rstrip("."),
                }
            return instances

    def get_choices(self, instance):
        condition = instance["condition"]
        # condition = instance["condition"]
        choices = condition.split("answer in [")[1].split("]")[0].replace('", "', '|').replace('"', '').split("|")
        return [c.strip().rstrip(".") for c in choices]

    def make_query(self, instance, model, decoder, kwargs, shots):
        # template = f"""\"Q: {instance["template"]}\\n\"""" + "\n\"A: [COT][ANSWER]\""
        # choices = ", ".join(f'"{k}"' for k in instance["target_scores"].keys())
        template = "\n".join([lstrip_spaces(l) for l in instance["template"].split("\n")]).replace("\t", "    ")
        template = template.replace("Now[state_var] ", "")
        # template = template.replace("[answer]", " Which choice is true?")
        template = f'''{template}'''

        choices = self.get_choices(instance)

        additional_args = sorted([f"{k}={v}" for k,v in kwargs.items()])
        additional_args = ", ".join(additional_args)
        if len(additional_args) > 0:
            additional_args = ", " + additional_args

        choice_label = "ABCDEFGHIJKL"
        end_letter = choice_label[len(choices) - 1]
        
        samples = ""
        assert shots == 2 or shots == 0, f"shots should be 0 or 2, but is {shots}"
        if shots == 2:
            samples = f"\"\"\"{ao_few_shot_samples}\"\"\"\n"

        q = f"""
BEAM(dclib_decoder="{decoder}"{additional_args})
    {indent(samples)}
    \"\"\"Q: {indent(template)}\"\"\"
FROM 
    "{model}"
WHERE
    answer in [{", ".join(f'" {c}"' for c in choices)}]
        """.strip()
        return q
    
    def parse_prediction(self, prediction, model_result):
        return parse_choice(prediction)



def make_choices(choices):
    l = "ABCDEFGHIJKL"
    return " ".join([f"({l[i]}) {c}" for i,c in enumerate(choices)])