from colmind.agents.agent import LLMAgent
from pathlib import Path
from javascript import require
import colmind.utils as U

import dspy
import time
import re


OUTPUT_EXPLAIN = """
    Explain (if applicable): Are there any steps missing in your plan? Why does the code not complete the task? What does the chat log and execution error imply?
"""

OUTPUT_EXPLAIN_BELIEF = """
    Explain (if applicable): Are there any steps missing in your plan? Why does the code not complete the task? What does the task beliefs imply? Is there any contradiction between the code and your beliefs?
"""

OUTPUT_PLAN = """
How to complete the task step by step. You should pay attention to Inventory since it tells what you have. The task completeness check is also based on your final inventory.
"""

OUTPUT_PLAN_BELIEF = """
How to complete the task step by step. You should pay attention to your perception beliefs since it tells what you have. The task completeness check is based on your final inventory, that is summarized in the perception beliefs.
"""

OUTPUT_CODE = """
1) Write an async function taking the bot as the only argument.
2) Reuse the above useful programs as much as possible.
    - Use `mineBlock(bot, name, count)` to collect blocks. Do not use `bot.dig` directly.
    - Use `craftItem(bot, name, count)` to craft items. Do not use `bot.craft` or `bot.recipesFor` directly.
    - Use `smeltItem(bot, name count)` to smelt items. Do not use `bot.openFurnace` directly.
    - Use `placeItem(bot, name, position)` to place blocks. Do not use `bot.placeBlock` directly.
    - Use `killMob(bot, name, timeout)` to kill mobs. Do not use `bot.attack` directly.
3) Your function will be reused for building more complex functions. Therefore, you should make it generic and reusable. You should not make strong assumption about the inventory (as it may be changed at a later time), and therefore you should always check whether you have the required items before using them. If not, you should first collect the required items and reuse the above useful programs.
4) Functions in the "Code from the last round" section will not be saved or executed. Do not reuse functions listed there.
5) Anything defined outside a function will be ignored, define all your variables inside your functions.
6) Call `bot.chat` to show the intermediate progress.
7) Use `exploreUntil(bot, direction, maxDistance, callback)` when you cannot find something. You should frequently call this before mining blocks or killing mobs. You should select a direction at random every time instead of constantly using (1, 0, 1).
8) `maxDistance` should always be 32 for `bot.findBlocks` and `bot.findBlock`. Do not cheat.
9) Do not write infinite loops or recursive functions.
10) Do not use `bot.on` or `bot.once` to register event listeners. You definitely do not need them.
11) Name your function in a meaningful way (can infer the task from the name).
"""

CODE_PREFIX = """
You should only respond in the format as described below:
RESPONSE FORMAT:

Explain: ...
Plan:
1) ...
2) ...
3) ...
...
Code:
```javascript
// helper functions (only if needed, try to avoid them)
...
// main function after the helper functions
async function yourMainFunctionName(bot) {
// ...
}
```
"""

class Action(dspy.Signature):
    """Generate actions."""
    procedural_memory: str = dspy.InputField(desc="useful programs written with Mineflayer APIs")
    code_from_last_round: str = dspy.InputField(desc="code generated for executing the previous action")
    execution_error: str = dspy.InputField(desc="execution error associated with the code generated for the previous action")
    chat_log: str = dspy.InputField()
    biome: str = dspy.InputField()
    time_of_day: str = dspy.InputField()
    nearby_blocks: list = dspy.InputField()
    nearby_entities: dict = dspy.InputField()
    health: float = dspy.InputField()
    hunger: float = dspy.InputField()
    position: dict = dspy.InputField()
    equipment: str = dspy.InputField()
    inventory: list = dspy.InputField()
    task: str = dspy.InputField(desc="description of the task that the agent tries to solve")
    context: str = dspy.InputField(desc="information about solving the task")
    critique: str = dspy.InputField(desc="critique of the previous action")



    explain = dspy.OutputField(desc=OUTPUT_EXPLAIN)
    plan = dspy.OutputField(desc=OUTPUT_PLAN)
    code = dspy.OutputField(desc=OUTPUT_CODE, prefix=CODE_PREFIX)

