from typing import List, Optional

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithAns
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.tasks.code_contest.task import CodeContestTask
from llm_mcts.tasks.code_contest.problem import CodeContestProblem


class CodeContestCOTV5SingleTurnPrompt(PromptTemplate):
    def __init__(self, prompt_config: PromptConfig, task: CodeContestTask):
        self.task = task

    def initial_prompt(self) -> str:
        # based on the large language monkey's prompt
        # https://github.com/ScalingIntelligence/large_language_monkeys/blob/main/llmonk/generate/code_contests.py#L53
        prompt = """
Write python code to solve the following coding problem that obeys the constraints and passes the example test cases.

The output code needs to read from and write to standard IO. Please wrap your code answer using ```python and ```.

"""
        prompt += self.task.problem.description
        return prompt

    def feedback_prompt(
        self,
        action: Action,
        eval_results: Optional[List[EvalResult]],
        generation_result: GenerationResult,
    ) -> str:
        try:
            code = generation_result.parse_python_code()
        except:
            code = ""
        match action:
            case "answer":
                return answer_feedback_prompt(
                    problem=self.task.problem,
                    eval_results=eval_results,
                    code=code,
                )
            case _:
                raise NotImplementedError(
                    f"feedback_prompt not implemented for action {action}"
                )

    def add_next_action_instruction(
        self, action: Action, next_prompt: GenerationRequest
    ) -> GenerationRequest:
        last_user_msg = next_prompt.messages[-1]
        assert last_user_msg.role == "user"

        # If the prompt.messages contains assistant message, it means that the prompt is not the first turn
        # len(prompt.messages) == 1 is not restricted to the first turn when the messages contain image prompts
        is_first_turn = True
        for msg in next_prompt.messages:
            if msg.role == "assistant":
                is_first_turn = False
                break

        next_prompt.messages = next_prompt.messages[-1:]

        last_user_msg.content += next_task_prompt(action, is_first_turn=is_first_turn)
        return next_prompt


def answer_feedback_prompt(
    problem: CodeContestProblem, eval_results: List[EvalResultWithAns], code: str
) -> str:
    prompt = """
Write python code to solve the following coding problem that obeys the constraints and passes the example test cases.

There is already some code, but it has some errors and might not be passing the tests.

You will generate a fixed version of the program. You must put the entired fixed program within code delimiters only for once.

The output code needs to read from and write to standard IO. Please wrap your code answer using ```python and ```.

"""
    prompt += problem.description + "\n\n"

    if code == "" or eval_results is None:
        prompt += f"### Answer: Your answer doesn't include any code."
        return prompt

    prompt += f"""
## Code to fix
```python
{code}
```

"""

    num_correct = 0

    for i, eval_result in enumerate(eval_results):
        is_correct = eval_result.get_score() == 1
        output = eval_result.answer

        prompt += f"## Public test {i}\n\n"

        if is_correct:
            prompt += f"""
Input:
{problem.public_tests["input"][i]}
Expected Output:
{problem.public_tests["output"][i]}
Your Output:
{output}

Result: **Correct**
Great job on this test case! Let's reflect briefly:
- Could there be any special edge cases not covered here?
- Is the time or memory complexity optimal for large inputs?
"""
            num_correct += 1
        else:
            prompt += f"""
Input:
{problem.public_tests["input"][i]}
Expected Output:
{problem.public_tests["output"][i]}
Your Output:
{output}

Result: **Wrong**
Reflect step-by-step (Chain of Thought):
- Compare the actual output with the expected output.
- Identify the logical or implementation point that could cause the discrepancy.
- Propose a fix or adjustment to resolve this issue.
- Consider how the same fix might affect other test cases or edge scenarios.

"""

    if num_correct == len(eval_results):
        prompt += (
            "# Summary\n\n"
            "Your solution is correct for all the public test cases!\n\n"
            "Even so, think carefully:\n"
            "- Are there any hidden or extreme edge cases not tested publicly?\n"
            "- Could you optimize time or space complexity further?\n"
            "- Is there any refactoring that could simplify or clarify your solution?\n\n"
        )
    else:
        prompt += (
            f"# Summary\n\n"
            f"Your solution is correct for {num_correct} public test case(s) "
            f"among {len(eval_results)}.\n\n"
            "Next steps for self-correction:\n"
            "- Revisit each failed test case with your chain-of-thought reasoning.\n"
            "- Identify the source of the error or inefficiency.\n"
            "- Update and refine your solution, then retest.\n\n"
        )

    return prompt


def next_task_prompt(kind: Action, is_first_turn: bool) -> str:
    if kind != "answer":
        raise NotImplementedError()

    if is_first_turn:
        return ""
    else:
        return (
            "Given the above feedback, carefully revise your solution step by step.\n"
            "1. Reflect on any failed test case: hypothesize the cause.\n"
            "2. Diagnose the relevant parts of your code or logic.\n"
            "3. Fix and improve your implementation.\n"
            "4. Validate again with the given tests (and consider additional tests if possible).\n\n"
            "Even if all tests currently pass, consider:\n"
            "- Further optimization (time/space complexity, readability, maintainability).\n"
            "- Potential corner cases that might appear for very large or special inputs.\n"
        )
