# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
from abc import ABC, abstractmethod

import regex
from pydantic import BaseModel

from verl.utils.rollout_trace import rollout_trace_op

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class FunctionCall(BaseModel):
    arguments: str
    """
    The arguments to call the function with, as generated by the model in JSON
    format. Note that the model does not always generate valid JSON, and may
    hallucinate parameters not defined by your function schema. Validate the
    arguments in your code before calling your function.
    """

    name: str
    """The name of the function to call."""


class ToolParser(ABC):
    _registry: dict[str, type["ToolParser"]] = {}

    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer

    @abstractmethod
    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
        """Extract tool calls from the responses.

        Args:
            responses_ids (List[int]): The ids of the responses.

        Returns:
            Tuple[str, List[FunctionCall]]: Content and extracted tool calls.
        """
        raise NotImplementedError

    @classmethod
    def get_tool_parser(cls, name: str, tokenizer):
        if name not in cls._registry:
            raise ValueError(f"Unknown tool parser: {name}")
        return cls._registry[name](tokenizer)

    @classmethod
    def register(cls, name: str):
        def decorator(subclass: type[ToolParser]) -> type[ToolParser]:
            cls._registry[name] = subclass
            return subclass

        return decorator


@ToolParser.register("hermes")
class HermesToolParser(ToolParser):
    """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""

    def __init__(self, tokenizer) -> None:
        super().__init__(tokenizer)

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"
        self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)

    @rollout_trace_op
    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
        loop = asyncio.get_running_loop()
        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
        if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
            return text, []

        matches = self.tool_call_regex.findall(text)
        function_calls = []
        for match in matches:
            try:
                function_call = json.loads(match)
                name, arguments = function_call["name"], function_call["arguments"]
                function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
            except Exception as e:
                logger.error(f"Failed to decode tool call: {e}")

        # remaing text exclude tool call tokens
        content = self.tool_call_regex.sub("", text)

        return content, function_calls


@ToolParser.register("gpt-oss")
class GptOssToolParser(ToolParser):
    """
    Tool parser for gpt-oss model.
    Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py

    Args:
        tokenizer: The tokenizer to use.
    """

    def __init__(self, tokenizer) -> None:
        super().__init__(tokenizer)
        # check https://cookbook.openai.com/articles/openai-harmony for more details.
        self.cot_pattern = regex.compile(
            r"<\|start\|>assistant<\|channel\|>analysis<\|message\|>.*?<\|end\|>", regex.DOTALL
        )
        # <|start|>assistant may be pre-appended in prompts, so we need to remove it.
        self.partial_cot_pattern = regex.compile(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", regex.DOTALL)
        self.tool_call_pattern = regex.compile(
            r"<\|start\|>assistant<\|channel\|>[^<]* to=functions\.([^<]+) "
            r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>",
            regex.DOTALL,
        )

    @rollout_trace_op
    async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
        loop = asyncio.get_running_loop()
        # We need to keep special tokens for gpt-oss model for better tool call extraction.
        text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))
        # Need to remove padding tokens for better tool call extraction.
        text = text.replace(self.tokenizer.pad_token, "")
        # Need to reomve COT since COT may contain tool call tokens.But they are not valid tool calls.
        text = regex.sub(self.cot_pattern, "", text)
        text = regex.sub(self.partial_cot_pattern, "", text)

        # check if there are tool calls in the text by re.findall
        matches = regex.findall(self.tool_call_pattern, text)
        if not matches:
            return text, []

        function_calls = []
        for match in matches:
            try:
                name, arguments = match[0], match[1]
                # don't check if arguments is valid JSON and leave it to client
                function_calls.append(FunctionCall(name=name, arguments=arguments))
            except Exception as e:
                logger.error(f"Failed to decode tool call: {e}")

        # remaing text exclude tool call tokens
        content = regex.sub(self.tool_call_pattern, "", text)

        return content, function_calls
