import constants as C
from numpy.random import shuffle
import frac as F
import numpy as np
import strings as S


NUM_TO_TEXT = {
    1: "one",
    2: "two",
    3: "three",
    4: "four",
    6: "six",
    8: "eight",
    10: "ten",
    12: "twelve",
}


class Prompt_Dice_REVB:
    _BASE_INTRO_SING = "A die has {nb_faces} faces. "
    _BASE_INTRO_PLUR = (
        "There are {num_text} {nb_faces}-sided dice. "  # num_text = {Two, Three, etc.}
    )

    _BASE_PROBABILITY_SING = "The die is equally likely to land on any of its faces. "
    _BASE_PROBABILITY_PLUR = "Each die is equally likely to land on any of its faces. "

    _BASE_INSTANTIATE_SING = "The die is cast. "
    _BASE_INSTANTIATE_PLUR = "The dice are cast. "

    # Conclusions
    _BASE_CONCLUSION_SING = "The die lands on face number"

    _BASE_CONCLUSION_PLUR = "The sum of their faces is equal to"

    _CONCLUSION_REPEATED_DEPENDENT = "The sum of both results is equal to"

    _CONCLUSION_OBSERVATION = "Indeed, the result is equal to"

    _PREV_RESULT_PROMPT_SING = (
        _BASE_INSTANTIATE_SING + _BASE_CONCLUSION_SING + " {target}. "
    )
    _PREV_RESULT_PROMPT_PLUR = (
        _BASE_INSTANTIATE_PLUR + _BASE_CONCLUSION_PLUR + " {target}. "
    )

    _INSTANTIATE_SING_AGAIN = _BASE_INSTANTIATE_SING.replace(". ", " again. ")
    _INSTANTIATE_PLUR_AGAIN = _BASE_INSTANTIATE_PLUR.replace(". ", " again. ")

    _OBSERVATION_BASE_PROMPT_SING = "We observe that the result is "
    _OBSERVATION_BASE_PROMPT_PLUR = "We observe that the sum of their faces is "
    _OBSERVATION_COMPLEMENT = " and that it is also "

    _POSSIBLE_OBSERVATIONS = {
        S.observation_names.SMALLER_THAN_MIDDLE: "smaller than {mid_val}. ",
        S.observation_names.LARGER_THAN_MIDDLE: "greater than {mid_val}. ",
        S.observation_names.EVEN: "an even number. ",
        S.observation_names.ODD: "an odd number. ",
        S.observation_names.NOT_MIDDLE: "not equal to {mid_val}. ",
        S.observation_names.NOT_ONE: "not equal to {number_of_dice}. ",
    }

    @staticmethod
    def regular_experiment(number_of_dice, face_per_dice) -> str:
        if number_of_dice == 1:
            conclusion = Prompt_Dice_REVB._BASE_CONCLUSION_SING
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_SING
                + Prompt_Dice_REVB._BASE_PROBABILITY_SING
                + Prompt_Dice_REVB._BASE_INSTANTIATE_SING
                + conclusion
            )
        else:
            conclusion = Prompt_Dice_REVB._BASE_CONCLUSION_PLUR
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_PLUR
                + Prompt_Dice_REVB._BASE_PROBABILITY_PLUR
                + Prompt_Dice_REVB._BASE_INSTANTIATE_PLUR
                + conclusion
            )

        return base_prompt.format(
            nb_faces=face_per_dice, num_text=NUM_TO_TEXT[number_of_dice]
        )

    @staticmethod
    def independant_previous_launch(
        number_of_dice,
        face_per_dice,
        previous_value,
    ) -> str:
        if number_of_dice == 1:
            conclusion = Prompt_Dice_REVB._BASE_CONCLUSION_SING
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_SING
                + Prompt_Dice_REVB._BASE_PROBABILITY_SING
                + Prompt_Dice_REVB._PREV_RESULT_PROMPT_SING
                + Prompt_Dice_REVB._INSTANTIATE_SING_AGAIN
                + conclusion
            )
        else:
            conclusion = Prompt_Dice_REVB._BASE_CONCLUSION_PLUR
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_PLUR
                + Prompt_Dice_REVB._BASE_PROBABILITY_PLUR
                + Prompt_Dice_REVB._PREV_RESULT_PROMPT_PLUR
                + Prompt_Dice_REVB._INSTANTIATE_PLUR_AGAIN
                + conclusion
            )

        return base_prompt.format(
            nb_faces=face_per_dice,
            num_text=NUM_TO_TEXT[number_of_dice],
            target=previous_value,
        )

    @staticmethod
    def dependant_previous_launch(number_of_dice, face_per_dice, previous_value) -> str:
        if number_of_dice == 1:
            conclusion = Prompt_Dice_REVB._CONCLUSION_REPEATED_DEPENDENT
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_SING
                + Prompt_Dice_REVB._BASE_PROBABILITY_SING
                + Prompt_Dice_REVB._PREV_RESULT_PROMPT_SING
                + Prompt_Dice_REVB._INSTANTIATE_SING_AGAIN
                + conclusion
            )
        else:
            conclusion = Prompt_Dice_REVB._CONCLUSION_REPEATED_DEPENDENT
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_PLUR
                + Prompt_Dice_REVB._BASE_PROBABILITY_PLUR
                + Prompt_Dice_REVB._PREV_RESULT_PROMPT_PLUR
                + Prompt_Dice_REVB._INSTANTIATE_PLUR_AGAIN
                + conclusion
            )

        return base_prompt.format(
            nb_faces=face_per_dice,
            num_text=NUM_TO_TEXT[number_of_dice],
            target=previous_value,
        )

    @staticmethod
    def one_observation(number_of_dice, face_per_dice, observation) -> str:
        conclusion = Prompt_Dice_REVB._CONCLUSION_OBSERVATION

        if number_of_dice == 1:
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_SING
                + Prompt_Dice_REVB._BASE_PROBABILITY_SING
                + Prompt_Dice_REVB._BASE_INSTANTIATE_SING
                + Prompt_Dice_REVB._OBSERVATION_BASE_PROMPT_SING
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation]
                + conclusion
            )
        else:
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_PLUR
                + Prompt_Dice_REVB._BASE_PROBABILITY_PLUR
                + Prompt_Dice_REVB._BASE_INSTANTIATE_PLUR
                + Prompt_Dice_REVB._OBSERVATION_BASE_PROMPT_PLUR
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation]
                + conclusion
            )

        midvalue = int(0.5 * number_of_dice * (face_per_dice + 1))

        return base_prompt.format(
            nb_faces=face_per_dice,
            num_text=NUM_TO_TEXT[number_of_dice],
            mid_val=midvalue,
            number_of_dice=number_of_dice,
        )

    @staticmethod
    def two_observations(
        number_of_dice, face_per_dice, observation1, observation2
    ) -> str:
        conclusion = Prompt_Dice_REVB._CONCLUSION_OBSERVATION

        if number_of_dice == 1:
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_SING
                + Prompt_Dice_REVB._BASE_PROBABILITY_SING
                + Prompt_Dice_REVB._BASE_INSTANTIATE_SING
                + Prompt_Dice_REVB._OBSERVATION_BASE_PROMPT_SING
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation1][:-2]
                + Prompt_Dice_REVB._OBSERVATION_COMPLEMENT
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation2]
                + conclusion
            )
        else:
            base_prompt = (
                Prompt_Dice_REVB._BASE_INTRO_PLUR
                + Prompt_Dice_REVB._BASE_PROBABILITY_PLUR
                + Prompt_Dice_REVB._BASE_INSTANTIATE_PLUR
                + Prompt_Dice_REVB._OBSERVATION_BASE_PROMPT_PLUR
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation1][:-2]
                + Prompt_Dice_REVB._OBSERVATION_COMPLEMENT
                + Prompt_Dice_REVB._POSSIBLE_OBSERVATIONS[observation2]
                + conclusion
            )

        midvalue = int(0.5 * number_of_dice * (face_per_dice + 1))

        return base_prompt.format(
            nb_faces=face_per_dice,
            num_text=NUM_TO_TEXT[number_of_dice],
            mid_val=midvalue,
            number_of_dice=number_of_dice,
        )


