import os
from pathlib import Path

from dotenv import load_dotenv

from lita.core.action import Action
from lita.core.llm import LLM
from lita.core.memory import Memory
from lita.core.protos import Message, Role
from lita.core.tools import Tool
from lita.core.utils import PrettyPrinter

load_dotenv()


class AgentBase:
    # https://lilianweng.github.io/posts/2023-06-23-agent/
    # agent key components: LLM, Memory, Tools, Planning, Action

    llm: LLM = None
    memory: Memory = None
    tools: list[Tool] = None
    # Planning is a tool by current design
    action: Action = None

    # runtime settings
    run_id: str | None = None

    def __init__(
        self,
        model: str | None = None,
        system: str | None = None,
        tools: list[Tool] | None = None,
        max_tokens: int = 4096,
        temperature: float = 0.7,
        run_id: str | None = None,
        verbose: bool = False,
        debug: bool = False,
        **kwargs,
    ):
        if not model:
            model = os.getenv("MODEL")
        assert model, "model name must be provided or set MODEL in environment variable"
        self.model = model

        self.max_tokens = max_tokens
        self.temperature = temperature

        self.max_num_turns = kwargs.get("max_num_turns", 50)
        self.timeout = kwargs.get("timeout", 180)
        self.max_retries = kwargs.get("max_retries", 3)

        if tools:
            self.tools = tools
            self.action = Action(tools=tools)

        self.llm = LLM(self.model, **kwargs)

        self.run_id = run_id
        self.memory = Memory(id=self.run_id)
        if system:
            self.memory.add_message(role=Role.SYSTEM, content=system)
        self.verbose = verbose
        self.debug = debug
        self.printer = PrettyPrinter() if verbose or debug else None

    def call_llm_with_tools(self, prompt: str) -> None:
        """process a prompt and handle tool calls in a loop"""

        self.memory.add_message(role=Role.USER, content=prompt)

        for _ in range(self.max_num_turns):
            result: Message = self.llm.call(
                self.memory.get_messages(),
                [t.schema() for t in self.tools] if self.tools else None,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                timeout=self.timeout,
                max_retries=self.max_retries,
            )
            if result is None:
                raise ValueError("LLM call failed, no result returned")

            self.memory.add_message(result)
            if result.tool_calls:
                for tool_call in result.tool_calls:
                    if self.verbose or self.debug:
                        self.printer.print_tool_call(tool_call)
                        if self.debug:
                            self.printer.prompt_continue()
                    name = tool_call.function.name
                    arguments = tool_call.function.arguments
                    tool_call_output = self.action.execute(name=name, arguments=arguments)
                    if name == "summarize":
                        # reset memory, only keep system and user messages in the head and add the summary content
                        self.memory.reset_summarization()
                    self.memory.add_message(
                        role=Role.TOOL,
                        content=tool_call_output,
                        tool_call_id=tool_call.id,
                    )
                    if self.verbose or self.debug:
                        self.printer.print_message("tool output", tool_call_output)
                        self.printer.print_separator()

                    if name == "finish":
                        return
            else:
                # No tool calls, directly returns
                if self.verbose or self.debug:
                    self.printer.print_message(result.role, result.content)
                return

    def run(self, prompt: str) -> None:
        try:
            self.call_llm_with_tools(prompt)
        except Exception as e:
            raise ValueError(f"Agent Sys Error: {e}") from None

    def save(self, save_dir: str) -> str:
        if not Path(save_dir).exists():
            Path(save_dir).mkdir(parents=True)
        save_file = Path(save_dir) / f"{self.run_id}.jsonl"
        self.memory.save(save_file)

        return str(save_file)

    def save_tools_schema(self, save_file: str) -> str:
        import json

        with open(save_file, "w") as f:
            tools_schema = [t.schema() for t in self.tools]
            f.write(json.dumps(tools_schema, indent=2))
        return str(save_file)
