import json
import os
from typing import List, Optional, Tuple

from PIL import Image

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResultWithAns
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.arc.grid_repr import list_format
from llm_mcts.prompts.arc.special_prompt import IMAGE_TOKEN
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.tasks.arc.task import ARCProblem
from llm_mcts.tasks.arc.visualize_arc import plot_abstract_reasoning


class YuichiPrompt(PromptTemplate):
    def __init__(self, prompt_config: PromptConfig, problem: ARCProblem):
        self.prompt_config = prompt_config
        self.with_image = prompt_config.with_image
        self.problem = problem

    def initial_prompt(self) -> str | List[str | Image.Image]:
        _prompt, images = initial_prompt_yuichi(self.prompt_config, self.problem)
        initial_prompt = system_prompt_v1 + "\n\n" + _prompt

        # Split the initial prompt by the IMAGE_TOKEN and interleave images with text parts
        if self.with_image:
            assert initial_prompt.count(IMAGE_TOKEN) == len(images)
            split_prompt = initial_prompt.split(IMAGE_TOKEN)
            return_prompt_list = [split_prompt[0]]
            for i in range(len(images)):
                return_prompt_list.append(images[i])
                return_prompt_list.append(split_prompt[i + 1])
            return return_prompt_list
        else:
            return initial_prompt

    def feedback_prompt(
        self,
        action: Action,
        eval_results: Optional[List[EvalResultWithAns]],
        generation_result: GenerationResult,
    ) -> str:
        llm_answer_python_code = generation_result.parse_python_code()
        if eval_results is None:
            return "Your previous code didn't work as expected due to error or invalid format."

        match action:
            case "transform":
                return transform_feedback_prompt_yuichi(
                    problem=self.problem,
                    eval_results=eval_results,
                    pycode=llm_answer_python_code,
                )
            case _:
                raise NotImplementedError(
                    f"feedback_prompt not implemented for action {action}"
                )

    def add_next_action_instruction(
        self, action: Action, next_prompt: GenerationRequest
    ) -> GenerationRequest:

        if action == "transform":
            last_user_msg = next_prompt.messages[-1]
            assert last_user_msg.role == "user"

            # If the prompt.messages contains assistant message, it means that the prompt is not the first turn
            # len(prompt.messages) == 1 is not restricted to the first turn when the messages contain image prompts
            is_first_turn = True
            for msg in next_prompt.messages:
                if msg.role == "assistant":
                    is_first_turn = False
                    break
            if is_first_turn:
                task_prompt = task_prompt_initial
            else:
                task_prompt = task_prompt_feedback

            if isinstance(last_user_msg.content, str):
                last_user_msg.content += task_prompt
            elif isinstance(last_user_msg.content, list):
                # if the prompt contains image tokens, the last message is a list
                last_user_msg.content.append(task_prompt)
            else:
                raise ValueError(f"Unknown message type: {type(last_user_msg.content)}")
            return next_prompt
        else:
            raise NotImplementedError(
                f"next_action instruction not implemented for {action}"
            )


system_prompt_v1 = """You are tasked with solving an ARC (Abstraction and Reasoning Corpus) Challenge problem. 

You will be given some number of paired example inputs and outputs. The outputs were produced by applying a transformation rule to the inputs. Your task is to determine the transformation rule and implement it in code.

The inputs and outputs are each "grids" represented as 2D lists of integers in Python. A grid is a rectangular matrix of integers between 0 and 9 (inclusive). Each number in the grid corresponds to a specific color, as follows: black: 0, blue: 1, red: 2, green: 3, yellow: 4, grey: 5, pink: 6, orange: 7, purple: 8, brown: 9.

The transformation only needs to be unambiguous and applicable to the example inputs and the additional input. It doesn't need to work for all possible inputs.
"""


