import asyncio
import json
import logging
import os
import uuid
from typing import Any, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    convert_to_openai_messages,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
from verl.experimental.agent_loop.tool_parser import ToolParser
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
tool_free =  os.environ.get('TOOL_FREE', False)
class MaxTokenExceededError(Exception):
    pass
class ChatModel(BaseChatModel):
    model_name: str = Field(alias="model")
    client: AsyncLLMServerManager
    tokenizer: Any
    max_tokens: int
    tool_parser: str = "hermes"
    max_parallel_calls: int = 1
    temperature: float = 1.0
    top_p: float = 1.0
    repetition_penalty: float = 1.0
    def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
        formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
        system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
        kwargs["system_prompt"] = system_prompt
        return self.bind(tools=formatted_tools, **kwargs)
    def with_structured_output(
        self,
        schema: dict | type,
        *,
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
        raise NotImplementedError
    def _generate(
        self,
        messages: list[BaseMessage],
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        raise NotImplementedError
    async def _agenerate(
        self,
        messages: list[BaseMessage],
        stop: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
        sampling_params = {
            "temperature": self.temperature,
            "top_p": self.top_p,
            "repetition_penalty": self.repetition_penalty,
        }
        if "sampling_params" in kwargs:
            sampling_params.update(kwargs["sampling_params"])
        response_ids = await self.client.generate(
            request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, run_baseline=False, baseline_name="google/gemini-3-pro-preview"
        )
        message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])
    @property
    def _llm_type(self) -> str:
        return self.model_name
    async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
        assert messages[-1].type in ["human", "tool"], (
            f"Last message must be human or tool, but got {messages[-1].type}"
        )
        loop = asyncio.get_running_loop()
        if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
            prompt_ids = await loop.run_in_executor(
                None,
                lambda: self.tokenizer.apply_chat_template(
                    convert_to_openai_messages(messages),
                    tools=kwargs.get("tools"),
                    add_generation_prompt=True,
                    tokenize=True,
                ),
            )
            return str(uuid.uuid4()), prompt_ids, []
        for i in range(len(messages) - 1, -1, -1):
            if messages[i].type == "ai":
                break
        assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
        assert "response_mask" in messages[i].response_metadata, (
            "Last message must have response_mask in response_metadata"
        )
        tool_responses = convert_to_openai_messages(messages[i + 1 :])
        tool_response_ids = await loop.run_in_executor(
            None,
            lambda messages=tool_responses: self.tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True
            ),
        )
        tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
        if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
            raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
        request_id = messages[i].response_metadata.pop("request_id")
        prompt_ids = messages[i].response_metadata.pop("prompt_ids")
        response_mask = messages[i].response_metadata.pop("response_mask")
        prompt_ids += tool_response_ids
        response_mask += [0] * len(tool_response_ids)
        return request_id, prompt_ids, response_mask
    async def _postprocess(
        self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
    ) -> AIMessage:
        prompt_ids += response_ids
        response_mask += [1] * len(response_ids)
        tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
        content, function_calls = await tool_parser.extract_tool_calls(response_ids)
        tool_calls, invalid_tool_calls = [], []
        for function_call in function_calls:
            try:
                args = json.loads(function_call.arguments)
                if not isinstance(args, dict):
                    raise ValueError(f"Invalid json tool arguments: {args}")
                tool_call = ToolCall(
                    args=args,
                    name=function_call.name,
                    id=str(uuid.uuid4()),
                )
                tool_calls.append(tool_call)
            except json.JSONDecodeError as e:
                logger.warning(f"Invalid json tool arguments: {e}")
                tool_call = InvalidToolCall(
                    args=function_call.arguments,
                    name=function_call.name,
                    error=f"Invalid json tool arguments: {e}",
                )
                invalid_tool_calls.append(tool_call)
            except ValueError as e:
                print("Args:", function_call.arguments)
                logger.warning(f"Invalid json tool arguments: {e}")
                tool_call = InvalidToolCall(
                    args=function_call.arguments,
                    name=function_call.name,
                    error=f"Invalid json tool arguments: {e}",
                )
                invalid_tool_calls.append(tool_call)
        message = AIMessage(
            content=content,
            tool_calls=tool_calls[: self.max_parallel_calls],
            invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
            response_metadata={
                "request_id": request_id,
                "prompt_ids": prompt_ids,
                "response_mask": response_mask,
            },
        )
        return message
class TruncateStructuredTool(StructuredTool):
    tool_response_truncate_side: str
    max_tool_response_length: int
    async def _arun(
        self,
        *args: Any,
        config: RunnableConfig,
        **kwargs: Any,
    ) -> Any:
        tool_response = await super()._arun(*args, config=config, **kwargs)
        tool_response = str(tool_response)
        if len(tool_response) > self.max_tool_response_length:
            if self.tool_response_truncate_side == "left":
                tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
            elif self.tool_response_truncate_side == "right":
                tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
            else:
                length = self.max_tool_response_length // 2
                tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
        return tool_response
def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
    for i in range(len(messages) - 1, -1, -1):
        if messages[i].type != "tool":
            break
    last_message = messages[i]
    assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
    assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
    assert "response_mask" in last_message.response_metadata, (
        "Last message must have response_mask in response_metadata"
    )
    num_turns = 0
    for i in range(len(messages)):
        if messages[i].type == "system":
            continue
        if i == 0 or messages[i].type != messages[i - 1].type:
            num_turns += 1
    prompt_ids = last_message.response_metadata["prompt_ids"]
    response_mask = last_message.response_metadata["response_mask"]
    response_ids = prompt_ids[-len(response_mask) :]
    prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
    output = AgentLoopOutput(
        prompt_ids=prompt_ids,
        response_ids=response_ids[:response_length],
        response_mask=response_mask[:response_length],
        num_turns=num_turns,
        metrics={},
    )
    return output