
from typing import List, Optional

class MultipleChoiceQuestion:
    """
    The class of multiple choice questions.
    """

    def __init__(
        self,
        question: str = "",
        option_ids: List[str] = [],
        options: List[str] = [],
        correct: Optional[List[bool]] = None,
        question_first: bool = True,
        text_type="choice",
        trigger_statement=None,
    ):
        """
        Args:
            question: the question text
            option_ids: the list of option indeces (with formats), e.g., ['(A)', '(B)', '(C)', '(D)']
            options: the option contents, each of which appears after the corresponding option index.
            correct: the list indicating the correctness of each option. E.g., [True, False, False, True]
                indicates that only the first and the last options are correct answers.
            question_first: the switch controlling whether the question appears before the option.
            text_type: 'choice' or 'judgement', indicating whether the question text appears as a multiple
                choice question or a multiple judgement question
        """
        self.question = question
        self.option_ids = option_ids
        self.options = options
        self.question_first = question_first
        self.correct = correct
        self.text_type = text_type
        self.trigger_statement = trigger_statement
        assert len(option_ids) == len(options)
        if self.correct is not None:
            assert len(option_ids) == len(self.correct)
        assert text_type == "choice" or text_type == "judgement"

    def add_trigger_statement(self, trigger_statement):
        self.trigger_statement = trigger_statement

    def to_dict(self) -> dict:
        """Return the dictionary form of the current question."""
        result = {
            "question": self.question,
            "option_ids": self.option_ids,
            "options": self.options,
            "question_first": self.question_first,
            "correct": self.correct,
            "text_type": self.text_type,
            "trigger_statement": self.trigger_statement,
        }
        return result

    def load_dict(self, data: dict):
        """Load information of the question from a dictionary."""
        trigger_statement = data.get("trigger_statement", None)
        self.__init__(
            question=data["question"],
            option_ids=data["option_ids"],
            options=data["options"],
            question_first=data["question_first"],
            correct=data["correct"],
            text_type=data["text_type"],
            trigger_statement=trigger_statement,
        )

    def __str__(self):
        assert len(self.option_ids) == len(self.options)
        prompt = "Options:\n"
        for key, value in zip(self.option_ids, self.options):
            prompt += f"{key} {value}\n"
        prompt = f"Question: {self.question}\n" + prompt
        prompt += f"Answer:{self.correct}\n"
        prompt += f"question_first:{self.question_first}\n"
        prompt += f"text_type:{self.text_type}\n"
        prompt += f"trigger_statement:{self.trigger_statement}"
        return prompt

    def get_prompt(self):
        """Get the prompt of the current question for LLMs."""
        assert len(self.option_ids) == len(self.options)
        assert self.text_type == "choice" or self.text_type == "judgement"
        if self.text_type == "choice" and self.trigger_statement:
            prefix = f"Please select the correct option(s) from the following options given the question. To solve the problem, follow the {self.trigger_statement} reasoning strategy.\n"
        elif self.text_type == "choice":
            prefix = "Please select the correct option(s) from the following options given the question:\n"
        elif self.text_type == "judgement":
            prefix = "Please judge whether each of the options is correct or incorrect given the question:\n"
        prompt = "Options:\n"
        for key, value in zip(self.option_ids, self.options):
            prompt += f"{key} {value}\n"
        if self.question_first:
            prompt = f"Question: {self.question}\n" + prompt
        else:
            prompt = prompt + f"Question: {self.question}\n"
        if self.text_type == "choice":
            option_or = ", ".join([f'"{option}"' for option in self.option_ids])
            if self.trigger_statement:
                # If trigger_statement is provided, reasoning should come first, followed by the selected options
                # prompt += (
                #     "Your output must strictly follow this format:\n"
                #     '{"reasoning": <"Your reasoning step-by-step">, "answer": <the list of selected options, e.g., [%s]>}\n'
                #     % option_or
                # )
                # prompt += (
                #     '{"step_1": "<Step 1 of your reasoning>", "step_2": "<Step 2 of your reasoning>", "step_n": "<Step n of your reasoning>", "answer": <the list of selected option, e.g., ["A", "B", "C", "D", "E"]>}\n'
                # )
                prompt += '{"option_a": "<Reasoning for option a>", "option_b": "<Reasoning for option b>", ..., "answer": <the list of selected option, e.g., ["A", "B", "C", "D", "E"]>}'

            else:
                prompt += (
                    'Your output must strictly follow this format:\n{"answer": <the list of selected options, e.g., [%s]>}\n'
                    % option_or
                )
        elif self.text_type == "judgement":
            output_fmt = ", ".join(
                [f'"{option}": <"True" or "False">' for option in self.option_ids]
            )
            output_fmt = "{" + output_fmt + "}"
            prompt += f"Your output must strictly follow this format:\n{output_fmt}\n"
        prompt = prefix + prompt
        prompt += "\nYour output in a single line:"
        return prompt

    def get_formatted_answer(self):
        result = None
        if self.text_type == "choice":
            answers = []
            for option, correct in zip(self.option_ids, self.correct):
                if correct:
                    answers.append(option)

            content = ", ".join(['"' + option + '"' for option in answers])
            result = f'"answer": [{content}]'
            result = "{" + result + "}"
        elif self.text_type == "judgement":
            result = ", ".join(
                [
                    f'"{option}":"{correct}"'
                    for option, correct in zip(self.option_ids, self.correct)
                ]
            )
            result = "{" + result + "}"
        return result