import re
import time
from colmind.agents.llm import customLLM, get_llm

import colmind.utils as U
from javascript import require

from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage

from colmind.prompts import load_prompt
from colmind.control_primitives_context import load_control_primitives_context
from termcolor import colored

class BeliefActionAgent:
    def __init__(
        self,
        config,
        username,
        logger,
        resume=False,
        chat_log=True,
        execution_error=True,
    ):

        self.config = config
        self.chat_log = chat_log
        self.execution_error = execution_error

        U.f_mkdir(f"{config['ckpt_dir']}/action")
        if resume:
            self.logger.info(f"Loading Action Agent from {config['ckpt_dir']}/action")
            self.chest_memory = U.load_json(f"{config['ckpt_dir']}/action/chest_memory.json")
        else:
            self.chest_memory = {}

        self.llm = get_llm(config["parameters"]["llm"], config["parameters"]["temperature"])
        self.username = username
        self.logger = logger

    def update_chest_memory(self, chests):
        for position, chest in chests.items():
            if position in self.chest_memory:
                if isinstance(chest, dict):
                    self.chest_memory[position] = chest
                if chest == "Invalid":
                    self.logger.info(
                        f"Action Agent removing chest {position}: {chest}"
                    )
                    self.chest_memory.pop(position)
            else:
                if chest != "Invalid":
                    self.logger.info(f"Action Agent saving chest {position}: {chest}")
                    self.chest_memory[position] = chest
        U.dump_json(self.chest_memory, f"{self.config['ckpt_dir']}/action/chest_memory.json")

    def render_chest_observation(self):
        chests = []
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, dict) and len(chest) > 0:
                chests.append(f"{chest_position}: {chest}")
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, dict) and len(chest) == 0:
                chests.append(f"{chest_position}: Empty")
        for chest_position, chest in self.chest_memory.items():
            if isinstance(chest, str):
                assert chest == "Unknown"
                chests.append(f"{chest_position}: Unknown items inside")
        assert len(chests) == len(self.chest_memory)
        if chests:
            chests = "\n".join(chests)
            return f"Chests:\n{chests}\n\n"
        else:
            return f"Chests: None\n\n"

    def render_system_message(self, skills=[]):
        system_template = load_prompt("action_template")
        # FIXME: Hardcoded control_primitives
        base_skills = [
            "exploreUntil",
            "mineBlock",
            "craftItem",
            "placeItem",
            "smeltItem",
            "killMob",
            "useChest",
            "mineflayer",
        ]
        programs = "\n\n".join(load_control_primitives_context(base_skills) + skills)
        response_format = load_prompt("action_response_format")
        system_message_prompt = SystemMessagePromptTemplate.from_template(
            system_template
        )
        system_message = system_message_prompt.format(
            programs=programs, response_format=response_format
        )
        assert isinstance(system_message, SystemMessage)
        return system_message

    def render_human_message(
        self, *, events, code="", task="", context="", critique="", interaction_beliefs="", memory_beliefs=""
    ):
        chat_messages = []
        error_messages = []
        damage_messages = []
        assert events[-1][0] == "observe", "Last event must be observe"
        for i, (event_type, event) in enumerate(events):
            if event_type == "onChat":
                chat_messages.append(event["onChat"])
            elif event_type == "onError":
                error_messages.append(event["onError"])
            elif event_type == "onDamage":
                damage_messages.append(event["onDamage"])
            elif event_type == "observe":
                biome = event["status"]["biome"]
                time_of_day = event["status"]["timeOfDay"]
                voxels = event["voxels"]
                entities = event["status"]["entities"]
                health = event["status"]["health"]
                hunger = event["status"]["food"]
                position = event["status"]["position"]
                equipment = event["status"]["equipment"]
                inventory_used = event["status"]["inventoryUsed"]
                inventory = event["inventory"]
                agentsChat = {}
                playerChat = {}
                if "chatExtract" in event:
                    for chat in event["chatExtract"]:
                        if chat["username"] == "SpecialBottle": # the minecraft username of the player
                            if chat["username"] not in playerChat:
                                playerChat[chat["username"]] = [chat["message"]]
                            else:
                                playerChat[chat["username"]].append(chat["message"])


                        # if chat["username"] != self.bot_username:
                        if chat["username"] not in agentsChat:
                            agentsChat[chat["username"]] = [chat["message"]]
                        else:
                            agentsChat[chat["username"]].append(chat["message"])

                assert i == len(events) - 1, "observe must be the last event"


        observation = ""

        # ----------------  Percept ----------------

        observation += "Percept:\n"



        if code:
            observation += f"Code from the last round:\n{code}\n\n"
        else:
            observation += f"Code from the last round: No code in the first round\n\n"

        if self.execution_error:
            if error_messages:
                error = "\n".join(error_messages)
                observation += f"Execution error:\n{error}\n\n"
            else:
                observation += f"Execution error: No error\n\n"

        if self.chat_log:
            if chat_messages:
                chat_log = "\n".join(chat_messages)
                observation += f"Chat log: {chat_log}\n\n"
            else:
                observation += f"Chat log: None\n\n"

        observation += f"Biome: {biome}\n\n"

        observation += f"Time: {time_of_day}\n\n"

        if voxels:
            observation += f"Nearby blocks: {', '.join(voxels)}\n\n"
        else:
            observation += f"Nearby blocks: None\n\n"

        if entities:
            nearby_entities = [
                k for k, v in sorted(entities.items(), key=lambda x: x[1])
            ]
            observation += f"Nearby entities (nearest to farthest): {', '.join(nearby_entities)}\n\n"
        else:
            observation += f"Nearby entities (nearest to farthest): None\n\n"

        observation += f"Health: {health:.1f}/20\n\n"

        observation += f"Hunger: {hunger:.1f}/20\n\n"

        observation += f"Position: x={position['x']:.1f}, y={position['y']:.1f}, z={position['z']:.1f}\n\n"

        observation += f"Equipment: {equipment}\n\n"

        if inventory:
            observation += f"Inventory ({inventory_used}/36): {inventory}\n\n"
        else:
            observation += f"Inventory ({inventory_used}/36): Empty\n\n"

        if not (
            task == "Place and deposit useless items into a chest"
            or task.startswith("Deposit useless items into the chest at")
        ):
            observation += self.render_chest_observation()

        if agentsChat:
            observation += f"Chat Extract: {agentsChat}\n\n"


        if critique:
            observation += f"Critique: {critique}\n\n"
        else:
            observation += f"Critique: None\n\n"



        perception_beliefs = observation

        observation += "$$$\n\n"
        # ----------------  Desires ----------------
        observation += f"Desire:\n\n"
        observation += f"-Task: {task}\n\n"

        # ----------------  Beliefs ----------------

        observation += f"Beliefs:\n\n"

        observation += f"-Task-related beliefs\n"

        if context:
            observation += f"--Personal:\n{context}\n\n"
        else:
            observation += "--Personal: None\n\n"

        interaction_beliefs = self.interaction_agent.get_partner_beliefs()
        observation += f"Interaction beliefs:\n"
        observation += f"{interaction_beliefs}\n\n"

        if memory_beliefs:
            observation += f"Memory:\n {memory_beliefs}\n"


        assert perception_beliefs != observation, "Perception beliefs must be different from action message"
        time.sleep(5)

        observation += f"Perception beliefs:\n{self.interaction_agent.get_beliefs_from_perception(perception=perception_beliefs)}\n"


        observation = colored(observation, self.interaction_agent.color)
        return HumanMessage(content=observation)

    def process_ai_message(self, message):
        assert isinstance(message, AIMessage)

        retry = 3
        error = None
        while retry > 0:
            try:
                babel = require("@babel/core")
                babel_generator = require("@babel/generator").default



                code_pattern = re.compile(r"(?:```(?:javascript|js)\s*(.*?)```)|(?:Code:\s*(async function.*?)(?=\n\n|$))", re.DOTALL)

                content = message.content
                # replace ... with empty string
                content = content.replace("...", "")

                code_matches = code_pattern.findall(content)

                # Process the matches to extract non-empty strings
                code = "\n".join(filter(None, sum(code_matches, ())))

                parsed = babel.parse(code)

                functions = []
                assert len(list(parsed.program.body)) > 0, "No functions found"
                for i, node in enumerate(parsed.program.body):
                    if node.type != "FunctionDeclaration":
                        continue
                    node_type = (
                        "AsyncFunctionDeclaration"
                        if node["async"]
                        else "FunctionDeclaration"
                    )
                    functions.append(
                        {
                            "name": node.id.name,
                            "type": node_type,
                            "body": babel_generator(node).code,
                            "params": list(node["params"]),
                        }
                    )
                # find the last async function
                main_function = None
                for function in reversed(functions):
                    if function["type"] == "AsyncFunctionDeclaration":
                        main_function = function
                        break
                assert (
                    main_function is not None
                ), "No async function found. Your main function must be async."
                assert (
                    len(main_function["params"]) == 1
                    and main_function["params"][0].name == "bot"
                ), f"Main function {main_function['name']} must take a single argument named 'bot'"
                program_code = "\n\n".join(function["body"] for function in functions)
                exec_code = f"await {main_function['name']}(bot);"
                return {
                    "program_code": program_code,
                    "program_name": main_function["name"],
                    "exec_code": exec_code,
                }
            except Exception as e:
                retry -= 1
                error = e
                time.sleep(1)



        return f"Error parsing action response (before program execution): {error}"

    def summarize_chatlog(self, events):
        def filter_item(message: str):
            craft_pattern = r"I cannot make \w+ because I need: (.*)"
            craft_pattern2 = (
                r"I cannot make \w+ because there is no crafting table nearby"
            )
            mine_pattern = r"I need at least a (.*) to mine \w+!"
            if re.match(craft_pattern, message):
                return re.match(craft_pattern, message).groups()[0]
            elif re.match(craft_pattern2, message):
                return "a nearby crafting table"
            elif re.match(mine_pattern, message):
                return re.match(mine_pattern, message).groups()[0]
            else:
                return ""

        chatlog = set()
        for event_type, event in events:
            if event_type == "onChat":
                item = filter_item(event["onChat"])
                if item:
                    chatlog.add(item)
        return "I also need " + ", ".join(chatlog) + "." if chatlog else ""
