from typing import List, Optional
import json
import os
import copy

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

from ours.util import language_action_to_subgoal

prompt_decomposed_plan = """For an item name, you need to make a plan using examples.
"""

#################### Our context-aware reasoning prompt ####################
description_prompt = """Given a Minecraft game image, describe nearby Minecraft objects, like tree, grass, cobblestone, etc.

[Example]
"There is a large tree with dark green leaves surrounding the area."
"The image shows a dark, cave-like environment in Minecraft. The player is digging downwards. There are no visible trees or grass in this particular view."
"The image shows a dark, narrow tunnel made of stone blocks. The player is digging downwards."

[Your turn]
Describe the given image, simply and clearly like the examples."""

context_aware_reasoning_prompt = """
Given <task> and <visual_description>, determine if the player needs intervention to achieve the goal. If intervention is needed, suggest a task that the player should perform.
I will give you examples.

[Example]
<task>: chop tree
<visual_description>: There is a large tree with dark green leaves surrounding the area.
<goal>: logs
<reasoning>:
{{
    "need_intervention": false,
    "thoughts": "The player can see a tree and can chop it down to get logs.",
    "task": "",
}}

[Example]
<task>: chop tree
<visual_description>: The image shows a dirt block in Minecraft. There is a tree in the image, but it is too far from here.
<goal>: logs
<reasoning>:
{{
    "need_intervention": true,
    "thoughts": "The player is far from trees. The player needs to move to the trees.",
    "task": "explore to find trees",
}}

[Example]
<task>: dig down to mine iron_ore
<visual_description>: The image shows a dark, narrow tunnel made of stone blocks. The player is digging downwards.
<goal>: iron_ore
<reasoning>:
{{
    "need_intervention": false,
    "thoughts": "The player is already digging down and is likely to find iron ore.",
    "task": "",
}}

[Your turn]
Here is the <task>, <visual_description>, and <goal>.
You MUST output the <reasoning> in JSON format.
<task>: {task}
<visual_description>: {visual_description}
<goal>: {goal}
<reasoning>:
"""


def is_path(path):
    if len(path) == 2:
        return True
    else:
        return False




reflection_prompt = """You are a professional game analyst.
For a given <item_name> and <inventory>, you need to analyze why <plan> failed to get the item.
I will give you examples of analysis as follow.

[Example]
<item_name>: wooden_pickaxe
<inventory>: {'stick': 4, 'planks': 4, 'crafting_table': 1}
<plan>: smelt wooden_pickaxe
<failure_analysis>
{"analysis": "You failed because you cannot smelt a wooden_pickaxe. You should craft it instead."}

[Example]
<item_name>: stone_pickaxe
<inventory>: {'stick': 4, 'planks': 4, 'crafting_table': 1}
<plan>: craft stone_pickaxe
<failure_analysis>
{"analysis": "You failed because you do not have enough cobblestones."}

"""



class PlanningModel:

    def __init__(self, model_path: str = "Qwen/Qwen2.5-VL-7B-Instruct", device_id: int = 0,
                 system_prompt: Optional[str] = None) -> None:
        self.device = f"cuda:{device_id}"

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
        )
        self.model = self.model.to(self.device).eval()
        self.processor = AutoProcessor.from_pretrained(model_path)

    def generate_hypothesis(self, original_hypothesis, item_name, inventory, topK_verified_items, topK_verified_recipes, all_reflections):

        def render_json_output(llm_output):
            llm_output = llm_output.replace("<required_items>:", "<required_items>")
            sep_str = "<required_items>"

            temp = llm_output.split(sep_str)[-1].strip()
            if "```json" in temp:
                temp = temp.split("```json")[1].strip().split("```")[0].strip()

            if "{{" in temp:
                temp = temp.replace("{{", "{").replace("}}", "}")

            r = temp.rfind("}")
            temp = temp[: r + 1]

            try:
                temp = json.loads(temp)
            except Exception as e:
                return llm_output, str(e)

            return temp, None

