from typing import List, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from ..util.prompt import language_action_to_subgoal

from .base_model import BasePlanningModel

prompt_decomposed_plan = """For an item name, you need to make a plan using examples.
You can "dig down and mine" / "craft" / "smelt" item.
"""


plan_prompt = """For a given game screen and task, you need to make a plan with the help of <visual info> and <craft graph>.
<visual info>: Consists of the following aspects: health bar, food bar, hotbar, environment. Based on the current visual information, you need to consider whether prequel steps needed to ensure that agent can complete the task.
<craft graph>: a top-down list of all the tools and materials needed to complete the task. 
I will give you an example of planning under specific visual conditions as follow:

[Example]
{example}

[Your turn]
Here is a game screen and task, you MUST output <task planning> in JSON format. Remember <task planning> MUST output in JSON format.
<image_placeholder>
<task>: {task}
<visual info>: {visual}
<craft graph>: {graph}
<task planning>
"""


class PlanningModel(BasePlanningModel):

    def __init__(self, model_path: str = "Qwen/Qwen2.5-7B-Instruct", device_id: int = 1, system_prompt: Optional[str] = None) -> None:
        self.device = f"cuda:{device_id}"

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype="auto",
        )
        self.model = self.model.to(self.device).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)


    def decomposed_plan(
        self,
        waypoint: str,
        images: str | List[str],
        similar_wp_sg_dict: dict | None = None,
        failed_sg_list_for_wp: List[str] | None = None,
    ):
        images = None
        prompt = prompt_decomposed_plan

        if similar_wp_sg_dict is not None and len(similar_wp_sg_dict) > 0:
            prompt += "I will give you examples of which plans are needed to achieve an item.\n"
            for similar_wp, sg_str in similar_wp_sg_dict.items():
                prompt += f"""[Example]
<item name>
{similar_wp}
<task planning>
{sg_str}

"""
        else:
            # similar waypoints are not available
            # That is, it does not use memory
            pass

        if "log" not in waypoint:
            # waypoint is not logs
            language_actions = ["dig down and mine", "craft", "smelt"]
            for failed_sg_str in failed_sg_list_for_wp:
                if "mine" in failed_sg_str:
                    language_actions.remove("dig down and mine")
                elif "craft" in failed_sg_str:
                    language_actions.remove("craft")
                elif "smelt" in failed_sg_str:
                    language_actions.remove("smelt")
            language_action_options = [f"{action} {waypoint}" for action in language_actions]
            if len(language_action_options) == 0:
                language_action_options = [f"dig down and mine {waypoint}", f"craft {waypoint}", f"smelt {waypoint}"]
        else:
            # waypoint is logs
            language_actions = ["chop a tree", "craft logs", "smelt logs"]
            for failed_sg_str in failed_sg_list_for_wp:
                if "mine" in failed_sg_str or "chop" in failed_sg_str:
                    language_actions.remove("chop a tree")
                elif "craft" in failed_sg_str:
                    language_actions.remove("craft logs")
                elif "smelt" in failed_sg_str:
                    language_actions.remove("smelt logs")
            language_action_options = language_actions
            if len(language_action_options) == 0:
                language_action_options = [f"chop a tree", f"craft {waypoint}", f"smelt {waypoint}"]

        language_subgoal_options = []
        i = 1
        for action in language_action_options:
            _, subgoal = language_action_to_subgoal(action, waypoint)
            language_subgoal_options.append(f"{i}. {subgoal}")
            i += 1
        options_str = "\n".join(language_subgoal_options)

        prompt += f"""
[Your turn]
Here is <item name>, you MUST output <task planning> in JSON format.
You can make <task planning> by selecting an option from below:
{options_str}

<item name>
{waypoint}
<task planning>
"""

        print(f"====\n{prompt}\n====")
        return self._inference(prompt, None), prompt

    def _inference(self, instruction: str, images: str | List[str] = None) -> str:
        messages = [
            {"role": "user", "content": instruction},
        ]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)

        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=512,
            temperature=0.4
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return response
