# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
"""

import asyncio
import json
import logging
import os
import uuid
from typing import Any, Optional

from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    convert_to_openai_messages,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field

from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
from verl.experimental.agent_loop.tool_parser import ToolParser

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class MaxTokenExceededError(Exception):
    """Indicate that history chat messages + tool message exceeds LLM max_tokens."""

    pass


class ChatModel(BaseChatModel):
    model_name: str = Field(alias="model")
    """The name of the model"""

    client: AsyncLLMServerManager
    """AsyncLLM server manager"""

    tokenizer: Any
    """Tokenizer for the model"""

    max_tokens: int
    """Max tokens to generate"""

    tool_parser: str = "hermes"
    """Tool parser for the model"""

    max_parallel_calls: int = 1
    """Max parallel tool calls"""

    temperature: float = 1.0
    """Temperature for sampling"""

    top_p: float = 1.0
    """Top p for sampling"""

    repetition_penalty: float = 1.0
    """Repetition penalty for sampling"""

    def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
        """Bind tools to the model.

        Args:
            tools: Sequence of tools to bind to the model.

        Returns:
            A Runnable that returns a message.
        """
        formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]

        # used to remove system prompt prefix when encoding tool response
        system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
        kwargs["system_prompt"] = system_prompt

        return self.bind(tools=formatted_tools, **kwargs)

    def with_structured_output(
        self,
        schema: dict | type,
        *,
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
        """Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
        raise NotImplementedError

    def _generate(
        self,
        messages: list[BaseMessage],
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        raise NotImplementedError

    async def _agenerate(
        self,
        messages: list[BaseMessage],
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Asynchronously generate chat completion message.

        Args:
            messages (list[BaseMessage]): List of list of messages.
            stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
                first occurrence of any of these substrings. Defaults to None.

        Returns:
            ChatResult: Chat result.
        """
        request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)

        sampling_params = {
            "temperature": self.temperature,
            "top_p": self.top_p,
            "repetition_penalty": self.repetition_penalty,
        }
        if "sampling_params" in kwargs:
            sampling_params.update(kwargs["sampling_params"])

        response_ids = await self.client.generate(
            request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
        )

        message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model."""
        return self.model_name

    async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
        """Preprocess messages for chat completion.

        To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
        instead of messages list.

        But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
        (prompt_ids, response_mask) in lastest AIMessage.response_metadata.

        1. Encode ToolMessage to token ids.
        2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
        3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.

        Ref: https://python.langchain.com/docs/concepts/chat_history/

        Args:
            messages (list[BaseMessage]): List of messages.

        Returns:
            tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
        """
        # messages: [system], human, ai, human|tool, ai, human|tool, ...
        assert messages[-1].type in ["human", "tool"], (
            f"Last message must be human or tool, but got {messages[-1].type}"
        )
        loop = asyncio.get_running_loop()

        # Case 1: initial chat completion: [system], human
        if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
            prompt_ids = await loop.run_in_executor(
                None,
                lambda: self.tokenizer.apply_chat_template(
                    convert_to_openai_messages(messages),
                    tools=kwargs.get("tools"),
                    add_generation_prompt=True,
                    tokenize=True,
                ),
            )
            return str(uuid.uuid4()), prompt_ids, []

        # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
        for i in range(len(messages) - 1, -1, -1):
            if messages[i].type == "ai":
                break
        assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
        assert "response_mask" in messages[i].response_metadata, (
            "Last message must have response_mask in response_metadata"
        )

        # encode tool response
        tool_responses = convert_to_openai_messages(messages[i + 1 :])
        tool_response_ids = await loop.run_in_executor(
            None,
            lambda messages=tool_responses: self.tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True
            ),
        )
        tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]

        # stop generation if response length exceeds max response length
        if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
            raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")

        # append tool response to prompt
        request_id = messages[i].response_metadata.pop("request_id")
        prompt_ids = messages[i].response_metadata.pop("prompt_ids")
        response_mask = messages[i].response_metadata.pop("response_mask")
        prompt_ids += tool_response_ids
        response_mask += [0] * len(tool_response_ids)

        return request_id, prompt_ids, response_mask

    async def _postprocess(
        self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
    ) -> AIMessage:
        """Postprocess response_ids when chat completion is done.

        1. Decode response_ids, parse tool calls to AIMessage.
        2. Append response_ids to prompt_ids, and append 1 to response_mask.
        3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.

        Args:
            request_id (str): Unique request id.
            prompt_ids (list[int]): Input prompt token ids in this chat completion.
            response_mask (list[int]): Response mask before this chat completion.
            response_ids (list[int]): LLM generated token ids in this chat completion.

        Returns:
            AIMessage: Postprocessed message.
        """
        prompt_ids += response_ids
        response_mask += [1] * len(response_ids)

        tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
        content, function_calls = await tool_parser.extract_tool_calls(response_ids)

        tool_calls, invalid_tool_calls = [], []
        for function_call in function_calls:
            try:
                args = json.loads(function_call.arguments)
                if not isinstance(args, dict):
                    raise json.JSONDecodeError(f"Invalid json tool arguments: {args}")
                tool_call = ToolCall(
                    args=args,
                    name=function_call.name,
                    id=str(uuid.uuid4()),
                )
                tool_calls.append(tool_call)
            except json.JSONDecodeError as e:
                logger.warning(f"Invalid json tool arguments: {e}")
                tool_call = InvalidToolCall(
                    args=function_call.arguments,
                    name=function_call.name,
                    error=f"Invalid json tool arguments: {e}",
                )
                invalid_tool_calls.append(tool_call)

        message = AIMessage(
            content=content,
            tool_calls=tool_calls[: self.max_parallel_calls],
            invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
            response_metadata={
                "request_id": request_id,
                "prompt_ids": prompt_ids,
                "response_mask": response_mask,
            },
        )
        return message


