# SPDX-License-Identifier: Apache-2.0
import json
import re
from copy import deepcopy
from unittest.mock import MagicMock

import pytest
from pydantic import TypeAdapter

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              ChatCompletionToolsParam)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

EXAMPLE_TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type":
                        "string",
                        "description":
                        "The city to find the weather for"
                        ", e.g. 'San Francisco'",
                    },
                },
                "required": ["city"],
                "additionalProperties": False
            },
        },
        "strict": True
    },
    {
        "type": "function",
        "function": {
            "name": "get_forecast",
            "description": "Get the weather forecast for a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type":
                        "string",
                        "description":
                        "The city to get the forecast for, e.g. 'New York'",
                    },
                    "days": {
                        "type":
                        "integer",
                        "description":
                        "Number of days to get the forecast for (1-7)",
                    },
                },
                "required": ["city", "days"],
                "additionalProperties": False
            },
        },
        "strict": True
    },
]


def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
                       should_match: bool):
    self = MagicMock(tool_choice="required", tools=tools)
    schema = ChatCompletionRequest._get_guided_json_from_tool(self)
    assert isinstance(schema, dict)

    # use build_regex_from_schema used in JSONLogitsProcessor to create Guide
    from outlines_core.fsm.json_schema import build_regex_from_schema
    regex = build_regex_from_schema(json.dumps(schema))
    compiled = re.compile(regex)
    matches = compiled.fullmatch(json.dumps(sample_output)) is not None

    assert matches == should_match


VALID_TOOL_OUTPUTS = [
    ([{
        "name": "get_current_weather",
        "parameters": {
            "city": "Vienna"
        }
    }], True),
    ([{
        "name": "get_current_weather",
        "parameters": {
            "city": "Vienna"
        }
    }, {
        "name": "get_current_weather",
        "parameters": {
            "city": "Berlin"
        }
    }], True),
    ([{
        "name": "get_forecast",
        "parameters": {
            "city": "Vienna",
            "days": 7
        }
    }], True),
    ([{
        "name": "get_forecast",
        "parameters": {
            "city": "Vienna",
            "days": 7
        }
    }, {
        "name": "get_current_weather",
        "parameters": {
            "city": "Vienna"
        }
    }], True),
    ([{
        "name": "get_forecast",
        "parameters": {
            "city": "Vienna",
            "days": 7
        }
    }, {
        "name": "get_current_weather",
        "parameters": {
            "city": "Vienna"
        }
    }, {
        "name": "get_forecast",
        "parameters": {
            "city": "Berlin",
            "days": 7
        }
    }, {
        "name": "get_current_weather",
        "parameters": {
            "city": "Berlin"
        }
    }], True),
]

VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]


@pytest.mark.parametrize(
    "sample_output, should_match",
    VALID_TOOL_OUTPUTS + [
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
            [{  # function without required parameters cannot be generated
                "name": "get_current_weather"
            }],
            False),
        (
            [{  # function without required parameters cannot be generated
                "name": "get_current_weather",
                "parameters": {}
            }],
            False),
        (
            [{  # function without required parameters cannot be generated
                "name": "get_current_weather",
                "parameters": None
            }],
            False),
        (
            {  # tool call without lists cannot be generated
                "name": "get_current_weather",
                "parameters": {
                    "city": "Vienna"
                }
            },
            False),
        (
            [{  # tool call with extra parameters cannot be generated
                "name": "get_current_weather",
                "parameters": {
                    "city": "Vienna",
                    "extra": "value"
                }
            }],
            False),
        (
            [{  # tool call where parameters are first cannot be generated
                "parameters": {
                    "city": "Vienna"
                },
                "name": "get_current_weather"
            }],
            False),
        (
            [{  # tool call without all required parameters cannot be generated
                "name": "get_forecast",
                "parameters": {
                    "city": "Vienna"
                }
            }],
            False),
        (  # tool call with incorrect name/parameters cannot be generated
            [{
                "name": "get_weather",
                "parameters": {
                    "city": "Vienna",
                    "days": 7
                }
            }], False),
        (  #  tool call with both valid and empty function cannot be generated
            [{
                "name": "get_current_weather",
                "parameters": {
                    "city": "Vienna"
                }
            }, {}], False),
    ])
def test_guided_json(sample_output, should_match):
    _compile_and_check(tools=TypeAdapter(
        list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
                       sample_output=sample_output,
                       should_match=should_match)


def update_parameters_none(
        tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
    tool.function.parameters = None
    return tool


def update_parameters_empty_dict(
        tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
    tool.function.parameters = {}
    return tool


@pytest.mark.parametrize(
    "sample_output, should_match",
    [
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
            [{  # function without required parameters cannot be generated
                "name": "get_current_weather"
            }],
            False),
        (
            [{  # function without required parameters cannot be generated
                "name": "get_current_weather",
                "parameters": None
            }],
            False),
        (
            [{  # function with extra parameters cannot be generated
                "name": "get_current_weather",
                "parameters": {
                    "extra": "value"
                }
            }],
            False),
        (
            [{  # only function with empty parameters object is valid
                "name": "get_current_weather",
                "parameters": {}
            }],
            True),
    ])
@pytest.mark.parametrize(
    "update_parameters",
    [update_parameters_none, update_parameters_empty_dict])
def test_guided_json_without_parameters(sample_output, should_match,
                                        update_parameters):
    updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
    tools = TypeAdapter(
        list[ChatCompletionToolsParam]).validate_python(updated_tools)
    tools = list(map(update_parameters, tools))
    assert all([
        tool.function.parameters is None or tool.function.parameters == {}
        for tool in tools
    ])
    _compile_and_check(tools=tools,
                       sample_output=sample_output,
                       should_match=should_match)


@pytest.mark.parametrize("output", VALID_TOOLS)
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
    self = MagicMock()

    output = deepcopy(output)
    if empty_params:
        output = [{"name": o["name"], "parameters": {}} for o in output]
    output_json = json.dumps(output)

    previous_text = ""
    function_name_returned = False
    messages = []
    for i in range(0, len(output_json), delta_len):
        delta_text = output_json[i:i + delta_len]
        current_text = previous_text + delta_text

        delta_message, function_name_returned = (
            OpenAIServingChat.extract_tool_call_required_streaming(
                self,
                previous_text=previous_text,
                current_text=current_text,
                delta_text=delta_text,
                function_name_returned=function_name_returned))

        if delta_message:
            messages.append(delta_message)

        previous_text = current_text

    assert len(messages) > 0
    combined_messages = "["
    for message in messages:
        if message.tool_calls[0].function.name:
            if len(combined_messages) > 1:
                combined_messages += "},"

            combined_messages += '{"name": "' + \
                message.tool_calls[0].function.name  + \
                    '", "parameters": ' + \
                        message.tool_calls[0].function.arguments
        else:
            combined_messages += message.tool_calls[0].function.arguments
    combined_messages += "}]"
    assert json.loads(combined_messages) == output
    assert json.dumps(json.loads(combined_messages)) == output_json
