from typing import List, Optional

from PIL import Image

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithScore
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.prompts.math_vista.math_vista_code import create_one_query, SHOT_EXAMPLES
from llm_mcts.tasks.math_vista.task import MathVistaTask


class MathVistaOfficialPrompt(PromptTemplate):
    def __init__(self, prompt_config: PromptConfig, task: MathVistaTask):
        self.task = task
        assert (
            prompt_config.with_image
        ), "Math Vista Official Prompt requires an image input"

    def initial_prompt(self) -> List[str | Image.Image]:
        prompt = create_one_query(
            problem=self.task.problem,
            examples=SHOT_EXAMPLES,
            shot_num=0,
            shot_type="solution",
            use_caption=False,
            use_ocr=False,
        )
        # Official code inputs an image first: https://github.com/lupantech/MathVista/blob/99fa993d4e3f659f8d93b7786502e9109e94d273/models/gpt.py#L32
        return [self.task.problem.decoded_image.convert("RGB"), prompt]

    def feedback_prompt(
        self,
        action: Action,
        eval_results: Optional[List[EvalResult]],
        generation_result: GenerationResult,
    ) -> str:
        match action:
            case "answer":
                return answer_feedback_prompt(
                    eval_results=eval_results,
                )
            case _:
                raise NotImplementedError(
                    f"feedback_prompt not implemented for action {action}"
                )

    def add_next_action_instruction(
        self, action: Action, next_prompt: GenerationRequest
    ) -> GenerationRequest:
        last_user_msg = next_prompt.messages[-1]
        assert last_user_msg.role == "user"

        # If the prompt.messages contains assistant message, it means that the prompt is not the first turn
        # len(prompt.messages) == 1 is not restricted to the first turn when the messages contain image prompts
        is_first_turn = True
        for msg in next_prompt.messages:
            if msg.role == "assistant":
                is_first_turn = False
                break

        last_user_msg.content += next_task_prompt(action, is_first_turn=is_first_turn)
        return next_prompt


def answer_feedback_prompt(eval_results: List[EvalResultWithScore]) -> str:
    assert len(eval_results) == 1

    prompt = f"""
Another AI assistant carefully evaluated your solution and gave a score to your solution by an integer from 0 to 5:
Score:
{eval_results[0].score}
Reason:
{eval_results[0].reason}
"""
    return prompt


def next_task_prompt(kind: Action, is_first_turn: bool) -> str:
    if kind != "answer":
        raise NotImplementedError()
    if is_first_turn:
        return ""
    else:
        return """
Given the above feedback from another AI assistant, carefully revise your solution by revisiting your reasoning and thinking to answer the given mathematical question.
"""
