from llm import call_llm
from utils import load_json, save_json, parse_json
from prompts import IN_CONTEXT_MEMORY_UPDATE_PROMPT, get_in_context_hack_prompt
from loguru import logger
import os
import json
import shutil
from datetime import datetime
from copy import deepcopy
from backoff import on_exception, expo


class NaiveAgent:
    """Naive agent wrapper for LLMs interacting with the environment."""

    def __init__(self, llm_config):
        self.llm_config = llm_config
        self.reset()

    def reset(self):
        self.msg_history = []

    def act(self, obs):
        self.msg_history.append({"role": "user", "content": obs})
        response = call_llm(self.msg_history, self.llm_config)
        self.msg_history.append({"role": "assistant", "content": response})
        return response

    def add_msgs(self, msgs):
        self.msg_history.extend(msgs)

    def load_state(self, local_dir):
        self.msg_history = load_json(os.path.join(local_dir, "msg_history.json"))

    def save_state(self, local_dir):
        os.makedirs(local_dir, exist_ok=True)
        save_json(os.path.join(local_dir, "msg_history.json"), self.msg_history)

    def answer_question(self, question):
        msg = {"role": "user", "content": question}
        return call_llm(self.msg_history + [msg], self.llm_config, return_token_usage=True)


class InContextMemAgent:
    def __init__(self, config):
        self.config = config
        self.reset()

    def reset(self):
        self.in_context_memory = []
        self.local_msgs = []
        info_types = self.config.get("info_types", None)
        if info_types is None:
            self.memory_update_prompt = IN_CONTEXT_MEMORY_UPDATE_PROMPT
        else:
            logger.info(f"Using in-context memory update prompt with info types: {info_types}")
            self.memory_update_prompt = get_in_context_hack_prompt(info_types)
            logger.debug(f"memory_update_prompt: {self.memory_update_prompt}")

    def act(self, obs):
        new_msg = {"role": "user", "content": obs}
        sorted_memories = sorted(self.in_context_memory, key=lambda x: x["timestamp"])
        memories_str = "\n".join([f"- {entry['label']}: {entry['value']}" for entry in sorted_memories])
        system_prompt = f"You are a helpful AI. Respond according to memories of the user.\nUser memories ordered by time (earliest to latest):\n{memories_str}"
        messages = [{"role": "system", "content": system_prompt}] + self.local_msgs + [new_msg]
        response = call_llm(messages, self.config["llm_config"])
        new_response = {"role": "assistant", "content": response}
        self.add_msgs([new_msg, new_response])
        return response

    def load_state(self, local_dir):
        self.local_msgs = load_json(os.path.join(local_dir, "msg_history.json"))
        self.in_context_memory = load_json(os.path.join(local_dir, "in_context_memory.json"))

    def save_state(self, local_dir):
        os.makedirs(local_dir, exist_ok=True)
        save_json(os.path.join(local_dir, "msg_history.json"), self.local_msgs)
        save_json(os.path.join(local_dir, "in_context_memory.json"), self.in_context_memory)

    def answer_question(self, question):
        new_msg = {"role": "user", "content": question}
        # system prompt w/ memories
        sorted_memories = sorted(self.in_context_memory, key=lambda x: x["timestamp"])
        memories_str = "\n".join([f"- {entry['label']}: {entry['value']}" for entry in sorted_memories])
        system_prompt = f"You are a helpful AI. Respond according to memories of the user.\nUser memories ordered by time (earliest to latest):\n{memories_str}"
        messages = [{"role": "system", "content": system_prompt}] + self.local_msgs + [new_msg]
        return call_llm(messages, self.config["llm_config"], return_token_usage=True)

    def add_msgs(self, messages):
        # load interactions to update internal state
        assert len(messages) == 2, "Only support two-turn interactions in one batch"
        limit = self.config["agent_config"]["update_bsz"] + \
            self.config["agent_config"]["local_length"]
        self.local_msgs += messages
        if len(self.local_msgs) >= limit:
            # update memory
            update_bsz = self.config["agent_config"]["update_bsz"]
            msgs_to_insert, self.local_msgs = self.local_msgs[:update_bsz], self.local_msgs[update_bsz:]
            logger.trace(
                f"Inserting {len(msgs_to_insert)} messages into memory.\n{[msg for msg in msgs_to_insert if msg['role'] == 'user']}")
            self._update_memory(msgs_to_insert)

    def _update_memory(self, messages):
        current_memories = {entry["label"]: entry["value"] for entry in self.in_context_memory}
        current_memories_str = json.dumps(current_memories, indent=2, ensure_ascii=False)
        conversation_str = json.dumps(messages, indent=2, ensure_ascii=False)
        memory_prompt = self.memory_update_prompt.format(
            current_memories=current_memories_str, conversation=conversation_str
        )
        memory_updates = call_llm(
            [{"role": "user", "content": memory_prompt}], self.config["llm_config"], json=True)
        memory_updates = json.loads(memory_updates)
        timestamp = datetime.now().strftime(
            '%Y-%m-%dT%H:%M:%S.%f')[:-3]  # ISO 8601 format
        logger.trace(f"memory update {timestamp}: {memory_updates}")

        # added new entries
        for entry in self.in_context_memory:
            if entry["label"] in memory_updates:
                entry["value"] = memory_updates[entry["label"]]
                entry["timestamp"] = timestamp
                del memory_updates[entry["label"]]

        # updated existing entries
        for label, value in memory_updates.items():
            assert label not in current_memories
            self.in_context_memory.append(
                {"label": label, "value": value, "timestamp": timestamp})

        logger.trace(f"updated memory {timestamp}: {self.in_context_memory}")

    def set_prompts(self, prompts):
        """Set memory update prompt"""
        if "memory_update_prompt" in prompts:
            prompt = prompts['memory_update_prompt']
            
            # Use a more sophisticated approach that only escapes single braces
            import re

            # First, protect our placeholder patterns
            prompt = prompt.replace('{current_memories}', '___PLACEHOLDER_CURRENT___')
            prompt = prompt.replace('{conversation}', '___PLACEHOLDER_CONVERSATION___')
            
            # Escape single braces that aren't already escaped
            # This regex finds single { or } that aren't part of {{ or }}
            prompt = re.sub(r'(?<!\{)\{(?!\{)', '{{', prompt)  # { not preceded by { and not followed by {
            prompt = re.sub(r'(?<!\})\}(?!\})', '}}', prompt)  # } not preceded by } and not followed by }
            
            # Restore placeholders
            prompt = prompt.replace('___PLACEHOLDER_CURRENT___', '{current_memories}')
            prompt = prompt.replace('___PLACEHOLDER_CONVERSATION___', '{conversation}')

            self.memory_update_prompt = prompt