# reflection = {
#     "item_name": item_name,
#     "inventory": inventory_before_action,
#     "plan": action_str,
#     "failure_analysis": reflection
# }

        hypothesis_prompt = f"""
You are a professional game analyst. For a given <item_name>, you need to make <required_items> to get the item.
If you make <required_items> well, I will give you 1 $.

I will give you recent transitions.
"""

        for reflection in all_reflections:
            hypothesis_prompt += f"""[Failed example]
<item_name>: {reflection['item_name']}
<hypothesized_required_items>: {str(original_hypothesis)}
<inventory>: {str(reflection['inventory'])}
<action>: {reflection['plan']}
<success>: false
"""

        hypothesis_prompt += "I will give you learned items similar to <item_name>, and their validated required items, just for reference.\n"
        for verified_name, recipe in zip(topK_verified_items, topK_verified_recipes):
            recipe_dict = {'recipe': recipe}
            hypothesis_prompt += f"""[Success Example]
<item_name>:
{verified_name}
<required_items>:
{str(recipe_dict)}
"""

        hypothesis_prompt += f"""
[Your turn]
Here is <item_name>, you MUST output <required_items> to achieve {item_name} in JSON format. Remember <required_items> MUST be in JSON format.

<item_name>:
{item_name}
<required_items>:
"""

        print(f"====\n{hypothesis_prompt}\n====")

        llm_output = self._inference(hypothesis_prompt, None)
        hypothesized_recipe, error_msg = render_json_output(llm_output)
        print(f"Hypothesized recipe for {item_name}")
        print(f"====\n{hypothesized_recipe}\n====")
        return hypothesized_recipe, error_msg


    def reflect_on_failure(self, item_name: str, language_action_str: str, inventory_before_action: dict):
        prompt = copy.deepcopy(reflection_prompt)
        prompt += f"""[Your turn]
Here is <item_name>, <inventory> and <plan>, you MUST output <failure_analysis> concisely in JSON format.

<item_name>: {item_name}
<inventory>: {str(inventory_before_action)}
<plan>: {language_action_str}
<failure_analysis>
"""

        # print(f"====\n{prompt}\n====")

        llm_output = self._inference(prompt, None)
        return llm_output


    def plan_with_reflection(
            self,
            waypoint: str,
            image_path: str,
            similar_wp_sg_dict: dict | None = None,
            reflection_list: List[str] | None = None
    ):
        image_path = None
        prompt = copy.deepcopy(prompt_decomposed_plan)
        prompt += "I will give you examples of which plans are needed to achieve an item, just for reference.\n"
        if similar_wp_sg_dict is not None and len(similar_wp_sg_dict) > 0:
            for similar_wp, sg_str in similar_wp_sg_dict.items():
                prompt += f"""[Example]
<item name>
{similar_wp}
<task planning>
{sg_str}

"""
        else:
            # Give only one example for formatting
            tmp_sg_str = json.dumps({"task": "craft wooden_pickaxe", "goal": ["wooden_pickaxe", 1]})
            prompt += f"""[Example]
<item name>
wooden_pickaxe
<task planning>
{tmp_sg_str}

"""

        # each reflection in the reflection_list is a python dictionary
        if reflection_list is not None and len(reflection_list) > 0:
            prompt += "Here are some analyses on previous failed plans for this item.\n"
            for reflection in reflection_list:
                prompt += f"""[Analysis]
{str(reflection)}
"""

        if "log" not in waypoint:
            # waypoint is not logs
            language_actions = ["dig down and mine", "craft", "smelt"]
            language_action_options = [f"{action} {waypoint}" for action in language_actions]
        else:
            # waypoint is logs
            language_actions = ["chop a tree", "craft logs", "smelt logs"]
            language_action_options = language_actions


        language_subgoal_options = []
        i = 1
        for action in language_action_options:
            _, subgoal = language_action_to_subgoal(action, waypoint)
            language_subgoal_options.append(f"{i}. {subgoal}")
            i += 1
        options_str = "\n".join(language_subgoal_options)

        prompt += f"""
[Your turn]
Here is <item name>, you MUST select one from below <options>, to make <task planning>.
you MUST select one from below <options>. DO NOT MAKE A PLAN NOT IN <options>.

<options>
{options_str}

<item name>
{waypoint}
<task planning>
"""

        print(f"====\n{prompt}\n====")
        return self._inference(prompt, None), prompt


    def decomposed_plan(
        self,
        waypoint: str,
        image_path: str,
        similar_wp_sg_dict: dict | None = None,
        failed_sg_list_for_wp: List[str] | None = None,
    ):
        image_path = None
        prompt = prompt_decomposed_plan

        if similar_wp_sg_dict is not None and len(similar_wp_sg_dict) > 0:
            prompt += "I will give you examples of which plans are needed to achieve an item.\n"
            for similar_wp, sg_str in similar_wp_sg_dict.items():
                prompt += f"""[Example]
<item name>
{similar_wp}
<task planning>
{sg_str}

"""
        else:
            # similar waypoints are not available
            # That is, it does not use memory
            pass

        language_actions = ["dig down and mine", "craft", "smelt"]
        for failed_sg_str in failed_sg_list_for_wp:
            if "mine" in failed_sg_str:
                language_actions.remove("dig down and mine")
            elif "craft" in failed_sg_str:
                language_actions.remove("craft")
            elif "smelt" in failed_sg_str:
                language_actions.remove("smelt")
        language_action_options = [f"{action} {waypoint}" for action in language_actions]
        if len(language_action_options) == 0:
            language_action_options = [f"dig down and mine {waypoint}", f"craft {waypoint}", f"smelt {waypoint}"]

        language_subgoal_options = []
        i = 1
        for action in language_action_options:
            _, subgoal = language_action_to_subgoal(action, waypoint)
            language_subgoal_options.append(f"{i}. {subgoal}")
            i += 1
        options_str = "\n".join(language_subgoal_options)

        prompt += f"""
[Your turn]
Here is <item name>, you MUST output <task planning> in JSON format.
You can make <task planning> by selecting an option from below:
{options_str}

<item name>
{waypoint}
<task planning>
"""

        # print(f"====\n{prompt}\n====")
        return self._inference(prompt, None), prompt


    def context_aware_reasoning(
        self,
        task: str,
        goal: str,
        image_path: str,
    ):
        visual_description = self._inference(description_prompt, image_path)

        new_reasoning_prompt = context_aware_reasoning_prompt.format(
            task=task,
            visual_description=visual_description,
            goal=goal,
        )
        reasoning = self._inference(new_reasoning_prompt, None)
        return reasoning, visual_description


    def _inference(self, instruction: str, images: str | List[str] = None) -> str:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": instruction},
                    {"type": "image", "image": images},
                ]
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        if images is None:
            inputs = self.processor(
                text=[text],
                padding=True,
                return_tensors="pt",
            )
        else:
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
        inputs = inputs.to(self.model.device)

        generated_ids = self.model.generate(**inputs, max_new_tokens=512)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        return response[0]