class BeliefAction(dspy.Signature):
    """Generate actions grounded in beliefs."""
    desire: str = dspy.InputField(desc="what the agent wants to achieve. the goal we are trying to reach")
    episodic_memory: str = dspy.InputField(desc="summary of past experiences at solving similar tasks. can be used to avoid repeating the same mistakes.")
    programs: str = dspy.InputField(desc="useful programs or subroutines written with Mineflayer APIs that can be reused for the code")
    task_beliefs = dspy.InputField(desc="set of beliefs that indicate how to solve the task.")
    perception_beliefs = dspy.InputField(desc="set of beliefs that encapsulates information about the external environment, possible execution errors and informative chat logs.")
    interaction_beliefs = dspy.InputField(desc="set of beliefs gathered by communication with another agent that offers information necessary to solve the task.")
    partner_beliefs = dspy.InputField(desc="set of beliefs about the mental model of one or more agents")
    past_critique = dspy.InputField(desc="internal critique of the previous action. take the critique in consideration to improve.")

    explain = dspy.OutputField(desc=OUTPUT_EXPLAIN_BELIEF)
    plan = dspy.OutputField(desc=OUTPUT_PLAN_BELIEF)
    code = dspy.OutputField(desc=OUTPUT_CODE, prefix=CODE_PREFIX)

class ActionAgent(LLMAgent):
    def __init__(
        self,
        name: str,
        llm: str,
        temperature: float,
        request_timeout: int,
        resume: bool,
        ckpt_dir: str,
        chat_log: bool,
        execution_error: bool,
        logger
    ):
        super().__init__(name, llm, temperature, logger)
        self.request_timeout = request_timeout
        self.path = Path(f"{ckpt_dir}/{name}")
        self.chat_log = chat_log
        self.execution_error = execution_error

        if not self.path.exists():
            self.path.mkdir(parents=True, exist_ok=True)
        if resume:
            self.chest_memory = U.load_json(self.path / "chest_memory.json")
        else:
            self.chest_memory = {}

    def __call__(self, context: dict):
        generate_action = dspy.Predict(BeliefAction)
        action = generate_action(
            desire=context["task"],
            episodic_memory=context["episodic_memory"],
            programs=context["procedural_memory"],
            task_beliefs=context["task_beliefs"],
            perception_beliefs=context["perception_beliefs"],
            interaction_beliefs=context["interaction_beliefs"],
            partner_beliefs=context["partner_beliefs"],
            past_critique=context["critique"]
        )
        return action

    def parse_action(self, action: str):
        retry = 3
        error = None
        while retry > 0:
            try:
                babel = require("@babel/core")
                babel_generator = require("@babel/generator").default

                code_pattern = re.compile(r"```(?:javascript|js)(.*?)```", re.DOTALL)
                code = "\n".join(code_pattern.findall(action))
                parsed = babel.parse(code)
                functions = []
                assert len(list(parsed.program.body)) > 0, "No functions found"
                for i, node in enumerate(parsed.program.body):
                    if node.type != "FunctionDeclaration":
                        continue
                    node_type = ("AsyncFunctionDeclaration" if node["async"] else "FunctionDeclaration")
                    functions.append({"name": node.id.name, "type": node_type, "body": babel_generator(node).code, "params": list(node["params"]),})
                # find the last async function
                main_function = None
                for function in reversed(functions):
                    if function["type"] == "AsyncFunctionDeclaration":
                        main_function = function
                        break
                assert (main_function is not None), "No async function found. Your main function must be async."
                assert (len(main_function["params"]) == 1 and main_function["params"][0].name == "bot"), f"Main function {main_function['name']} must take a single argument named 'bot'"
                program_code = "\n\n".join(function["body"] for function in functions)
                exec_code = f"await {main_function['name']}(bot);"
                return {"program_code": program_code, "program_name": main_function["name"], "exec_code": exec_code}
            except Exception as e:
                retry -= 1
                error = e
                self.logger.warning(e)
                time.sleep(1)
        return f"Error parsing action response (before program execution): {error}"

    def update_chest_memory(self, chests):
        for position, chest in chests.items():
            if position in self.chest_memory:
                if isinstance(chest, dict):
                    self.chest_memory[position] = chest
                if chest == "Invalid":
                    self.logger.info(f"Action Agent removing chest {position}: {chest}")
                    self.chest_memory.pop(position)
            else:
                if chest != "Invalid":
                    self.logger.info(f"Action Agent saving chest {position}: {chest}")
                    self.chest_memory[position] = chest
        U.dump_json(self.chest_memory, self.path / "chest_memory.json")

    def render_chest_observation(self):
        chests = []
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, dict) and len(chest) > 0:
                chests.append(f"{chest_position}: {chest}")
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, dict) and len(chest) == 0:
                chests.append(f"{chest_position}: Empty")
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, str):
                assert chest == "Unknown"
                chests.append(f"{chest_position}: Unknown items inside")
        assert len(chests) == len(self.chest_memory)
        if chests:
            chests = "\n".join(chests)
            return f"Chests:\n{chests}\n\n"
        else:
            return f"Chests: None\n\n"


    def summarize_chatlog(self, events):
        def filter_item(message: str):
            craft_pattern = r"I cannot make \w+ because I need: (.*)"
            craft_pattern2 = (
                r"I cannot make \w+ because there is no crafting table nearby"
            )
            mine_pattern = r"I need at least a (.*) to mine \w+!"
            if re.match(craft_pattern, message):
                return re.match(craft_pattern, message).groups()[0]
            elif re.match(craft_pattern2, message):
                return "a nearby crafting table"
            elif re.match(mine_pattern, message):
                return re.match(mine_pattern, message).groups()[0]
            else:
                return ""

        chatlog = set()
        for event_type, event in events:
            if event_type == "onChat":
                item = filter_item(event["onChat"])
                if item:
                    chatlog.add(item)
        return "I also need " + ", ".join(chatlog) + "." if chatlog else ""


    def restore(self):
        # TODO: implement this
        pass



