from copy import deepcopy
from typing import List, Optional

from loguru import logger
from pydantic import BaseModel

from vita.agent.base import (
    LocalAgent,
    ValidAgentInputMessage,
    is_valid_agent_history_message,
)
from vita.data_model.message import (
    APICompatibleMessage,
    AssistantMessage,
    Message,
    MultiToolMessage,
    SystemMessage,
    UserMessage,
)
from vita.data_model.tasks import Task
from vita.environment.tool import Tool, as_tool
from vita.utils.llm_utils import generate
from vita.utils.utils import get_now, get_weekday
from vita.memory import MemoryManager


class LLMAgentState(BaseModel):
    """The state of the agent."""

    system_messages: list[SystemMessage]
    messages: list[APICompatibleMessage]


class LLMAgent(LocalAgent[LLMAgentState]):
    """
    An LLM agent that can be used to solve a task.
    """

    def __init__(
        self,
        tools: List[Tool],
        domain_policy: str,
        llm: Optional[str] = None,
        llm_args: Optional[dict] = None,
        time=None,
        enable_think: bool = False,
        language: str = None,
        enable_memory: bool = True,
        memory_llm: Optional[str] = None,
        memory_llm_args: Optional[dict] = None
    ):
        """
        Initialize the LLMAgent.
        """
        super().__init__(tools=tools, domain_policy=domain_policy)
        self.llm = llm
        self.llm_args = deepcopy(llm_args) if llm_args is not None else {}
        self.time = time + " " + get_weekday(time, language)
        self.enable_think = enable_think
        self.enable_memory = enable_memory
        if self.enable_memory:
            # Use memory_llm and memory_llm_args if provided, otherwise fall back to agent's llm and llm_args
            memory_llm_to_use = memory_llm if memory_llm is not None else self.llm
            memory_llm_args_to_use = memory_llm_args if memory_llm_args is not None else self.llm_args
            self.memory_manager = MemoryManager(llm=memory_llm_to_use, llm_args=memory_llm_args_to_use)
            self.previous_memory = '无历史记忆'
            self.window_messages = []  # Internal window to manage message state

    @property
    def system_prompt(self) -> str:
        if self.time is not None:
            return self.domain_policy.format(
                time=self.time
            )
        return self.domain_policy.format(
            time=get_now("%Y-%m-%d %H:%M:%S")
        )

    def get_init_state(
        self, message_history: Optional[list[Message]] = None
    ) -> LLMAgentState:
        """Get the initial state of the agent.

        Args:
            message_history: The message history of the conversation.

        Returns:
            The initial state of the agent.
        """
        if message_history is None:
            message_history = []
        assert all(is_valid_agent_history_message(m) for m in message_history), (
            "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent."
        )
        
        # Initialize window_messages with message_history if memory is enabled
        if self.enable_memory:
            self.window_messages = message_history.copy()
        
        return LLMAgentState(
            system_messages=[SystemMessage(role="system", content=self.system_prompt)],
            messages=message_history,
        )

    def generate_next_message(
        self, message: ValidAgentInputMessage, state: LLMAgentState
    ) -> tuple[AssistantMessage, LLMAgentState]:
        """
        Respond to a user or tool message.
        """
        if isinstance(message, MultiToolMessage):
            state.messages.extend(message.tool_messages)
            # Also add to window_messages for memory management
            if self.enable_memory:
                self.window_messages.extend(message.tool_messages)
        else:
            state.messages.append(message)
            # Also add to window_messages for memory management
            if self.enable_memory:
                self.window_messages.append(message)
        
        # Check if we need to compress memory using window_messages
        # Start compressing from the 20th message (after 10 turns), then every 10 messages
        enhanced_system_messages = None
        if self.enable_memory and len(self.window_messages) >= 20:
            # Get the first 10 messages to compress from window_messages
            messages_to_compress = self.window_messages[:10]
            
            # Create a dialogue chunk from these messages
            dialogue_chunk = self._create_dialogue_chunk(messages_to_compress)
            
            # Extract memory using the memory manager
            if hasattr(self, 'memory_manager') and self.memory_manager:
                # Extract updated memory using previous_memory
                updated_memory = self.memory_manager.extract_memory(
                    chunk=dialogue_chunk,
                    previous_memory=self.previous_memory
                )
                
                # Update the stored previous_memory
                self.previous_memory = updated_memory
                print(f'updated_memory: {updated_memory}')

                enhanced_system_content = state.system_messages[0].content + f"\n\n# 历史对话信息: {updated_memory}"
                enhanced_system_messages = [
                    SystemMessage(role="system", content=enhanced_system_content)
                ]
            
            # Remove the compressed messages from window_messages
            # Find the first user message from the 10th message backwards
            keep_from_index = 0
            for i in range(9, -1, -1):  # Check from 9th to 0th message
                if hasattr(self.window_messages[i], 'role') and (self.window_messages[i].role == 'user' or  self.window_messages[i].role == 'assistant'):
                    keep_from_index = i
                    break
            
            # Keep messages from the first user message onwards
            self.window_messages = self.window_messages[keep_from_index:]
            print(f'window_messages长度: {len(self.window_messages)}')
        
        # Use window_messages for generation if memory is enabled, otherwise use state.messages
        system_messages = enhanced_system_messages if enhanced_system_messages else state.system_messages
        if self.enable_memory:
            messages = system_messages + self.window_messages
        else:
            messages = system_messages + state.messages
            
        assistant_message = generate(
            model=self.llm,
            tools=self.tools,
            messages=messages,
            enable_think=self.enable_think,
            **self.llm_args,
        )
        state.messages.append(assistant_message)
        
        # Also add assistant message to window_messages if memory is enabled
        if self.enable_memory:
            self.window_messages.append(assistant_message)
            
        return assistant_message, state

    def _create_dialogue_chunk(self, messages) -> str:
        """
        Create a dialogue chunk from a list of messages for memory compression.
        
        Args:
            messages: List of messages to convert to dialogue format
            
        Returns:
            Formatted dialogue string
        """
        dialogue_lines = []
        for msg in messages:
            if hasattr(msg, 'role') and hasattr(msg, 'content') and msg.content:
                if msg.role == 'user':
                    dialogue_lines.append(f"用户: {msg.content}")
                elif msg.role == 'assistant':
                    dialogue_lines.append(f"助理: {msg.content}")
                elif msg.role == 'tool':
                    dialogue_lines.append(f"工具: {msg.content}")
        
        return "\n".join(dialogue_lines)

    def set_seed(self, seed: int):
        """Set the seed for the LLM."""
        if self.llm is None:
            raise ValueError("LLM is not set")
        cur_seed = self.llm_args.get("seed", None)
        if cur_seed is not None:
            logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}")
        self.llm_args["seed"] = seed


