from typing import List, Optional

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithAns
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.tasks.code_contest.task import CodeContestTask
from llm_mcts.tasks.code_contest.problem import CodeContestProblem


class CodeContestBaselinePrompt(PromptTemplate):
    def __init__(self, prompt_config: PromptConfig, task: CodeContestTask):
        self.task = task

    def initial_prompt(self) -> str:
        # based on the large language monkey's prompt
        # https://github.com/ScalingIntelligence/large_language_monkeys/blob/main/llmonk/generate/code_contests.py#L53
        prompt = """
Write python code to solve the following coding problem that obeys the constraints and passes the example test cases.

The output code needs to read from and write to standard IO. Please wrap your code answer using ```python and ```.

"""
        prompt += self.task.problem.description
        return prompt

    def feedback_prompt(
        self,
        action: Action,
        eval_results: Optional[List[EvalResult]],
        generation_result: GenerationResult,
    ) -> str:
        if eval_results is None:
            return "Your answer doesn't include any code."
        match action:
            case "answer":
                return answer_feedback_prompt(
                    problem=self.task.problem,
                    eval_results=eval_results,
                )
            case _:
                raise NotImplementedError(
                    f"feedback_prompt not implemented for action {action}"
                )

    def add_next_action_instruction(
        self, action: Action, next_prompt: GenerationRequest
    ) -> GenerationRequest:
        last_user_msg = next_prompt.messages[-1]
        assert last_user_msg.role == "user"

        # If the prompt.messages contains assistant message, it means that the prompt is not the first turn
        # len(prompt.messages) == 1 is not restricted to the first turn when the messages contain image prompts
        is_first_turn = True
        for msg in next_prompt.messages:
            if msg.role == "assistant":
                is_first_turn = False
                break

        last_user_msg.content += next_task_prompt(action, is_first_turn=is_first_turn)
        return next_prompt


def answer_feedback_prompt(
    problem: CodeContestProblem, eval_results: List[EvalResultWithAns]
) -> str:
    prompt = ""
    num_correct = 0
    for i, eval_result in enumerate(eval_results):
        is_correct = eval_result.get_score() == 1
        output = eval_result.answer
        prompt += f"## Public test {i}\n\n"
        if is_correct is True:
            prompt += "Result: Correct\n\n"
            num_correct += 1
        else:
            prompt += f"""
Result: Wrong

Input:
{problem.public_tests["input"][i]}
Expected Output:
{problem.public_tests["output"][i]}
Your Output:
{output}

"""

    if num_correct == len(eval_results):
        prompt += (
            "# Summary\n\nYour solution is correct for all the public test cases!\n\n"
        )
    else:
        prompt += f"# Summary\n\nYour solution is correct for {num_correct} public test cases among {len(eval_results)}!\n\n"

    return prompt


def next_task_prompt(kind: Action, is_first_turn: bool) -> str:
    if kind != "answer":
        raise NotImplementedError()
    if is_first_turn:
        return ""
    else:
        return """
Given the above feedback, carefully revise your solution by revisiting your reasoning and thinking to answer the given coding problem.

Even if your code produces the correct answer, strive to optimize it further by considering ways to make it faster and more efficient in execution.
"""
