from collections import defaultdict
from typing import List, Optional

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResultWithAns
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.arc.grid_repr import list_format
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.tasks.arc.task import ARCProblem


class KouV1Prompt(PromptTemplate):
    version = "kou_v1"

    def __init__(self, prompt_config: PromptConfig, problem: ARCProblem):
        self.is_o1 = prompt_config.is_o1
        self.problem = problem

    def initial_prompt(self) -> str:
        prompt = initial_prompt(self.is_o1)
        prompt += problem_prompt(self.problem)
        return prompt

    def feedback_prompt(
        self,
        action: Action,
        eval_results: Optional[List[EvalResultWithAns]],
        generation_result: GenerationResult,
    ) -> str:
        llm_answer_python_code = generation_result.parse_python_code()
        if eval_results is None:
            return "Your previous code didn't work as expected due to error or invalid format."

        match action:
            case "transform":
                return transform_feedback_prompt(
                    problem=self.problem,
                    eval_results=eval_results,
                    pycode=llm_answer_python_code,
                )
            case "question":
                return question_feedback_prompt(eval_results=eval_results)
            case "multi_questions":
                prompt = ""

                fn_name_to_evs = defaultdict(list)
                for eval_result in eval_results:
                    assert isinstance(eval_result, dict)
                    for fn_name, is_correct in eval_result.items():
                        fn_name_to_evs[fn_name].append((is_correct, None))

                for fn_name, evs in fn_name_to_evs.items():
                    prompt += f"# Evaluation Results for function {fn_name}\n\n"
                    prompt += question_feedback_prompt(evs, fn_name)
                return prompt
            case _:
                raise NotImplementedError(
                    f"feedback_prompt not implemented for action {action}"
                )

    def add_next_action_instruction(
        self, action: Action, next_prompt: GenerationRequest
    ) -> GenerationRequest:
        last_user_msg = next_prompt.messages[-1]
        assert last_user_msg.role == "user"

        # If the prompt.messages contains assistant message, it means that the prompt is not the first turn
        # len(prompt.messages) == 1 is not restricted to the first turn when the messages contain image prompts
        is_first_turn = True
        for msg in next_prompt.messages:
            if msg.role == "assistant":
                is_first_turn = False
                break

        last_user_msg.content += next_task_prompt(action, is_first_turn=is_first_turn)
        return next_prompt


def problem_prompt(problem: ARCProblem) -> str:
    prompt = ""
    for i, demo in enumerate(problem.demos):
        prompt += f"""
# Example {i+1}

## Input
{list_format(demo['input'])}

## Output
{list_format(demo['output'])}

"""
    for i, test in enumerate(problem.tests):
        prompt += f"""
# Additional Input {i+1}
{list_format(test['input'])}

"""
    return prompt


def initial_prompt(is_o1: bool) -> str:
    task_explanation = """
You will be given some number of paired example inputs and outputs. The outputs were produced by applying a transformation rule to the inputs. In addition to the paired example inputs and outputs, there is also one additional input without a known output. Your task is to determine the transformation rule and implement it in code.

The inputs and outputs are each "grids". A grid is a rectangular matrix of integers between 0 and 9 (inclusive). These grids will be shown to you as grids of numbers (ASCII). Each number corresponds to a color. The correspondence is as follows: black: 0, blue: 1, red: 2, green: 3, yellow: 4, grey: 5, pink: 6, orange: 7, purple: 8, brown: 9.

The transformation only needs to be unambiguous and applicable to the example inputs and the additional input. It doesn't need to work for all possible inputs.
"""

    reasoning_instruction = """
You'll need to carefully reason in order to determine the transformation rule. Start your response by carefully reasoning in <reasoning></reasoning> tags. Then, implement the transformation in code.

After your reasoning write code in triple backticks (```python and then ```). You should write a function called `transform` which takes a single argument, the input grid as `list[list[int]]`, and returns the transformed grid (also as `list[list[int]]`). You should make sure that you implement a version of the transformation which works in general (it shouldn't just work for the additional input).
"""
    if is_o1:
        reasoning_instruction = """
At the beginning, write your reasoning and details.

After that, write code in triple backticks (```python and then ```). You should write a function called `transform` which takes a single argument, the input grid as `list[list[int]]`, and returns the transformed grid (also as `list[list[int]]`). You should make sure that you implement a version of the transformation which works in general (it shouldn't just work for the additional input).
"""

    other_instruction = """
Don't write tests in your python code, just output the `transform` function. (It will be tested later.)

You can also ask question to verify your observation on the inputs/outputs patterns in the form of python function which takes two arguments, the input and expected output grid both as `list[list[int]]` and returns the boolean flag (True or False). We will help you by running your Python function on examples and let you know whether your question is True or False.

You follow a particular reasoning style. You break down complex problems into smaller parts and reason through them step by step, arriving at sub-conclusions before stating an overall conclusion. This reduces the extent to which you need to do large leaps of reasoning.

You reason in substantial detail for as is necessary to determine the transformation rule.

You are creative and accomplished at solving puzzles. When you write `transform`, do not hardcode the solution for each example. We will run your transform function on additional inputs later and check if your logic is generic in addition to check the correctness.
"""
    if is_o1:
        other_instruction = """
Don't write tests in your python code, just output the `transform` function. (It will be tested later.)

You are creative and accomplished at solving puzzles. When you write `transform`, do not hardcode the solution for each example. We will run your transform function on additional inputs later and check if your logic is generic in addition to check the correctness.
"""

    return task_explanation + reasoning_instruction + other_instruction


