from prompts_builder.string_tools_choice import Prompt_Choice_StaA, Prompt_Choice_REVB
from prompts_builder.string_tools_coins import Prompt_Coins_StaA, Prompt_Coins_REVB
from prompts_builder.string_tools_preference import (
    Prompt_Preference_StaA,
    Prompt_Preference_REVB,
)
from prompts_builder.string_tools_dice import Prompt_Dice_StaA, Prompt_Dice_REVB
import utils_distributions as UD
import strings as S


def generate_prompt_elements_StaA(config):
    scenario = config[S.prompt_parameter.SCENARIO]
    subscenario = config[S.prompt_parameter.SUBSCENARIO]
    subconfig = config[scenario]

    distribution = UD.generateTrueDistribution(config=config)
    bool, question_for_value = UD.find_maximum_likelihoo(distribution)

    assert bool, "The scenario/config does not have a maximum likelihood"

    if scenario == S.scenario_names.CHOICE:
        if subscenario == S.subscenario_names.REGULAR:
            return Prompt_Choice_StaA.regular_experiment(
                number_of_choices=subconfig[S.prompt_parameter.NUMBER],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            return Prompt_Choice_StaA.independant_previous_launch(
                number_of_choices=subconfig[S.prompt_parameter.NUMBER],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.COINS:
        if subscenario == S.subscenario_names.REGULAR:
            return Prompt_Coins_StaA.regular_experiment(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.REPEATED_DEPENDENT:
            return Prompt_Coins_StaA.dependant_previous_launch(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            return Prompt_Coins_StaA.independant_previous_launch(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.PREFERENCES:
        _options = subconfig[S.prompt_parameter.OPTIONS]
        _focus = subconfig[S.prompt_parameter.FOCUS]
        _bias = subconfig[S.prompt_parameter.BIAS]

        _other = None
        for option in _options:
            if option != _focus:
                _other = option

        if subscenario == S.subscenario_names.REGULAR:
            return Prompt_Preference_StaA.regular_experiment(
                biased_factor=_bias,
                focus=_focus,
                other=_other,
                question_for_value=question_for_value,
                true_distribution=distribution,
            )

        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            return Prompt_Preference_StaA.independant_previous_launch(
                biased_factor=_bias,
                focus=_focus,
                other=_other,
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.DICE:
        if subscenario == S.subscenario_names.REGULAR:
            return Prompt_Dice_StaA.regular_experiment(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.REPEATED_DEPENDENT:
            return Prompt_Dice_StaA.dependant_previous_launch(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            return Prompt_Dice_StaA.independant_previous_launch(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
                question_for_value=question_for_value,
                true_distribution=distribution,
            )
        elif subscenario == S.subscenario_names.OBSERVATIONS:
            observations = subconfig[S.prompt_parameter.OBSERVATIONS]
            if len(observations) == 1:
                return Prompt_Dice_StaA.one_observation(
                    number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                    face_per_dice=subconfig[S.prompt_parameter.FACES],
                    observation=observations[0],
                    question_for_value=question_for_value,
                    true_distribution=distribution,
                )
            elif len(observations) == 2:
                return Prompt_Dice_StaA.two_observations(
                    number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                    face_per_dice=subconfig[S.prompt_parameter.FACES],
                    observation1=observations[0],
                    observation2=observations[1],
                    question_for_value=question_for_value,
                    true_distribution=distribution,
                )
            else:
                raise ValueError(f"Invalid number of observations : {observations}")

        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )
    elif scenario == S.scenario_names.STATS:
        row = subconfig
        # print(row.keys())
        base_prompt = row["prompt"]# + " " + row["config"]["mcq_ending"]

        # bp, question, choices, values, correct_answer
        question = row["question"]
        choices = row["choices"]
        values = row["values"]
        correct_answer = row["answer"]

        return base_prompt, question, choices, values, correct_answer
    else:
        raise ValueError(f"Unknown scenario {scenario}")


def generate_prompt_elements_REVB(config):
    scenario = config[S.prompt_parameter.SCENARIO]
    subscenario = config[S.prompt_parameter.SUBSCENARIO]
    subconfig = config[scenario]

    distribution = UD.generateTrueDistribution(config=config)

    if scenario == S.scenario_names.CHOICE:
        if subscenario == S.subscenario_names.REGULAR:
            base_prompt = Prompt_Choice_REVB.regular_experiment(
                number_of_choices=subconfig[S.prompt_parameter.NUMBER],
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            base_prompt = Prompt_Choice_REVB.independant_previous_launch(
                number_of_choices=subconfig[S.prompt_parameter.NUMBER],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.COINS:
        if subscenario == S.subscenario_names.REGULAR:
            base_prompt = Prompt_Coins_REVB.regular_experiment(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
            )
        elif subscenario == S.subscenario_names.REPEATED_DEPENDENT:
            base_prompt = Prompt_Coins_REVB.dependant_previous_launch(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            base_prompt = Prompt_Coins_REVB.independant_previous_launch(
                number_of_coins=subconfig[S.prompt_parameter.NUMBER],
                weight_of_focus=subconfig[S.prompt_parameter.BIAS],
                focus=subconfig[S.prompt_parameter.FOCUS],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.PREFERENCES:
        _options = subconfig[S.prompt_parameter.OPTIONS]
        _focus = subconfig[S.prompt_parameter.FOCUS]
        _bias = subconfig[S.prompt_parameter.BIAS]

        _other = None
        for option in _options:
            if option != _focus:
                _other = option

        if subscenario == S.subscenario_names.REGULAR:
            base_prompt = Prompt_Preference_REVB.regular_experiment(
                biased_factor=_bias,
                focus=_focus,
                other=_other,
            )

        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            base_prompt = Prompt_Preference_REVB.independant_previous_launch(
                biased_factor=_bias,
                focus=_focus,
                other=_other,
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )

    elif scenario == S.scenario_names.DICE:
        if subscenario == S.subscenario_names.REGULAR:
            base_prompt = Prompt_Dice_REVB.regular_experiment(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
            )
        elif subscenario == S.subscenario_names.REPEATED_DEPENDENT:
            base_prompt = Prompt_Dice_REVB.dependant_previous_launch(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        elif subscenario == S.subscenario_names.REPEATED_INDEPENDENT:
            base_prompt = Prompt_Dice_REVB.independant_previous_launch(
                number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                face_per_dice=subconfig[S.prompt_parameter.FACES],
                previous_value=subconfig[S.prompt_parameter.PREVIOUS_RESULT],
            )
        elif subscenario == S.subscenario_names.OBSERVATIONS:
            observations = subconfig[S.prompt_parameter.OBSERVATIONS]
            if len(observations) == 1:
                base_prompt = Prompt_Dice_REVB.one_observation(
                    number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                    face_per_dice=subconfig[S.prompt_parameter.FACES],
                    observation=observations[0],
                )
            elif len(observations) == 2:
                base_prompt = Prompt_Dice_REVB.two_observations(
                    number_of_dice=subconfig[S.prompt_parameter.NUMBER],
                    face_per_dice=subconfig[S.prompt_parameter.FACES],
                    observation1=observations[0],
                    observation2=observations[1],
                )
            else:
                raise ValueError(f"Invalid number of observations : {observations}")

        else:
            raise ValueError(
                f"Unknown subscenario {subscenario} for scenario{scenario}"
            )
    elif scenario == S.scenario_names.STATS:
        row = subconfig
        tokenizer = subconfig["tokenizer"]
        user_query = row["prompt"]
        assistant_prompt = row["config"]["revb_ending"]
        base_prompt = tokenizer.apply_chat_template(
            [
                {"role": "user", "content": user_query},
                {"role": "assistant", "content": assistant_prompt},
            ],
            #return_tensors="pt",
            tokenize=False,
            continue_final_message=True,
            #return_dict=True,
        )
        #base_prompt = row["prompt"] + " " + row["config"]["revb_ending"]
        #print(base_prompt)
    else:
        raise ValueError(f"Unknown scenario {scenario}")

    outcomes = list(distribution.keys())
    return base_prompt, outcomes