if __name__ == "__main__":
    import os
    os.environ["OPENAI_API_KEY"] = "ADD_KEY_HERE"
    action_agent = ActionAgent(name="action_agent", llm="openai/gpt-4o-mini", temperature=0.7, request_timeout=120, resume=False, ckpt_dir="./junk", chat_log=False, execution_error=False, logger=None)
    context = dict(
        procedural_memory="",
        code_from_last_round = "print('hello world')",
        execution_error = None,
        chat_log =  [
            {"role": "user", "content": "What should I do?"},
            {"role": "assistant", "content": "Gather wood and build a shelter."},
        ],
        biome = "forest",
        time_of_day = "day",
        nearby_blocks = [
            {"name": "oak_log", "position": [10, 64, 12]},
            {"name": "dirt", "position": [11, 63, 12]},
            {"name": "grass_block", "position": [10, 63, 13]},
            {"name": "leaves", "position": [10, 67, 12]}
        ],
        nearby_entities = [
             {"name": "sheep", "position": [15, 64, 18]},
             {"name": "cow", "position": [12, 65, 20]}
        ],
        health = 20,
        hunger = 18,
        position = [10, 64, 10],
        equipment = {"mainhand": "wooden_axe", "offhand": None, "head": None, "chest": None, "legs": None, "feet": None},
        inventory = [
            {"name": "wooden_axe", "quantity": 1},
            {"name": "oak_log", "quantity": 16},
            {"name": "dirt", "quantity": 5}
        ],
        task = "Build a shelter",
        context = "I have gathered some wood and need to find a suitable location for my shelter.",
        critique = "Consider building near a water source for farming."
    )
    action = action_agent(context)
    result = action_agent.parse_action(action["code"])
    print(result["program_code"])
