import copy
import json
from typing import List, Optional

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

from mctextworld.models.base_model import BasePlanningModel
from mctextworld.utils import language_action_to_subgoal

prompt_decomposed_plan = """For an item name, you need to make a plan, by selecting one among provided options.
"""


plan_prompt = """For a given game screen and task, you need to make a plan with the help of <visual info> and <craft graph>.
<visual info>: Consists of the following aspects: health bar, food bar, hotbar, environment. Based on the current visual information, you need to consider whether prequel steps needed to ensure that agent can complete the task.
<craft graph>: a top-down list of all the tools and materials needed to complete the task. 
I will give you an example of planning under specific visual conditions as follow:

[Example]
{example}

[Your turn]
Here is a game screen and task, you MUST output <task planning> in JSON format. Remember <task planning> MUST output in JSON format.
<image_placeholder>
<task>: {task}
<visual info>: {visual}
<craft graph>: {graph}
<task planning>
"""


reflection_prompt = """You are a professional game analyst.
For a given <item_name> and <inventory>, you need to analyze why <plan> failed to get the item.
I will give you examples of analysis as follow.

[Example]
<item_name>: wooden_pickaxe
<inventory>: {'stick': 4, 'planks': 4, 'crafting_table': 1}
<plan>: smelt wooden_pickaxe
<failure_analysis>
{"analysis": "You failed because you cannot smelt a wooden_pickaxe. You should craft it instead."}

[Example]
<item_name>: stone_pickaxe
<inventory>: {'stick': 4, 'planks': 4, 'crafting_table': 1}
<plan>: craft stone_pickaxe
<failure_analysis>
{"analysis": "You failed because you do not have enough cobblestones."}

"""

class PlanningModel(BasePlanningModel):

    def __init__(self, model_path: str = "Qwen/Qwen2.5-VL-7B-Instruct", device_id: int = 1,
                 system_prompt: Optional[str] = None) -> None:
        self.device = f"cuda:{device_id}"

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
        )
        self.model = self.model.to(self.device).eval()
        self.processor = AutoProcessor.from_pretrained(model_path)


    def reflect_on_failure(self, item_name: str, language_action_str: str, inventory_before_action: dict):
        prompt = copy.deepcopy(reflection_prompt)
        prompt += f"""[Your turn]
Here is <item_name>, <inventory> and <plan>, you MUST output <failure_analysis> concisely in JSON format.

<item_name>: {item_name}
<inventory>: {str(inventory_before_action)}
<plan>: {language_action_str}
<failure_analysis>
"""

        # print(f"====\n{prompt}\n====")

        llm_output = self._inference(prompt, None)
        return llm_output


    def plan_with_reflection(
            self,
            waypoint: str,
            images: str | List[str],
            similar_wp_sg_dict: dict | None = None,
            reflection_list: List[str] | None = None
    ):
        images = None
        prompt = copy.deepcopy(prompt_decomposed_plan)
        prompt += "I will give you examples of which plans are needed to achieve an item, just for reference.\n"
        if similar_wp_sg_dict is not None and len(similar_wp_sg_dict) > 0:
            for similar_wp, sg_str in similar_wp_sg_dict.items():
                prompt += f"""[Example]
<item name>
{similar_wp}
<task planning>
{sg_str}

"""
        else:
            # Give only one example for formatting
            tmp_sg_str = json.dumps({"task": "craft wooden_pickaxe", "goal": ["wooden_pickaxe", 1]})
            prompt += f"""[Example]
<item name>
wooden_pickaxe
<task planning>
{tmp_sg_str}

"""

        # each reflection in the reflection_list is a python dictionary
        if reflection_list is not None and len(reflection_list) > 0:
            prompt += "Here are some analyses on previous failed plans for this item.\n"
            for reflection in reflection_list:
                prompt += f"""[Analysis]
{str(reflection)}
"""

        if "log" not in waypoint:
            # waypoint is not logs
            language_actions = ["dig down and mine", "craft", "smelt"]
            language_action_options = [f"{action} {waypoint}" for action in language_actions]
        else:
            # waypoint is logs
            language_actions = ["chop a tree", "craft logs", "smelt logs"]
            language_action_options = language_actions


        language_subgoal_options = []
        i = 1
        for action in language_action_options:
            _, subgoal = language_action_to_subgoal(action, waypoint)
            language_subgoal_options.append(f"{i}. {subgoal}")
            i += 1
        options_str = "\n".join(language_subgoal_options)

        prompt += f"""
[Your turn]
Here is <item name>, you MUST select one from below <options>, to make <task planning>.
you MUST select one from below <options>. DO NOT MAKE A PLAN NOT IN <options>.

<options>
{options_str}

<item name>
{waypoint}
<task planning>
"""

        print(f"====\n{prompt}\n====")
        return self._inference(prompt, None), prompt


    def decomposed_plan(
        self,
        waypoint: str,
        images: str | List[str],
        similar_wp_sg_dict: dict | None = None,
        failed_sg_list_for_wp: List[str] | None = None,
    ):
        images = None
        prompt = copy.deepcopy(prompt_decomposed_plan)
        prompt += "I will give you examples of which plans are needed to achieve an item, just for reference.\n"
        if similar_wp_sg_dict is not None and len(similar_wp_sg_dict) > 0:
            for similar_wp, sg_str in similar_wp_sg_dict.items():
                prompt += f"""[Example]
<item name>
{similar_wp}
<task planning>
{sg_str}

"""
        else:
            # Give only one example for formatting
            tmp_sg_str = json.dumps({"task": "craft wooden_pickaxe", "goal": ["wooden_pickaxe", 1]})
            prompt += f"""[Example]
<item name>
wooden_pickaxe
<task planning>
{tmp_sg_str}

"""

        if "log" not in waypoint:
            # waypoint is not logs
            language_actions = ["dig down and mine", "craft", "smelt"]
            for failed_sg_str in failed_sg_list_for_wp:
                if "mine " in failed_sg_str:
                    # before removing "dig down and mine", check if the element is in the language_actions list.
                    if "dig down and mine" in language_actions:
                        language_actions.remove("dig down and mine")
                elif "craft " in failed_sg_str:
                    if "craft" in language_actions:
                        language_actions.remove("craft")
                elif "smelt " in failed_sg_str:
                    if "smelt" in language_actions:
                        language_actions.remove("smelt")
            language_action_options = [f"{action} {waypoint}" for action in language_actions]
            if len(language_action_options) == 0:
                language_action_options = [f"dig down and mine {waypoint}", f"craft {waypoint}", f"smelt {waypoint}"]
        else:
            # waypoint is logs
            language_actions = ["chop a tree", "craft logs", "smelt logs"]
            for failed_sg_str in failed_sg_list_for_wp:
                if "mine " in failed_sg_str or "chop " in failed_sg_str:
                    if "chop a tree" in language_actions:
                        language_actions.remove("chop a tree")
                elif "craft " in failed_sg_str:
                    if "craft logs" in language_actions:
                        language_actions.remove("craft logs")
                elif "smelt " in failed_sg_str:
                    if "smelt logs" in language_actions:
                        language_actions.remove("smelt logs")
            language_action_options = language_actions
            if len(language_action_options) == 0:
                language_action_options = [f"chop a tree", f"craft {waypoint}", f"smelt {waypoint}"]

        language_subgoal_options = []
        i = 1
        for action in language_action_options:
            _, subgoal = language_action_to_subgoal(action, waypoint)
            language_subgoal_options.append(f"{i}. {subgoal}")
            i += 1
        options_str = "\n".join(language_subgoal_options)

        prompt += f"""
[Your turn]
Here is <item name>, you MUST select one from below <options>, to make <task planning>.
you MUST select one from below <options>. DO NOT MAKE A PLAN NOT IN <options>.

<options>
{options_str}

<item name>
{waypoint}
<task planning>
"""

        # print(f"====\n{prompt}\n====")
        return self._inference(prompt, None), prompt

    def generate_hypothesis(self, original_hypothesis, item_name, inventory, topK_verified_items, topK_verified_recipes, all_reflections):

        def render_json_output(llm_output):
            llm_output = llm_output.replace("<required_items>:", "<required_items>")
            sep_str = "<required_items>"

            temp = llm_output.split(sep_str)[-1].strip()
            if "```json" in temp:
                temp = temp.split("```json")[1].strip().split("```")[0].strip()

            if "{{" in temp:
                temp = temp.replace("{{", "{").replace("}}", "}")

            r = temp.rfind("}")
            temp = temp[: r + 1]

            try:
                temp = json.loads(temp)
            except Exception as e:
                return llm_output, str(e)

            return temp, None

