
import strings as S
import pandas as pd
import csv

def get_scenarios(include_stat=True, include_preference=True, include_coins=True, include_dice=True, include_choice=True):
    scenarios = {}

    ### Choice variants
    if include_choice:
        config_choice = {
            S.prompt_parameter.SCENARIO: S.scenario_names.CHOICE,
            S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_INDEPENDENT,
            S.scenario_names.CHOICE: {
                S.prompt_parameter.NUMBER: 7,
                S.prompt_parameter.PREVIOUS_RESULT: "A",
            },
        }
        config_choices = []
        for number_of_choices in [2, 4, 6]:
            config_choice = {
                S.prompt_parameter.SCENARIO: S.scenario_names.CHOICE,
                S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REGULAR,
                S.scenario_names.CHOICE: {
                    S.prompt_parameter.NUMBER: number_of_choices,
                    S.prompt_parameter.PREVIOUS_RESULT: None,
                },
            }
            config_choices.append(config_choice)
            for previous_value in [chr(x) for x in range(ord("A"), ord("A") + number_of_choices)]:
                config_choice_copy = config_choice.copy()
                config_choice_copy[S.scenario_names.CHOICE] = config_choice[S.scenario_names.CHOICE].copy()
                config_choice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_INDEPENDENT
                config_choice_copy[S.scenario_names.CHOICE][S.prompt_parameter.NUMBER] = number_of_choices
                config_choice_copy[S.scenario_names.CHOICE][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                config_choices.append(config_choice_copy)
        print(f"Total choice configs: {len(config_choices)}")
        scenarios["choice"] = config_choices

    ### Stat variants
    if include_stat:
        configs_to_test_stat = []
        generated_stat_configs = pd.read_csv("dataset/prompts_stA_stat.csv", index_col="id", sep=",", quoting=csv.QUOTE_NONNUMERIC)
        generated_stat_configs["config"] = generated_stat_configs["config"].apply(eval)
        generated_stat_configs["choices"] = generated_stat_configs["choices"].apply(eval)
        generated_stat_configs["values"] = generated_stat_configs["values"].apply(eval)
        for idx, row in generated_stat_configs.iterrows():
            config_stat_new = {
                S.prompt_parameter.SCENARIO: S.scenario_names.STATS,
                S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REGULAR,
                S.scenario_names.STATS: {
                    "prompt": row["prompt"],
                    "question": row["question"],
                    "answer": row["answer"],
                    "choices": row["choices"],
                    "values": row["values"],
                    "config": row["config"],
                    "tokenizer": None,
                }
            }
            configs_to_test_stat.append(config_stat_new)

        print(f"Total stat configs: {len(configs_to_test_stat)}")
        scenarios["stat"] = configs_to_test_stat

    ### Preference variants
    if include_preference:
        config_preference = {
            S.prompt_parameter.SCENARIO: S.scenario_names.PREFERENCES,
            S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_INDEPENDENT,
            S.scenario_names.PREFERENCES: {
                S.prompt_parameter.OPTIONS: ["Left", "Right"],
                S.prompt_parameter.FOCUS: "Left",
                S.prompt_parameter.BIAS: 3,
                S.prompt_parameter.PREVIOUS_RESULT: "Right",
            },
        }
        configs_to_test_pref = []
        for options in [["Left", "Right"], ["Right", "Left"], ["Heads", "Tails"], ["Tails", "Heads"], ["X", "Y"], ["Y", "X"]]:
            focus = options[0]
            for bias in [1, 2, 3]:
                config_preference_copy = config_preference.copy()
                config_preference_copy[S.scenario_names.PREFERENCES] = config_preference[S.scenario_names.PREFERENCES].copy()
                config_preference_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REGULAR
                config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.OPTIONS] = options
                config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.FOCUS] = focus
                config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.BIAS] = bias
                configs_to_test_pref.append(config_preference_copy)
                for previous_value in options:
                    config_preference_copy = config_preference.copy()
                    config_preference_copy[S.scenario_names.PREFERENCES] = config_preference[S.scenario_names.PREFERENCES].copy()
                    config_preference_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_INDEPENDENT
                    config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.OPTIONS] = options
                    config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.FOCUS] = focus
                    config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.BIAS] = bias
                    config_preference_copy[S.scenario_names.PREFERENCES][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                    configs_to_test_pref.append(config_preference_copy)
        print(f"Total preference configs: {len(configs_to_test_pref)}")
        scenarios["preference"] = configs_to_test_pref

    ### Coins variants
    if include_coins:
        config_coins = {
            S.prompt_parameter.SCENARIO: S.scenario_names.COINS,
            S.prompt_parameter.SUBSCENARIO: S.subscenario_names.REPEATED_DEPENDENT,
            S.scenario_names.COINS: {
                S.prompt_parameter.NUMBER: 4,
                S.prompt_parameter.FOCUS: S.coin_faces.TAILS,
                S.prompt_parameter.BIAS: 2,
                S.prompt_parameter.PREVIOUS_RESULT: 7,
            },
        }
        configs_to_test_coins = []
        for number_of_coins in [2, 3, 4]:
            for focus in [S.coin_faces.HEADS, S.coin_faces.TAILS]:
                for bias in [1, 3, 5]:
                    config_coins_copy = config_coins.copy()
                    config_coins_copy[S.scenario_names.COINS] = config_coins[S.scenario_names.COINS].copy()
                    config_coins_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REGULAR
                    config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.NUMBER] = number_of_coins
                    config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.FOCUS] = focus
                    config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.BIAS] = bias
                    configs_to_test_coins.append(config_coins_copy)
                    for previous_value in range(number_of_coins + 1):
                        config_coins_copy = config_coins.copy()
                        config_coins_copy[S.scenario_names.COINS] = config_coins[S.scenario_names.COINS].copy()
                        config_coins_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_DEPENDENT
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.NUMBER] = number_of_coins
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.FOCUS] = focus
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.BIAS] = bias
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                        configs_to_test_coins.append(config_coins_copy)
                    for previous_value in range(number_of_coins + 1):
                        config_coins_copy = config_coins.copy()
                        config_coins_copy[S.scenario_names.COINS] = config_coins[S.scenario_names.COINS].copy()
                        config_coins_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_INDEPENDENT
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.NUMBER] = number_of_coins
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.FOCUS] = focus
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.BIAS] = bias
                        config_coins_copy[S.scenario_names.COINS][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                        configs_to_test_coins.append(config_coins_copy)
        print(f"Total coins configs: {len(configs_to_test_coins)}")
        scenarios["coins"] = configs_to_test_coins

    ### Dice variants
    if include_dice:
        config_dice = {
            S.prompt_parameter.SCENARIO: S.scenario_names.DICE,
            S.prompt_parameter.SUBSCENARIO: S.subscenario_names.OBSERVATIONS,
            S.scenario_names.DICE: {
                S.prompt_parameter.NUMBER: 4,
                S.prompt_parameter.FACES: 4,
                S.prompt_parameter.PREVIOUS_RESULT: 5,
                S.prompt_parameter.OBSERVATIONS: [
                    S.observation_names.EVEN,
                    S.observation_names.LARGER_THAN_MIDDLE,
                ],
            },
        }
        configs_to_test_dice = []
        for number_of_dice in [1, 2, 3]:
            for faces in [4, 6, 8, 10, 12]:
                config_dice_copy = config_dice.copy()
                config_dice_copy[S.scenario_names.DICE] = config_dice[S.scenario_names.DICE].copy()
                config_dice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REGULAR
                config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.OBSERVATIONS] = []
                config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.PREVIOUS_RESULT] = None
                config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.NUMBER] = number_of_dice
                config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.FACES] = faces
                configs_to_test_dice.append(config_dice_copy)
                for previous_value in range(number_of_dice, number_of_dice * faces + 1):
                    config_dice_copy = config_dice.copy()
                    config_dice_copy[S.scenario_names.DICE] = config_dice[S.scenario_names.DICE].copy()
                    config_dice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_INDEPENDENT
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.OBSERVATIONS] = []
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.NUMBER] = number_of_dice
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.FACES] = faces
                    configs_to_test_dice.append(config_dice_copy)
                for previous_value in range(number_of_dice, number_of_dice * faces + 1):
                    config_dice_copy = config_dice.copy()
                    config_dice_copy[S.scenario_names.DICE] = config_dice[S.scenario_names.DICE].copy()
                    config_dice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.REPEATED_DEPENDENT
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.OBSERVATIONS] = []
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.PREVIOUS_RESULT] = previous_value
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.NUMBER] = number_of_dice
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.FACES] = faces
                    configs_to_test_dice.append(config_dice_copy)

                for observation in [
                    S.observation_names.EVEN,
                    S.observation_names.ODD,
                    S.observation_names.SMALLER_THAN_MIDDLE,
                    S.observation_names.LARGER_THAN_MIDDLE,
                    S.observation_names.NOT_ONE,
                    S.observation_names.NOT_MIDDLE,
                ]:
                    config_dice_copy = config_dice.copy()
                    config_dice_copy[S.scenario_names.DICE] = config_dice[S.scenario_names.DICE].copy()
                    config_dice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.OBSERVATIONS
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.PREVIOUS_RESULT] = None
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.NUMBER] = number_of_dice
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.FACES] = faces
                    config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.OBSERVATIONS] = [observation]

                    configs_to_test_dice.append(config_dice_copy)

                    for observation2 in [
                        obs
                        for obs in [
                            S.observation_names.EVEN,
                            S.observation_names.ODD,
                            S.observation_names.SMALLER_THAN_MIDDLE,
                            S.observation_names.LARGER_THAN_MIDDLE,
                            S.observation_names.NOT_ONE,
                            S.observation_names.NOT_MIDDLE,
                        ]
                        if obs != observation
                    ]:
                        if S.observation_names.EVEN in [observation, observation2] and S.observation_names.ODD in [observation, observation2]:
                            continue
                        if S.observation_names.SMALLER_THAN_MIDDLE in [observation, observation2] and S.observation_names.LARGER_THAN_MIDDLE in [observation, observation2]:
                            continue
                        if faces == 4 and number_of_dice == 1:
                            if S.observation_names.EVEN in [observation, observation2] and S.observation_names.SMALLER_THAN_MIDDLE in [observation, observation2]:
                                continue
                            if S.observation_names.ODD in [observation, observation2] and S.observation_names.SMALLER_THAN_MIDDLE in [observation, observation2]:
                                continue
                            if S.observation_names.SMALLER_THAN_MIDDLE in [observation, observation2] and S.observation_names.NOT_ONE in [observation, observation2]:
                                continue
                        config_dice_copy = config_dice.copy()
                        config_dice_copy[S.scenario_names.DICE] = config_dice[S.scenario_names.DICE].copy()
                        config_dice_copy[S.prompt_parameter.SUBSCENARIO] = S.subscenario_names.OBSERVATIONS
                        config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.PREVIOUS_RESULT] = None
                        config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.NUMBER] = number_of_dice
                        config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.FACES] = faces
                        config_dice_copy[S.scenario_names.DICE][S.prompt_parameter.OBSERVATIONS] = [observation, observation2]
                        configs_to_test_dice.append(config_dice_copy)
        print(f"Total dice configs: {len(configs_to_test_dice)}")
        scenarios["dice"] = configs_to_test_dice

    return scenarios


if __name__ == "__main__":
    scenarios = get_scenarios()
    total_configs = sum([len(scenarios[scen]) for scen in scenarios])
    print(f"Total scenarios generated: {total_configs}")