class TruncateStructuredTool(StructuredTool):
    """Structured tool with response truncation."""

    tool_response_truncate_side: str
    """truncate side of tool response: left, middle, right"""

    max_tool_response_length: int
    """max length of tool response"""

    async def _arun(
        self,
        *args: Any,
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Any:
        tool_response = await super()._arun(*args, config=config, **kwargs)
        tool_response = str(tool_response)

        if len(tool_response) > self.max_tool_response_length:
            if self.tool_response_truncate_side == "left":
                tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
            elif self.tool_response_truncate_side == "right":
                tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
            else:
                length = self.max_tool_response_length // 2
                tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]

        return tool_response


def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
    """Convert messages to AgentLoopOutput.

    Args:
        messages (List[BaseMessage]): List of messages, last message must be assistant
            with response_metadata containing `prompt_ids` and `response_mask`.
        response_length (int): Max length of response.

    Returns:
        AgentLoopOutput: agent loop output trajectory used for training.
    """
    # skip last tool calls
    for i in range(len(messages) - 1, -1, -1):
        if messages[i].type != "tool":
            break
    last_message = messages[i]
    assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
    assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
    assert "response_mask" in last_message.response_metadata, (
        "Last message must have response_mask in response_metadata"
    )

    num_turns = 0
    for i in range(len(messages)):
        if messages[i].type == "system":
            continue
        # parallel tool calls are in single turn
        if i == 0 or messages[i].type != messages[i - 1].type:
            num_turns += 1

    prompt_ids = last_message.response_metadata["prompt_ids"]
    response_mask = last_message.response_metadata["response_mask"]

    response_ids = prompt_ids[-len(response_mask) :]
    prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]

    output = AgentLoopOutput(
        prompt_ids=prompt_ids,
        response_ids=response_ids[:response_length],
        response_mask=response_mask[:response_length],
        num_turns=num_turns,
        metrics={},
    )
    return output