# The following list is ordered by priority.
HINT_PATH_LIST = [
    # Public eval
    "./data/transformation_rule/rule_001_public_eval/project007_exp025_public_eval/anthropic.claude-3-5-sonnet-20240620-v1:0/",
    "./data/transformation_rule/rule_001_public_eval/project007_exp025_public_eval/gpt-4o-2024-08-06/",
    "./data/transformation_rule/rule_001_public_eval/project007_exp025_public_eval/gemini-1.5-pro-002/",
    "./data/transformation_rule/rule_001_public_eval/project007_exp020_public_eval/anthropic.claude-3-5-sonnet-20240620-v1:0/",
    "./data/transformation_rule/rule_001_public_eval/project007_exp020_public_eval/gpt-4o-2024-08-06/",
    "./data/transformation_rule/rule_001_public_eval/project007_exp020_public_eval/gemini-1.5-pro-002/",
    # Training
    "./data/transformation_rule/rule_001/project007_exp025/anthropic.claude-3-5-sonnet-20240620-v1:0/",
    "./data/transformation_rule/rule_001/project007_exp025/gpt-4o-2024-08-06/",
    "./data/transformation_rule/rule_001/project007_exp020/anthropic.claude-3-5-sonnet-20240620-v1:0/",
    "./data/transformation_rule/rule_001/project007_exp020/gpt-4o-2024-08-06/",
]

initial_prompt_pj004_exp205 = """I will provide you with several training demonstrations. Each demonstration consists of an Input and an Output pair. The Output is generated from the Input based on a common transformation rule across all the pairs.

Identify and describe the common transformation rule that applies to all the provided Input-Output pairs.
Once the rule is identified, implement a Python `transform` function that applies the rule. This function should take the Input as its argument and return the corresponding Output.

## Demonstrations:

{demos}

## Hints of the transformation rule:

{hint}
"""

task_prompt_initial = """
## Your task

Write a Python function that transforms the input to the output.

First, describe the transformation rule in precise detail, clearly and concisely, so that it can be directly translated into Python code.

Then, implement the transformation rule in Python code which works for all the examples.

### Transformation Rule:

When describing the transformation rule, consider the following:
- Describe the input pattern and the output pattern.
- What are the specific changes or operations that are applied to the Input to generate the Output?
- Step by step breakdown of the transformation process, all of which are easy to implement in Python code.

[Please describe the function here in as much detail as possible.]

### Python function which can transform the input to the expected output:

When implementing the transform function:

- Break out additional helper functions as needed to handle specific tasks.
- Ensure that the function works without errors for all provided inputs and can generalize to similar cases.
- Step-by-step breakdown of the transformation process based on the description provided.
- Ensure your code is efficient, readable, and well-commented.
- The use of `while` loops is prohibited. 

```python
import numpy as np

# main transform function which can transform the input to the expected output 
def transform(grid_list: list[list[int]]) -> list[list[int]]:
    grid = np.array(grid_list)
    
    # Write your function here
    # Do not use while loops
    
    return grid.tolist()
```

Don't write anything after the python code.
"""

task_prompt_feedback = """
## Your task

Write a Python function that transforms the input to the output.

First, analyze the results of the previous examples and correct your mistake.

Second, describe the transformation rule again in precise detail, clearly and concisely, so that it can be directly translated into Python code.

Finally, implement the transformation rule in Python code which works for all the examples.

### Analysis of the previous examples:

[Please describe the analysis of the previous examples here in as much detail as possible.]

### Transformation Rule:

When describing the transformation rule, consider the following:
- Describe the input pattern and the output pattern.
- What are the specific changes or operations that are applied to the Input to generate the Output?
- Step by step breakdown of the transformation process, all of which are easy to implement in Python code.

[Please describe the function here in as much detail as possible.]

### Python function which can transform the input to the expected output:

When implementing the transform function:

- Break out additional helper functions as needed to handle specific tasks.
- Ensure that the function works without errors for all provided inputs and can generalize to similar cases.
- Step-by-step breakdown of the transformation process based on the description provided.
- Ensure your code is efficient, readable, and well-commented.
- The use of `while` loops is prohibited. 

```python
import numpy as np

# main transform function which can transform the input to the expected output 
def transform(grid_list: list[list[int]]) -> list[list[int]]:
    grid = np.array(grid_list)
    
    # Write your function here
    # Do not use while loops
    
    return grid.tolist()
```

Don't write anything after the python code.
"""