def transform_feedback_prompt(
    problem: ARCProblem, eval_results: List[EvalResultWithAns], pycode: Optional[str]
) -> str:
    prompt = ""
    num_correct = 0
    for i, eval_result in enumerate(eval_results):
        output = eval_result.answer
        is_correct = eval_result.get_score() == 1.0
        prompt += f"# Example {i}\n\n"
        if is_correct is True:
            prompt += "Result: Correct\n\n"
            num_correct += 1
        else:
            prompt += f"""
Result: Wrong

Your Output:
{list_format(output)}
Expected Output:
{list_format(problem.demos[i]["output"])}

"""

    if num_correct == len(eval_results):
        prompt += "# Summary\n\nYour solution is correct for all the problems!\n\n"
    else:
        prompt += f"# Summary\n\nYour solution is correct for {num_correct} problems among {len(eval_results)}!\n\n"

    # We also show transform function's result on additional inputs
    if pycode is None:
        prompt += (
            "Your `transform` function was malformed, so please fix it accordingly.\n\n"
        )
    else:
        prompt += "Also, here are the outputs of your `transform` function on additional inputs. Please check if your `transform` worked on additional inputs as intended, and correct your mistake in your next turns.\n\n"
        outputs = problem.run_transform_on_tests(pycode)
        for i, eval_result in enumerate(outputs):
            output = eval_result.answer
            prompt += f"# Transformed output on Additional Input {i}\n\n"
            if output is None:
                prompt += (
                    f"Your `transform` function is invalid for Additional Input {i}\n\n"
                )
            else:
                prompt += f"""
{list_format(output)}

"""

    return prompt


def next_task_prompt(kind: Action, is_first_turn: bool) -> str:
    first_line = (
        "Given the above result, reflect what was correct and/or wrong with your understanding and correct it accordingly inside <reflection></reflection> block, and w"
        if not is_first_turn
        else "W"
    )

    if kind == "question":
        return (
            f"{first_line}"
            + "rite a new question Python function which takes two arguments, input grid and expected output grid of the given problem, inside code block surrounded by ```python and ```.\n"
            "We will run your Python function and tell you its results on examples, so please use this opportunity to deepen or verify your understanding and insights.\n"
            "Your final objective is to eventually write the correct `transform` function and a question is just an aid for that, so keep that in mind when you write the questions.\n"
            "Your python code block should only include your question function, and not the input and output examples or print statements.\n"
        )
    elif kind == "transform":
        return (
            f"{first_line}"
            + "rite your reasoning and details inside <reasoning></reasoning> block, and then write a new transform Python function which takes input grid as an argument inside code block surrounded by ```python and ```.\n"
            "Also, be careful to find pattern from example input and output and try to generalize it to additional inputs. "
            "DO NOT hardcode output into your `transform` function and return it for each example. Please remember that your task is to identify general transform pattern from examples.\n"
        )
    elif kind == "multi_questions":
        return (
            f"{first_line}"
            + "rite mutliple new question Python functions which takes two arguments, input grid and expected output grid of the given problem, inside code blocks surrounded by ```python and ```.\n"
            "We will run each of your Python functions and tell you its results on examples, so please use this opportunity to deepen or verify your understanding and insights.\n"
            "You can write multiple question functions, but put them in a separate python code block. For each code block, you can include only a single Python function.\n"
            "Your final objective is to eventually write the correct `transform` function and questions are just aid for that, so keep that in mind when you write the questions.\n"
            "Your python code block should only include your question function, and not the input and output examples or print statements.\n"
        )
    else:
        raise NotImplementedError()


def question_feedback_prompt(
    eval_results: List[EvalResultWithAns], fn_name: Optional[str] = None
) -> str:
    prompt = ""
    func_name = f"function{' ' + fn_name if fn_name is not None else ''}"

    num_true = 0
    for i, eval_result in enumerate(eval_results):
        is_true = eval_result.get_score() == 1.0
        prompt += f"Your question {func_name} returned {'True' if is_true else 'False'} for problem {i}.\n\n"
        if is_true is True:
            num_true += 1

    if num_true == len(eval_results):
        prompt += f"Your question {func_name} returned True for all the problems.\n\n"
    elif num_true == 0:
        prompt += f"Your question {func_name} returned False for all the problems.\n\n"
    else:
        prompt += f"Your question {func_name} returned True for {num_true} problems and returned False for {len(eval_results)-num_true}.\n\n"

    return prompt