SYSTEM_PROMPT_SOLO = """
# 环境
- 当前时间：{time}

# 工具使用规范
- 根据任务需求和提供的信息，确定需要调用的工具及参数
- 按照逻辑顺序执行必要的工具调用来完成任务
- 参考Precondition和Postcondition确保任务正确完成

# 任务要求
- 你需要根据提供的完整任务描述和用户信息，按照顺序一次性完成用户的需求
- 任务中涉及到的先下单后取消订单、先下单后修改订单等操作，请严格按照任务描述中的要求顺序执行
- 所有必要的信息都已在任务描述中提供，包括用户偏好、约束条件等
- 执行过程中不能与用户进行交互
- 默认对于需要用户确认的逻辑，都认为用户已经确认
- 在完成用户所有的需求以后，生成 '###STOP###' 标记来结束对话
""".strip()


class LLMSoloAgent(LocalAgent[LLMAgentState]):
    """
    An LLM agent that can be used to solve a task without any interaction with the customer.
    The task need to specify a ticket format.
    """

    def __init__(
        self,
        tools: List[Tool],
        domain_policy: str,
        llm: Optional[str] = None,
        llm_args: Optional[dict] = None,
        time=None,
        enable_think: bool = False,
        language: str = None
    ):
        """
        Initialize the LLMAgent.
        """
        super().__init__(tools=tools, domain_policy=domain_policy)
        self.llm = llm
        self.llm_args = deepcopy(llm_args) if llm_args is not None else {}
        self.time = time + " " + get_weekday(time, language)
        self.enable_think = enable_think

    @property
    def system_prompt(self) -> str:
        if self.time is not None:
            return SYSTEM_PROMPT_SOLO.format(
                time=self.time
            )
        return SYSTEM_PROMPT_SOLO.format(
            time=get_now("%Y-%m-%d %H:%M:%S")
        )

    @classmethod
    def is_stop(cls, message: AssistantMessage) -> bool:
        """Check if the message is a stop message."""
        if message.content is None:
            return False
        return cls.STOP_TOKEN in message.content

    def get_init_state(
        self, message_history: Optional[list[Message]] = None
    ) -> LLMAgentState:
        """Get the initial state of the agent.

        Args:
            message_history: The message history of the conversation.

        Returns:
            The initial state of the agent.
        """
        if message_history is None:
            message_history = []
        assert all(is_valid_agent_history_message(m) for m in message_history), (
            "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent."
        )
        return LLMAgentState(
            system_messages=[SystemMessage(role="system", content=self.system_prompt)],
            messages=message_history,
        )

    def generate_next_message(
        self, message: Optional[ValidAgentInputMessage], state: LLMAgentState
    ) -> tuple[AssistantMessage, LLMAgentState]:
        """
        Respond to a user or tool message.
        """
        # if isinstance(message, UserMessage):
        #     raise ValueError("LLMSoloAgent does not support user messages.")
        if isinstance(message, MultiToolMessage):
            state.messages.extend(message.tool_messages)
        elif message is None:
            assert len(state.messages) == 0, "Message history should be empty"
        else:
            state.messages.append(message)
        messages = state.system_messages + state.messages
        assistant_message = generate(
            model=self.llm,
            tools=self.tools,
            messages=messages,
            tool_choice="auto",
            enable_think=self.enable_think,
            **self.llm_args,
        )
        if not assistant_message.is_tool_call() and not self.is_stop(assistant_message):
            raise ValueError("LLMSoloAgent only supports tool calls before ###STOP###.")
        state.messages.append(assistant_message)
        return assistant_message, state

    def set_seed(self, seed: int):
        """Set the seed for the LLM."""
        if self.llm is None:
            raise ValueError("LLM is not set")
        cur_seed = self.llm_args.get("seed", None)
        if cur_seed is not None:
            logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}")
        self.llm_args["seed"] = seed