# reflection = {
#     "item_name": item_name,
#     "inventory": inventory_before_action,
#     "plan": action_str,
#     "failure_analysis": reflection
# }
        hypothesis_prompt = f"""
You are a professional game analyst. For a given <item_name>, you need to make <required_items> to get the item.
If you make <required_items> well, I will give you 1 $.

I will give you recent transitions.
"""
        for reflection in all_reflections:
            hypothesis_prompt += f"""[Failed example]
<item_name>: {reflection['item_name']}
<hypothesized_required_items>: {str(original_hypothesis)}
<inventory>: {str(reflection['inventory'])}
<action>: {reflection['plan']}
<success>: false
"""

        hypothesis_prompt += "I will give you learned items similar to <item_name>, and their validated required items, just for reference.\n"
        for verified_name, recipe in zip(topK_verified_items, topK_verified_recipes):
            recipe_dict = {'recipe': recipe}
            hypothesis_prompt += f"""[Success Example]
<item_name>:
{verified_name}
<required_items>:
{str(recipe_dict)}
"""

        hypothesis_prompt += f"""
[Your turn]
Here is <item_name>, you MUST output <required_items> to achieve {item_name} in JSON format. Remember <required_items> MUST be in JSON format.

<item_name>:
{item_name}
<required_items>:
"""

        print(f"====\n{hypothesis_prompt}\n====")

        llm_output = self._inference(hypothesis_prompt, None)
        hypothesized_recipe, error_msg = render_json_output(llm_output)
        print(f"Hypothesized recipe for {item_name}")
        print(f"====\n{hypothesized_recipe}\n====")
        return hypothesized_recipe, error_msg


    def _inference(self, instruction: str, images: str | List[str] = None) -> str:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": instruction},
                    {"type": "image", "image": images},
                ]
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        if images is None:
            inputs = self.processor(
                text=[text],
                padding=True,
                return_tensors="pt",
            )
        else:
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
        inputs = inputs.to(self.model.device)

        generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        return response[0]
   