import logging
import os
import warnings
from collections.abc import Sequence
from typing import overload

from google import genai
from google.genai import types as genai_types
from google.genai.errors import ClientError
from pydantic import BaseModel
from tenacity import retry, retry_if_not_exception_type, stop_after_attempt, wait_random_exponential
from typing_extensions import deprecated

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.functions_runtime import EmptyEnv, Env, Function, FunctionCall, FunctionsRuntime
from agentdojo.types import ChatAssistantMessage, ChatMessage, text_content_block_from_string

logger = logging.getLogger(__name__)


@overload
def remove_additional_properties_recursively(schema_node: dict) -> dict: ...


@overload
def remove_additional_properties_recursively(schema_node: list) -> list: ...


def remove_additional_properties_recursively(schema_node: dict | list) -> dict | list:
    """
    Recursively removes the 'additionalProperties' key from all dictionary
    objects within a JSON schema structure. This modification happens in-place.

    Args:
        schema_node: A dictionary or list representing a part of a JSON schema.

    Returns:
        The modified schema_node.
    """
    if isinstance(schema_node, dict):
        # Remove 'additionalProperties' from the current dictionary if it exists
        if "additionalProperties" in schema_node:
            del schema_node["additionalProperties"]

        # Recursively process values in the dictionary
        # Iterate over a list of keys in case the dictionary is modified by recursion (though not in this specific logic for values)
        for key in list(schema_node.keys()):
            remove_additional_properties_recursively(schema_node[key])

    elif isinstance(schema_node, list):
        # Recursively process items in the list
        for item in schema_node:
            remove_additional_properties_recursively(item)

    return schema_node


def resolve_refs(cls: type[BaseModel]):
    schema = cls.model_json_schema()
    if "$defs" not in schema:
        return schema

    defs = schema.pop("$defs")

    def _resolve(schema):
        if "default" in schema:
            del schema["default"]
        if "$ref" in schema:
            ref = schema.pop("$ref")
            schema.update(defs[ref.split("/")[-1]])
        if "properties" in schema:
            for prop in schema["properties"].values():
                _resolve(prop)
        if "items" in schema:
            _resolve(schema["items"])

    _resolve(schema)
    return schema


def _parameters_to_google(parameters: type[BaseModel]) -> genai_types.Schema:
    openapi_parameters = remove_additional_properties_recursively(resolve_refs(parameters))
    return genai_types.Schema.model_validate(openapi_parameters)


def _function_to_google(f: Function) -> genai_types.FunctionDeclaration:
    return genai_types.FunctionDeclaration(
        name=f.name,
        description=f.description,
        parameters=_parameters_to_google(f.parameters),
    )


def _merge_tool_result_messages(messages: list[genai_types.Content]) -> list[genai_types.Content]:
    merged_messages = []
    current_tool_message = None

    for message in messages:
        if message.parts is not None and any(part.function_response is not None for part in message.parts):
            if current_tool_message is None:
                current_tool_message = genai_types.Content(parts=[], role=message.role)
            assert current_tool_message.parts is not None
            current_function_responses = [part for part in message.parts if part.function_response is not None]
            current_tool_message = genai_types.Content(
                parts=current_tool_message.parts + current_function_responses, role=message.role
            )
        else:
            if current_tool_message is not None:
                merged_messages.append(current_tool_message)
                current_tool_message = None
            merged_messages.append(message)

    if current_tool_message is not None:
        merged_messages.append(current_tool_message)

    return merged_messages


def _parts_from_assistant_message(assistant_message: ChatAssistantMessage) -> list[genai_types.Part]:
    parts = []
    if assistant_message["content"] is not None:
        for part in assistant_message["content"]:
            parts.append(genai_types.Part.from_text(text=part["content"]))
    if assistant_message["tool_calls"]:
        for tool_call in assistant_message["tool_calls"]:
            part = genai_types.Part.from_function_call(name=tool_call.function, args=dict(tool_call.args))
            parts.append(part)
    return parts


def _message_to_google(message: ChatMessage) -> genai_types.Content:
    match message["role"]:
        case "user":
            return genai_types.Content(
                role="user",
                parts=[genai_types.Part.from_text(text=block["content"]) for block in message["content"]],
            )
        case "assistant":
            return genai_types.Content(role="model", parts=_parts_from_assistant_message(message))
        case "tool":
            # tool_call singular here, plural above. This is intentional
            tool_call = message["tool_call"]
            if message["error"] is not None:
                response = {"error": message["error"]}
            else:
                response = {"result": message["content"]}
            return genai_types.Content(
                role="user",
                parts=[
                    genai_types.Part.from_function_response(
                        name=tool_call.function,
                        response=response,
                    )
                ],
            )
        case _:
            raise ValueError(f"Invalid role for Google: {message['role']}")