def make_demos(problem: ARCProblem, with_image: bool) -> Tuple[str, List[Image.Image]]:
    template = ""
    images = []
    for i, demo in enumerate(problem.demos):
        input_example = str(demo["input"]).replace("],", "],\n")
        output_example = str(demo["output"]).replace("],", "],\n")
        template += f"Input {i}:\n```python\n{input_example}\n```\n\n"
        template += f"Output {i}:\n```python\n{output_example}\n```\n\n"
        if with_image:
            template += f"Input {i} Image:\n\n{IMAGE_TOKEN}\n\n"
            template += f"Output {i} Image:\n\n{IMAGE_TOKEN}\n\n"
            images.append(plot_abstract_reasoning(demo["input"]))
            images.append(plot_abstract_reasoning(demo["output"]))
    return template, images


def initial_prompt_yuichi(
    prompt_config: PromptConfig, problem: ARCProblem
) -> Tuple[str, List[Image.Image]]:
    if prompt_config.initial_prompt_type == "yuichi_pj004_exp205":
        demos, images = make_demos(problem, prompt_config.with_image)
        hint = None
        for hint_root in HINT_PATH_LIST:
            try:
                hint_path = os.path.join(hint_root, f"{problem.label}.json")
                with open(hint_path, "r") as f:
                    hint = json.load(f)["transformation_rule"]
                    if hint is not None:
                        break
            except FileNotFoundError:
                print(f"# Hint file not found: {hint_path}")
                continue
        if hint is None:
            print(f"# No hint is provided for {problem.label}")
            hint = "No hint is provided."
        return initial_prompt_pj004_exp205.format(demos=demos, hint=hint), images
    else:
        raise ValueError(
            f"Unknown initial prompt type: {prompt_config.initial_prompt_type}"
        )


def transform_feedback_prompt_yuichi(
    problem: ARCProblem, eval_results: List[EvalResultWithAns], pycode: Optional[str]
) -> str:
    prompt = ""
    num_correct = 0
    for i, eval_result in enumerate(eval_results):
        is_correct = eval_result.get_score() == 1
        output = eval_result.answer
        prompt += f"# Example {i}\n\n"
        if is_correct is True:
            prompt += "Result: Correct\n\n"
            num_correct += 1
        else:
            prompt += f"""
Result: Wrong

Input:
{list_format(problem.demos[i]["input"])}
Expected Output:
{list_format(problem.demos[i]["output"])}
Your Output:
{list_format(output)}

"""

    if num_correct == len(eval_results):
        prompt += "# Summary\n\nYour solution is correct for all the problems!\n\n"
    else:
        prompt += f"# Summary\n\nYour solution is correct for {num_correct} problems among {len(eval_results)}!\n\n"

    # We also show transform function's result on additional inputs
    if pycode is None:
        prompt += (
            "Your `transform` function was malformed, so please fix it accordingly.\n\n"
        )
    else:
        prompt += "Also, here are the outputs of your `transform` function on additional inputs. Please check if your `transform` worked on additional inputs as intended, and correct your mistake in your next turns.\n\n"
        outputs = problem.run_transform_on_tests(pycode)
        for i, eval_result in enumerate(outputs):
            output = eval_result.answer
            prompt += f"# Transformed output on Additional Input {i}\n\n"
            if output is None:
                prompt += (
                    f"Your `transform` function is invalid for Additional Input {i}\n\n"
                )
            else:
                prompt += f"""Input:
{list_format(problem.tests[i]['input'])}

Your Output:
{list_format(output)}

"""

    return prompt
