from typing import Any, List, Optional
import copy

from .utils import (
    transform_messages,
    create_inline_tool_results_content,
    requires_inline_tool_results,
    is_google_model,
)

# Configuration for content filtering
def get_max_banned_word_length(banned_keywords: list) -> int:
    return max((len(word) for word in banned_keywords), default=0)

# Utility functions
def censorBannedWords(text: str, banned_keywords: list) -> str:
    import re
    # Build regex pattern for banned words, word boundaries, case-insensitive
    pattern = r'\\b(' + '|'.join(re.escape(word) for word in banned_keywords) + r')\\b'
    regex = re.compile(pattern, re.IGNORECASE)
    return regex.sub('???', text)

def delete_refusal(model: str, message: dict):
    # Delete the refusal field from resMessage for Pixtral model
    if model == "mistralai/pixtral-large-2411" and "refusal" in message:
        del message["refusal"]

def transform_tools_for_google_provider(tool_function: dict) -> dict:
    """
    Transforms tool function parameters for Google provider compatibility

    Google doesn't support object properties without defined subproperties
    This function changes such properties to array types with appropriate descriptions
    """
    import copy
    cloned_function = copy.deepcopy(tool_function)

    parameters = cloned_function.get("parameters")
    if (
        parameters
        and parameters.get("type") == "object"
        and (
            not parameters.get("properties")
            or len(parameters.get("properties", {})) == 0
        )
    ):
        parameters["type"] = "array"
        if isinstance(parameters.get("description"), str):
            parameters["description"] = (
                f'{parameters["description"]} Provide as array of strings in "NAME=VALUE" format.'
            )
        parameters["items"] = {"type": "string"}
        return cloned_function

    if not parameters or not parameters.get("properties"):
        return cloned_function

    properties = parameters["properties"]
    for prop_value in properties.values():
        if (
            prop_value
            and isinstance(prop_value, dict)
            and prop_value.get("type") == "object"
            and (
                "properties" not in prop_value
                or len(prop_value.get("properties", {})) == 0
            )
        ):
            prop_value["type"] = "array"
            if "description" in prop_value and isinstance(prop_value["description"], str):
                prop_value["description"] = (
                    f'{prop_value["description"]} Provide as array of strings in "NAME=VALUE" format.'
                )
            prop_value["items"] = {"type": "string"}
    return cloned_function

def execute_tool_call(
    tool_call,
    tools: List[Any],
    message_context: dict,
    toolState: Optional[Any] = None,
) -> Optional[dict]:
    """
    Executes a tool call from a tool call message (e.g., ChatCompletionMessageToolCall).
    The tool_call is expected to be an object with .id, .function (with .name and .arguments), etc.
    """
    print("TOOL CALL", tool_call)

    # Extract function name and arguments from the tool_call object
    # Support both dict and object (e.g., ChatCompletionMessageToolCall)
    if hasattr(tool_call, "function"):
        function_name = getattr(tool_call.function, "name", None)
        arguments_str = getattr(tool_call.function, "arguments", None)
    elif isinstance(tool_call, dict) and "function" in tool_call:
        function_name = tool_call["function"].get("name")
        arguments_str = tool_call["function"].get("arguments")
    else:
        print("Invalid tool_call format")
        return None

    if not arguments_str:
        arguments_str = "{}"

    # Find the tool in the tools list by function name
    find_function = next(
        (tool for tool in tools if tool.get("function", {}).get("name") == function_name),
        None,
    )
    if not find_function:
        print(f'Function {function_name} not found')
        return None

    import json
    try:
        function_arguments = json.loads(arguments_str)
    except Exception as e:
        print(f"Error parsing arguments: {e}")
        function_arguments = {}

    # The tool's execute method is expected to be a callable
    import inspect
    execute_fn = find_function["execute"]
    sig = inspect.signature(execute_fn)

    # function_arguments is a dict of the arguments but you gotta pass them in as arguments
    if "context" in sig.parameters:
        function_arguments["context"] = message_context
    if "state" in sig.parameters:
        function_arguments["state"] = toolState
    tool_response = execute_fn(**function_arguments)

    # Get the tool_call id
    tool_call_id = getattr(tool_call, "id", None)
    if tool_call_id is None and isinstance(tool_call, dict):
        tool_call_id = tool_call.get("id")

    return {
        "toolResponse": tool_response,
        "toolCallObj": {
            "id": tool_call_id,
            "function": function_name,
            "arguments": arguments_str,
        },
        "toolCallObj2": {
            "id": tool_call_id,
            "function": {
                "name": function_name,
                "arguments": arguments_str,
            },
            "type": "function",
        }
    }

def truncate_messages(messages: List[Any], indirect=False) -> List[Any]:
    """
    Truncates messages to a maximum number of tokens.
    """
    if indirect:
        # Truncate back to the first message with role 'tool' (keep that tool message)
        for idx, msg in enumerate(messages):
            if msg.get("role") == "tool":
                return messages[:idx+1]
        return messages  # If no 'tool' message, return all
    else:
        # Truncate from the beginning up to and including the last user message
        for idx in range(len(messages)-1, -1, -1):
            if messages[idx].get("role") == "user":
                return messages[:idx+1]
        return messages  # If no 'user' message, return all