@on_exception(expo, Exception, max_tries=10)  # retry on failure
def insert_mem0(memory, batch, user_id, infer):
    """Insert a batch of messages into memory."""
    memory_log = memory.add(batch, user_id=user_id, infer=infer)
    logger.trace(f"Memory log: {memory_log}")


def format_mem0_memories(memories):
    def get_time(memory):
        if memory["updated_at"]:
            return memory["updated_at"][:19]
        return memory["created_at"][:19]
    sorted_memories = sorted(memories["results"], key=get_time)
    memories_str = "\n".join([
        f"- {entry['memory']}"
        for entry in sorted_memories
    ])
    return memories_str


class Mem0Agent:
    def __init__(self, config):
        from mem0 import Memory
        self.config = deepcopy(config)
        mem_dir = self.config.get("local_mem_dir")
        if os.path.exists(mem_dir):
            shutil.rmtree(mem_dir)
        os.makedirs(mem_dir, exist_ok=True)
        self.config["memory_config"]["vector_store"]["config"]["url"] = os.path.join(
            mem_dir, "mem.db")
        self.config["memory_config"]["history_db_path"] = os.path.join(
            mem_dir, "mem_hist.db")
        self.memory = Memory.from_config(self.config["memory_config"])
        self.reset()

    def reset(self):
        self.memory.reset()
        self.local_msgs = []

    def act(self, obs):
        new_msg = {"role": "user", "content": obs}
        # retrieve
        relevant_memories = self.memory.search(
            query=obs,
            user_id="USER",
            limit=self.config["agent_config"]["top_k"]
        )

        # system prompt w/ memories
        memories_str = format_mem0_memories(relevant_memories)
        logger.trace(f"Retrieved memories: {memories_str}")
        system_prompt = f"You are a helpful AI. Respond according to retrieved memories.\nRelevant user memories ordered by time (earliest to latest):\n{memories_str}"
        messages = [{"role": "system", "content": system_prompt}] + self.local_msgs + [new_msg]
        response = call_llm(messages, self.config["llm_config"])
        new_response = {"role": "assistant", "content": response}

        # record current interaction
        self.add_msgs(messages=[new_msg, new_response])

        return response

    def load_state(self, local_dir):
        from mem0 import Memory
        del self.memory
        shutil.rmtree(self.config["local_mem_dir"])
        os.makedirs(self.config["local_mem_dir"])
        shutil.copy2(
            os.path.join(local_dir, "mem.db"),
            self.config["memory_config"]["vector_store"]["config"]["url"]
        )
        shutil.copy2(
            os.path.join(local_dir, "mem_hist.db"),
            self.config["memory_config"]["history_db_path"]
        )
        self.memory = Memory.from_config(self.config["memory_config"])
        with open(os.path.join(local_dir, "msg_history.json"), "r") as f:
            self.local_msgs = json.load(f)

    def save_state(self, local_dir):
        os.makedirs(local_dir, exist_ok=True)
        shutil.copy2(self.config["memory_config"]
                     ["vector_store"]["config"]["url"], local_dir)
        shutil.copy2(self.config["memory_config"]
                     ["history_db_path"], local_dir)
        with open(os.path.join(local_dir, "msg_history.json"), "w") as f:
            json.dump(self.local_msgs, f, indent=2, ensure_ascii=False)

    def answer_question(self, question):
        new_msg = {"role": "user", "content": question}
        # retrieve
        relevant_memories = self.memory.search(
            query=question,
            user_id="USER",
            limit=self.config["agent_config"]["top_k"]
        )
        # system prompt w/ memories
        memories_str = format_mem0_memories(relevant_memories)
        logger.trace(f"Retrieved memories: {memories_str}")
        system_prompt = f"You are a helpful AI. Respond according to retrieved memories.\nRelevant user memories ordered by time (earliest to latest):\n{memories_str}"
        messages = [{"role": "system", "content": system_prompt}] + self.local_msgs + [new_msg]
        return call_llm(messages, self.config["llm_config"], return_token_usage=True)

    def add_msgs(self, messages):
        """Add messages."""
        assert len(messages) == 2, "Only support two-turn interactions in one batch"
        limit = self.config["agent_config"]["update_bsz"] + \
            self.config["agent_config"]["local_length"]
        infer = self.config["agent_config"]["enable_llm_mem_policy"]
        self.local_msgs += messages
        if len(self.local_msgs) >= limit:
            # update memory
            update_bsz = self.config["agent_config"]["update_bsz"]
            msgs_to_insert, self.local_msgs = self.local_msgs[:update_bsz], self.local_msgs[update_bsz:]
            logger.trace(
                f"Inserting {len(msgs_to_insert)} messages into memory.\n{[msg for msg in msgs_to_insert if msg['role'] == 'user']}")
            for msg in msgs_to_insert:
                if msg["role"] == "user":
                    msg["content"] = f"USER INPUT: " + msg["content"]
                elif msg["role"] == "assistant":
                    msg["content"] = f"ASSISTANT RESPONSE: " + msg["content"]
                else:
                    raise ValueError(f"Unknown message role: {msg['role']}")
            insert_mem0(self.memory, msgs_to_insert, user_id="USER", infer=infer)


def create_agent(agent_config, output_dir, item=None):
    match agent_config["type"]:
        case "naive":
            return NaiveAgent(agent_config["llm_config"])
        case "in-context":
            if item is not None:
                agent_config["info_types"] = list(item["state_schema"].keys())
            return InContextMemAgent(agent_config)
        case "mem0":
            local_mem_dir = os.path.join(output_dir, "latest_memories")
            return Mem0Agent(agent_config | {"local_mem_dir": local_mem_dir})
        case _:
            raise ValueError(f"Unknown agent type: {agent_config['type']}")