class Prompt_Dice_StaA:
    _BASE_INTRO_SING = "A die has {nb_faces} faces. "
    _BASE_INTRO_PLUR = (
        "There are {num_text} {nb_faces}-sided dice. "  # num_text = {Two, Three, etc.}
    )

    _BASE_PROBABILITY_SING = "The die is equally likely to land on any of its faces. "
    _BASE_PROBABILITY_PLUR = "Each die is equally likely to land on any of its faces. "

    _BASE_INSTANTIATE_SING = "The die is cast. "
    _BASE_INSTANTIATE_PLUR = "The dice are cast. "

    # Conclusions
    _BASE_CONCLUSION_SING = "The die lands on face number"
    _QUESTION_CONCLUSION_SING_INSTRUCT = (
        "What is the probability that the die lands on face {question_value}?"
    )

    _BASE_CONCLUSION_PLUR = "The sum of their faces is equal to"
    _QUESTION_CONCLUSION_PLUR_INSTRUCT = "What is the probability that the sum of their faces is equal to {question_value}?"

    _CONCLUSION_REPEATED_DEPENDENT = "The sum of both results is equal to"
    _QUESTION_REPEATED_DEPENDENT_INSTRUCT = "What is the probability that the sum of both results is equal to {question_value}?"

    _CONCLUSION_OBSERVATION = "Indeed, the result is equal to"
    _QUESTION_OBSERVATION_INSTRUCT = (
        "What is the probability that the result is equal to {question_value}?"
    )

    _PREV_RESULT_PROMPT_SING = (
        _BASE_INSTANTIATE_SING + _BASE_CONCLUSION_SING + " {target}. "
    )
    _PREV_RESULT_PROMPT_PLUR = (
        _BASE_INSTANTIATE_PLUR + _BASE_CONCLUSION_PLUR + " {target}. "
    )

    _INSTANTIATE_SING_AGAIN = _BASE_INSTANTIATE_SING.replace(". ", " again. ")
    _INSTANTIATE_PLUR_AGAIN = _BASE_INSTANTIATE_PLUR.replace(". ", " again. ")

    _OBSERVATION_BASE_PROMPT_SING = "We observe that the result is "
    _OBSERVATION_BASE_PROMPT_PLUR = "We observe that the sum of their faces is "
    _OBSERVATION_COMPLEMENT = " and that it is also "

    _POSSIBLE_OBSERVATIONS = {
        S.observation_names.SMALLER_THAN_MIDDLE: "smaller than {mid_val}. ",
        S.observation_names.LARGER_THAN_MIDDLE: "greater than {mid_val}. ",
        S.observation_names.EVEN: "an even number. ",
        S.observation_names.ODD: "an odd number. ",
        S.observation_names.NOT_MIDDLE: "not equal to {mid_val}. ",
        S.observation_names.NOT_ONE: "not equal to {number_of_dice}. ",
    }

    @staticmethod
    def regular_experiment(
        number_of_dice, face_per_dice, question_for_value, true_distribution
    ):
        true_answer = true_distribution[question_for_value]

        if number_of_dice == 1:
            question = Prompt_Dice_StaA._QUESTION_CONCLUSION_SING_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_SING
                + Prompt_Dice_StaA._BASE_PROBABILITY_SING
                + Prompt_Dice_StaA._BASE_INSTANTIATE_SING
            )
        else:
            question = Prompt_Dice_StaA._QUESTION_CONCLUSION_PLUR_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_PLUR
                + Prompt_Dice_StaA._BASE_PROBABILITY_PLUR
                + Prompt_Dice_StaA._BASE_INSTANTIATE_PLUR
            )

        choices, values, correct_answer = F.populate_answers_fractions(true_answer)

        bp = base_prompt.format(
            nb_faces=face_per_dice,
            num_text=C.NUM_TO_TEXT[number_of_dice],
            question_value=question_for_value,
        )

        return bp, question, choices, values, correct_answer

    @staticmethod
    def independant_previous_launch(
        number_of_dice,
        face_per_dice,
        previous_value,
        question_for_value,
        true_distribution,
    ):
        true_answer = true_distribution[question_for_value]

        if number_of_dice == 1:
            question = Prompt_Dice_StaA._QUESTION_CONCLUSION_SING_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_SING
                + Prompt_Dice_StaA._BASE_PROBABILITY_SING
                + Prompt_Dice_StaA._PREV_RESULT_PROMPT_SING
                + Prompt_Dice_StaA._INSTANTIATE_SING_AGAIN
            )
        else:
            question = Prompt_Dice_StaA._QUESTION_CONCLUSION_PLUR_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_PLUR
                + Prompt_Dice_StaA._BASE_PROBABILITY_PLUR
                + Prompt_Dice_StaA._PREV_RESULT_PROMPT_PLUR
                + Prompt_Dice_StaA._INSTANTIATE_PLUR_AGAIN
            )

        choices, values, correct_answer = F.populate_answers_fractions(true_answer)

        bp = base_prompt.format(
            nb_faces=face_per_dice,
            num_text=C.NUM_TO_TEXT[number_of_dice],
            question_value=question_for_value,
            target=previous_value,
        )

        return bp, question, choices, values, correct_answer

    @staticmethod
    def dependant_previous_launch(
        number_of_dice,
        face_per_dice,
        previous_value,
        question_for_value,
        true_distribution,
    ):
        true_answer = true_distribution[question_for_value]

        if number_of_dice == 1:
            question = Prompt_Dice_StaA._QUESTION_REPEATED_DEPENDENT_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_SING
                + Prompt_Dice_StaA._BASE_PROBABILITY_SING
                + Prompt_Dice_StaA._PREV_RESULT_PROMPT_SING
                + Prompt_Dice_StaA._INSTANTIATE_SING_AGAIN
            )
        else:
            question = Prompt_Dice_StaA._QUESTION_REPEATED_DEPENDENT_INSTRUCT.format(
                question_value=question_for_value
            )
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_PLUR
                + Prompt_Dice_StaA._BASE_PROBABILITY_PLUR
                + Prompt_Dice_StaA._PREV_RESULT_PROMPT_PLUR
                + Prompt_Dice_StaA._INSTANTIATE_PLUR_AGAIN
            )

        choices, values, correct_answer = F.populate_answers_fractions(true_answer)

        bp = base_prompt.format(
            nb_faces=face_per_dice,
            num_text=C.NUM_TO_TEXT[number_of_dice],
            question_value=question_for_value,
            target=previous_value,
        )

        return bp, question, choices, values, correct_answer

    @staticmethod  # only argmax for this method
    def one_observation(
        number_of_dice,
        face_per_dice,
        observation,
        question_for_value,
        true_distribution,
    ):
        true_answer = true_distribution[question_for_value]

        question = Prompt_Dice_StaA._QUESTION_OBSERVATION_INSTRUCT.format(
            question_value=question_for_value
        )

        if number_of_dice == 1:
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_SING
                + Prompt_Dice_StaA._BASE_PROBABILITY_SING
                + Prompt_Dice_StaA._BASE_INSTANTIATE_SING
                + Prompt_Dice_StaA._OBSERVATION_BASE_PROMPT_SING
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation]
            )
        else:
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_PLUR
                + Prompt_Dice_StaA._BASE_PROBABILITY_PLUR
                + Prompt_Dice_StaA._BASE_INSTANTIATE_PLUR
                + Prompt_Dice_StaA._OBSERVATION_BASE_PROMPT_PLUR
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation]
            )

        min_value = number_of_dice
        max_value = number_of_dice * face_per_dice
        midvalue = (min_value + max_value) // 2

        choices, values, correct_answer = F.populate_answers_fractions(true_answer)

        bp = base_prompt.format(
            nb_faces=face_per_dice,
            num_text=C.NUM_TO_TEXT[number_of_dice],
            question_value=question_for_value,
            mid_val=midvalue,
            number_of_dice=number_of_dice,
        )

        return bp, question, choices, values, correct_answer

    @staticmethod
    def two_observations(
        number_of_dice,
        face_per_dice,
        observation1,
        observation2,
        question_for_value,
        true_distribution,
    ):
        true_answer = true_distribution[question_for_value]

        question = Prompt_Dice_StaA._QUESTION_OBSERVATION_INSTRUCT.format(
            question_value=question_for_value
        )

        if number_of_dice == 1:
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_SING
                + Prompt_Dice_StaA._BASE_PROBABILITY_SING
                + Prompt_Dice_StaA._BASE_INSTANTIATE_SING
                + Prompt_Dice_StaA._OBSERVATION_BASE_PROMPT_SING
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation1][:-2]
                + Prompt_Dice_StaA._OBSERVATION_COMPLEMENT
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation2]
            )
        else:
            base_prompt = (
                Prompt_Dice_StaA._BASE_INTRO_PLUR
                + Prompt_Dice_StaA._BASE_PROBABILITY_PLUR
                + Prompt_Dice_StaA._BASE_INSTANTIATE_PLUR
                + Prompt_Dice_StaA._OBSERVATION_BASE_PROMPT_PLUR
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation1][:-2]
                + Prompt_Dice_StaA._OBSERVATION_COMPLEMENT
                + Prompt_Dice_StaA._POSSIBLE_OBSERVATIONS[observation2]
            )

        min_value = number_of_dice
        max_value = number_of_dice * face_per_dice
        midvalue = (min_value + max_value) // 2

        choices, values, correct_answer = F.populate_answers_fractions(true_answer)

        bp = base_prompt.format(
            nb_faces=face_per_dice,
            num_text=C.NUM_TO_TEXT[number_of_dice],
            question_value=question_for_value,
            mid_val=midvalue,
            number_of_dice=number_of_dice,
        )

        return bp, question, choices, values, correct_answer