from dataclasses import dataclass
from typing import Any, List, Optional, Dict

@dataclass
class OpenAICompatibleArgs:
    model: Optional[str] = None
    messages: Optional[List[Any]] = None
    system_prompt: Optional[str] = None
    bucket: Optional[Any] = None
    toolArgs: Optional[Dict[str, Any]] = None

def openai_compatible_completion(
    client,
    args: OpenAICompatibleArgs,
    provider: str,
) -> Any:
    model = args.model
    messages = args.messages
    bucket = args.bucket
    toolArgs = args.toolArgs  # This is now a dataclass (AgentRedTeamingToolsBehaviorConfig)
    system_prompt = args.system_prompt

    messages = truncate_messages(messages, indirect=toolArgs.is_indirect_prompt_injection)

    toolState = None
    if toolArgs.has_state:
        toolState = toolArgs.init_state()

    # Extract context from dataclass
    messageContext = {
        "context": copy.deepcopy(toolArgs.context) if toolArgs and getattr(toolArgs, "context", None) else None
    }

    # Check for tools in dataclass
    hasTools = bool(toolArgs and getattr(toolArgs, "tools", None) and len(toolArgs.tools) > 0)
    openAICompatibleMessages = transform_messages(
        messages,
        bucket,
        "OpenAI",
        toolArgs.tools if hasTools else None,
        messageContext,
        model,
    )

    if openAICompatibleMessages[0]["role"] != "system":
        if system_prompt:
            openAICompatibleMessages.insert(0, {
                "role": "system",
                "content": system_prompt,
            })
        else:
            if toolArgs and getattr(toolArgs, "system_prompt", None):
                system_prompt_val = toolArgs.system_prompt() if callable(toolArgs.system_prompt) else toolArgs.system_prompt
                openAICompatibleMessages.insert(0, {
                    "role": "system",
                    "content": system_prompt_val,
                })

    modelName = model.split(":high")[0] if model else model

    tools = None
    if hasTools:
        tools = []
        for tool in toolArgs.tools:
            # Each tool is assumed to be a Tool object with a "function" attribute
            if provider != "google" and not is_google_model(modelName):
                tools.append({
                    "type": "function",
                    "function": tool["function"],
                })
                continue
            transformedFunction = transform_tools_for_google_provider(tool["function"])
            tools.append({
                "type": "function",
                "function": transformedFunction,
            })

    print("MODEL NAME", modelName)

    moreArgs = (
        {}
        if any(x in modelName for x in ["o1", "o3", "o4"])
        else {"temperature": 0}
    )
    if model and ":high" in model:
        moreArgs["reasoning_effort"] = "high"
    try:
        res = client.chat.completions.create(
            model=modelName,
            tools=tools,
            messages=openAICompatibleMessages,
            stream=False,
            **moreArgs
        )

        if hasattr(res, "error") and isinstance(res.error, dict) and res.error:
            print("OPENAI/OPENROUTER ERROR", res.error)
            errorMessage = res.error.get("message", "Unknown error")
            raise Exception("OpenAI/OpenRouter failed: " + errorMessage)

        if not hasattr(res, "choices") or not res.choices or len(res.choices) == 0:
            raise Exception("No response from OpenAI/OpenRouter")

        resMessage = res.choices[0].message if hasattr(res.choices[0], "message") else None
        resContent = getattr(resMessage, "content", "") if resMessage else ""

        if resMessage and getattr(resMessage, "tool_calls", None) and len(resMessage.tool_calls) > 0:
            delete_refusal(model, resMessage)

            if is_google_model(modelName) and len(resMessage.tool_calls) > 1:
                print(f"Limiting Google model {modelName} to process only the first tool call")
                resMessage.tool_calls = [resMessage.tool_calls[0]]

            # Convert resMessage to a dictionary with role, content, etc.
            if resMessage:
                resMessageDict = {
                    "role": getattr(resMessage, "role", None),
                    "content": getattr(resMessage, "content", None),
                }
                # Add tool_calls if present
                # Convert tool_calls to list of dicts if present
                if hasattr(resMessage, "tool_calls") and resMessage.tool_calls is not None:
                    resMessageDict["tool_calls"] = [
                        {
                            "id": getattr(tc, "id", None),
                            "type": getattr(tc, "type", None),
                            "function": {
                                "name": getattr(tc.function, "name", None) if hasattr(tc, "function") and tc.function is not None else None,
                                "arguments": getattr(tc.function, "arguments", None) if hasattr(tc, "function") and tc.function is not None else None,
                            } if hasattr(tc, "function") and tc.function is not None else None,
                        }
                        for tc in resMessage.tool_calls
                    ]
                # Add function_call if present
                if hasattr(resMessage, "function_call") and resMessage.function_call is not None:
                    resMessageDict["function_call"] = {
                        "name": getattr(resMessage.function_call, "name", None),
                        "arguments": getattr(resMessage.function_call, "arguments", None),
                    }
                openAICompatibleMessages.append(resMessageDict)
            latestMessage = resMessage

            toolCallsArray = []
            toolCallsArray2 = []
            MAX_RECURSION_DEPTH = 5
            MAX_TOOL_CALLS_PER_ITERATION = 10
            currentDepth = 0

            while (
                getattr(latestMessage, "tool_calls", None)
                and len(latestMessage.tool_calls) > 0
                and currentDepth < MAX_RECURSION_DEPTH
            ):
                toolResponses = []
                toolCallsToProcess = (
                    [latestMessage.tool_calls[0]]
                    if is_google_model(modelName) and len(latestMessage.tool_calls) > 1
                    else latestMessage.tool_calls[:MAX_TOOL_CALLS_PER_ITERATION]
                )

                for toolCall in toolCallsToProcess:
                    result = execute_tool_call(
                        toolCall,
                        toolArgs.tools if toolArgs and getattr(toolArgs, "tools", None) else [],
                        messageContext,
                        toolState,
                    )
                    if result:
                        toolResponses.append({
                            "role": "tool",
                            "tool_call_id": toolCall.id,
                            "content": result["toolResponse"],
                        })
                        openAICompatibleMessages.append({
                            "role": "tool",
                            "tool_call_id": toolCall.id,
                            "content": result["toolResponse"],
                        })
                        toolCallsArray.append(result["toolCallObj"])
                        toolCallsArray2.append(result["toolCallObj2"])
                inline_tool_result_required = requires_inline_tool_results(modelName)
                if inline_tool_result_required and len(toolResponses) > 0:
                    lastAssistantMessageIndex = next(
                        (
                            i
                            for i, msg in enumerate(openAICompatibleMessages)
                            if msg.get("role") == "assistant"
                            and "tool_calls" in msg
                            and isinstance(msg["tool_calls"], list)
                            and len(msg["tool_calls"]) > 0
                        ),
                        -1,
                    )
                    if lastAssistantMessageIndex != -1:
                        assistantMessage = openAICompatibleMessages[lastAssistantMessageIndex]
                        toolResultsString = create_inline_tool_results_content(toolResponses)
                        if isinstance(assistantMessage.get("content"), str):
                            assistantMessage["content"] = f'{assistantMessage["content"]} {toolResultsString}'
                        elif isinstance(assistantMessage.get("content"), list):
                            assistantMessage["content"].append({
                                "type": "text",
                                "text": toolResultsString,
                            })
                        toolCallIds = set(tr["tool_call_id"] for tr in toolResponses)
                        updatedMessages = [
                            msg
                            for msg in openAICompatibleMessages
                            if msg.get("role") != "tool" or msg.get("tool_call_id") not in toolCallIds
                        ]
                        openAICompatibleMessages.clear()
                        openAICompatibleMessages.extend(updatedMessages)

                toolCallResponse = client.chat.completions.create(
                    model=modelName,
                    tools=tools,
                    messages=openAICompatibleMessages,
                    stream=False,
                    **moreArgs
                )

                if hasattr(toolCallResponse, "error") and isinstance(toolCallResponse.error, dict) and toolCallResponse.error:
                    print("OPENAI/OPENROUTER TOOL CALL ERROR", toolCallResponse.error)
                    errorMessage = toolCallResponse.error.get("message", "Unknown error")
                    raise Exception("OpenAI/OpenRouter Tool call failed: " + errorMessage)

                latestMessage = (
                    toolCallResponse.choices[0].message
                    if hasattr(toolCallResponse.choices[0], "message")
                    else None
                )
                if not latestMessage:
                    break

                delete_refusal(model, latestMessage)

                if (
                    latestMessage
                    and is_google_model(modelName)
                    and getattr(latestMessage, "tool_calls", None)
                    and len(latestMessage.tool_calls) > 1
                ):
                    print(f"Limiting Google model {modelName} to process only the first tool call in subsequent response")
                    latestMessage.tool_calls = [latestMessage.tool_calls[0]]

                # Convert toolCallsToProcess (list of ChatCompletionMessageToolCall objects) to list of dicts
                tool_calls_dicts = []
                for tool_call in toolCallsToProcess:
                    tool_calls_dicts.append({
                        "id": tool_call.id,
                        "type": tool_call.type,
                        "function": {
                            "name": tool_call.function.name,
                            "arguments": tool_call.function.arguments,
                        }
                    })
                openAICompatibleMessages.append({
                    "role": "assistant",
                    "content": getattr(latestMessage, "content", "*Empty response*"),
                    "tool_calls": tool_calls_dicts,
                })
                currentDepth += 1

            # return {
            #     "message": getattr(latestMessage, "content", "*Empty response*"),
            #     "tool_calls": toolCallsArray,
            # }
            return openAICompatibleMessages

        # return {"message": resContent}
        openAICompatibleMessages.append(
            {
                "role": "assistant",
                "content": resContent,
                "tool_calls": [],
            }
        )
        return openAICompatibleMessages
    except Exception as error:
        print("ERROR", error)
        raise error