@retry(
    wait=wait_random_exponential(multiplier=1, max=40),
    stop=stop_after_attempt(3),
    reraise=True,
    retry=retry_if_not_exception_type(ClientError),
)
def chat_completion_request(
    model: str,
    client: genai.Client,
    contents: genai_types.ContentListUnion,
    generation_config: genai_types.GenerateContentConfig,
) -> genai_types.GenerateContentResponse:
    response = client.models.generate_content(
        model=model,
        contents=contents,
        config=generation_config,
    )
    return response


EMPTY_FUNCTION_NAME = "<empty-function-name>"


def _google_to_tool_call(function_call: genai_types.FunctionCall) -> FunctionCall:
    if function_call.name == "" or function_call.name is None:
        function = EMPTY_FUNCTION_NAME  # sometimes the API returns an empty string
    else:
        function = function_call.name
    return FunctionCall(function=function, args=function_call.args or {}, id=function_call.id)


def _google_to_assistant_message(message: genai_types.GenerateContentResponse) -> ChatAssistantMessage:
    tool_calls = []
    text_parts = []

    if not message.candidates:
        logger.warning("Google completion returned no candidates.")
        return ChatAssistantMessage(role="assistant", content=[], tool_calls=[])

    if message.candidates[0].content is None:
        logger.warning(
            f"Google completion was empty, with finish reason: {message.candidates[0].finish_reason}"
        )
        return ChatAssistantMessage(role="assistant", content=[], tool_calls=[])

    if message.candidates[0].content.parts is None:
        logger.warning("Google completion returned a candidate with no content parts.")
        return ChatAssistantMessage(role="assistant", content=[], tool_calls=[])
    for part in message.candidates[0].content.parts:
        if fn := part.function_call:
            tool_calls.append(_google_to_tool_call(fn))
        try:
            if part.text:
                text_parts.append(text_content_block_from_string(part.text))
        except AttributeError:
            pass
    return ChatAssistantMessage(role="assistant", content=text_parts, tool_calls=tool_calls)


class GoogleLLM(BasePipelineElement):
    """LLM pipeline element that uses Google Vertex AI (i.e. Gemini models).

    !!! warning
        In order to use Google LLMs with `VertexAI` and not the AI Studio API (by specifying an API key in the client),
        you need to run `vertexai.init()` before instantiating this class. This is done automatically if you create a
        pipeline with [`AgentPipeline.from_config`][agentdojo.agent_pipeline.AgentPipeline.from_config]. However,
        you need to have Google Cloud CLI installed on your machine and be authenticated.
        After being authnticated, you need to run `gcloud auth application-default login` in the terminal
        to let Python access your Google Cloud account.

    Args:
        model: The model name.
        client: The `google-genai` client to use. Not passing this argument (or passing `None`) is deprecated.
        temperature: The temperature to use for generation.
    """

    @overload
    @deprecated("Not specifying a client for `GoogleLLM` is deprecated and will be removed in a future release.")
    def __init__(
        self, model: str, client: None = None, temperature: float | None = 0.0, max_tokens: int = 65535
    ) -> None: ...

    @overload
    def __init__(
        self, model: str, client: genai.Client, temperature: float | None = 0.0, max_tokens: int = 65535
    ) -> None: ...

    def __init__(
        self, model: str, client: genai.Client | None = None, temperature: float | None = 0.0, max_tokens: int = 65535
    ):
        self.model = model
        if client is not None:
            self.client = client
        else:
            warnings.warn(
                "Not specifying a client for `GoogleLLM` is deprecated and will be removed in a future release."
            )
            self.client = genai.Client(
                vertexai=True, project=os.getenv("GCP_PROJECT"), location=os.getenv("GCP_LOCATION")
            )
        self.temperature = temperature
        self.max_tokens = max_tokens

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        first_message, *other_messages = messages
        if first_message["role"] == "system":
            system_instruction = first_message["content"][0]["content"]
        else:
            system_instruction = None
            other_messages = messages
        google_messages = [_message_to_google(message) for message in other_messages]
        google_messages = _merge_tool_result_messages(google_messages)
        google_functions = [_function_to_google(tool) for tool in runtime.functions.values()]
        google_tools: genai_types.ToolListUnion | None = (
            [genai_types.Tool(function_declarations=google_functions)] if google_functions else None
        )

        generation_config = genai_types.GenerateContentConfig(
            temperature=self.temperature,
            max_output_tokens=self.max_tokens,
            tools=google_tools,
            system_instruction=system_instruction,
        )
        completion = chat_completion_request(
            self.model,
            self.client,
            google_messages,  # type: ignore
            generation_config=generation_config,
        )
        output = _google_to_assistant_message(completion)
        messages = [*messages, output]
        return query, runtime, env, messages, extra_args
