# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from  https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Utilities for client classes"""

import logging
import warnings
from typing import Any, Optional, Protocol, runtime_checkable


@runtime_checkable
class FormatterProtocol(Protocol):
    """Structured Output classes with a format method"""

    def format(self) -> str: ...


def validate_parameter(
    params: dict[str, Any],
    param_name: str,
    allowed_types: tuple[Any, ...],
    allow_None: bool,  # noqa: N803
    default_value: Any,
    numerical_bound: Optional[tuple[Optional[float], Optional[float]]],
    allowed_values: Optional[list[Any]],
) -> Any:
    """Validates a given config parameter, checking its type, values, and setting defaults
    Parameters:
        params (Dict[str, Any]): Dictionary containing parameters to validate.
        param_name (str): The name of the parameter to validate.
        allowed_types (Tuple): Tuple of acceptable types for the parameter.
        allow_None (bool): Whether the parameter can be `None`.
        default_value (Any): The default value to use if the parameter is invalid or missing.
        numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
            A tuple specifying the lower and upper bounds for numerical parameters.
            Each bound can be `None` if not applicable.
        allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
            Can be `None` if no specific values are required.

    Returns:
        Any: The validated parameter value or the default value if validation fails.

    Raises:
        TypeError: If `allowed_values` is provided but is not a list.

    Example Usage:
    ```python
        # Validating a numerical parameter within specific bounds
        params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
        temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
        # Result: 0.5

        # Validating a parameter that can be one of a list of allowed values
        model = validate_parameter(
        params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
        )
        # If "safety_model" is missing or invalid in params, defaults to "default"
    ```
    """
    if allowed_values is not None and not isinstance(allowed_values, list):
        raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")

    param_value = params.get(param_name, default_value)
    warning = ""

    if param_value is None and allow_None:
        pass
    elif param_value is None:
        if not allow_None:
            warning = "cannot be None"
    elif not isinstance(param_value, allowed_types):
        # Check types and list possible types if invalid
        if isinstance(allowed_types, tuple):
            formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
        else:
            formatted_types = f"{allowed_types.__name__}"
        warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
    elif numerical_bound:
        # Check the value fits in possible bounds
        lower_bound, upper_bound = numerical_bound
        if (lower_bound is not None and param_value < lower_bound) or (
            upper_bound is not None and param_value > upper_bound
        ):
            warning = "has numerical bounds"
            if lower_bound is not None:
                warning += f", >= {lower_bound!s}"
            if upper_bound is not None:
                if lower_bound is not None:
                    warning += " and"
                warning += f" <= {upper_bound!s}"
            if allow_None:
                warning += ", or can be None"

    elif allowed_values:  # noqa: SIM102
        # Check if the value matches any allowed values
        if not (allow_None and param_value is None) and param_value not in allowed_values:
            warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"

    # If we failed any checks, warn and set to default value
    if warning:
        warnings.warn(
            f"Config error - {param_name} {warning}, defaulting to {default_value}.",
            UserWarning,
        )
        param_value = default_value

    return param_value


def should_hide_tools(messages: list[dict[str, Any]], tools: list[dict[str, Any]], hide_tools_param: str) -> bool:
    """Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
    Parameters:
        messages (List[Dict[str, Any]]): List of messages
        tools (List[Dict[str, Any]]): List of tools
        hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".

    Returns:
        bool: Indicates whether the tools should be excluded from the response create request

    Example Usage:
    ```python
        # Validating a numerical parameter within specific bounds
        messages = params.get("messages", [])
        tools = params.get("tools", None)
        hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
    """
    if hide_tools_param == "never" or tools is None or len(tools) == 0:
        return False
    elif hide_tools_param == "if_any_run":
        # Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
        return any(["tool_call_id" in dictionary for dictionary in messages])
    elif hide_tools_param == "if_all_run":
        # Return True if all tools have been executed at least once. False otherwise.

        # Get the list of tool names
        check_tool_names = [item["function"]["name"] for item in tools]

        # Prepare a list of tool call ids and related function names
        tool_call_ids = {}

        # Loop through the messages and check if the tools have been run, removing them as we go
        for message in messages:
            if "tool_calls" in message:
                # Register the tool ids and the function names (there could be multiple tool calls)
                for tool_call in message["tool_calls"]:
                    tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
            elif "tool_call_id" in message:
                # Tool called, get the name of the function based on the id
                tool_name_called = tool_call_ids[message["tool_call_id"]]

                # If we had not yet called the tool, check and remove it to indicate we have
                if tool_name_called in check_tool_names:
                    check_tool_names.remove(tool_name_called)

        # Return True if all tools have been called at least once (accounted for)
        return len(check_tool_names) == 0
    else:
        raise TypeError(
            f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
        )


# Logging format (originally from FLAML)
logging_formatter = logging.Formatter(
    "[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)
