from collections.abc import Sequence
from typing import Any

import vertexai.generative_models as genai
from google.cloud.aiplatform_v1beta1 import types as aiplatform_types
from google.protobuf.struct_pb2 import Struct
from proto.marshal.collections.maps import MapComposite
from proto.marshal.collections.repeated import RepeatedComposite
from pydantic import BaseModel
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

from agentgym.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentgym.functions_engine.functions_engine import (
    Env,
    Function,
    FunctionCall,
)
from agentgym.openapi_utils import make_openapi_from_pydantic_model
from agentgym.types import (
    ChatAssistantMessage,
    ChatMessage,
    DummyEnv,
)


def _parameters_to_google(parameters: type[BaseModel]) -> dict[str, Any]:
    openapi_parameters = make_openapi_from_pydantic_model(parameters)  # type: ignore
    for property_name, property_value in openapi_parameters["properties"].items():
        if "allOf" in property_value:
            enum = property_value["allOf"][0]
            enum.pop("title")
            openapi_parameters["properties"][property_name] = enum
    return openapi_parameters


def _function_to_google(f: Function) -> genai.FunctionDeclaration:
    return genai.FunctionDeclaration(
        name=f.name,
        description=f.description,
        parameters=_parameters_to_google(f.parameters),
    )


def _merge_tool_result_messages(
    messages: list[genai.Content],
) -> list[genai.Content]:
    messages = list(messages)
    i = 0
    while i < len(messages) - 1:
        message_1 = messages[i]
        message_2 = messages[i + 1]
        message_1_function_response_parts = []
        for part in message_1.parts:
            if part.function_response:
                message_1_function_response_parts.append(part)
        message_2_function_response_parts = []
        for part in message_2.parts:
            if part.function_response:
                message_2_function_response_parts.append(part)
        if message_1_function_response_parts and message_2_function_response_parts:
            message_1.parts.extend(message_2_function_response_parts)
            messages.pop(i + 1)
            message_2_function_response_parts = []
        else:
            i += 1
    return messages


def _message_to_google(
    message: ChatMessage,
) -> genai.Content:
    match message["role"]:
        case "user":
            return genai.Content(
                role="user",
                parts=[genai.Part.from_text(message["content"])],
            )
        case "assistant":
            parts = []
            if message["content"] is not None:
                parts.append(genai.Part.from_text(message["content"]))
            if message["tool_calls"]:
                for tool_call in message["tool_calls"]:
                    part = genai.Part.from_dict(
                        dict(function_call=dict(name=tool_call["function"], args=tool_call["args"]))
                    )
                    parts.append(part)
            return genai.Content(
                role="model",
                parts=parts,
            )
        case "tool":
            tool_call = message["tool_call"]
            return genai.Content(
                parts=[
                    genai.Part.from_function_response(
                        name=tool_call["function"],
                        response={"content": message["content"]},
                    )
                ],
            )
        case _:
            raise ValueError(f"Invalid role for Google: {message['role']}")


@retry(
    wait=wait_random_exponential(multiplier=1, max=40),
    stop=stop_after_attempt(3),
    reraise=True,
)
def chat_completion_request(
    model: genai.GenerativeModel,
    contents: list[genai.Content],
    tools: list[genai.Tool],
    generation_config: genai.GenerationConfig,
) -> genai.GenerationResponse:
    response: genai.GenerationResponse = model.generate_content(  # type: ignore -- streaming=False, so no iterable
        contents,
        tools=tools if tools else None,
        generation_config=generation_config,
    )
    return response


def repeated_composite_to_dict(
    repeated_composite: RepeatedComposite,
) -> list:
    l = []
    for item in repeated_composite:
        if isinstance(item, RepeatedComposite):
            v = repeated_composite_to_dict(item)
        elif isinstance(item, Struct | MapComposite):
            v = struct_to_dict(item)
        else:
            v = item
        l.append(v)
    return l


def struct_to_dict(struct: Struct | MapComposite) -> dict:
    out = {}
    for key, value in struct.items():
        if isinstance(value, Struct | MapComposite):
            out[key] = struct_to_dict(value)
        elif isinstance(value, RepeatedComposite):
            out[key] = repeated_composite_to_dict(value)
        elif isinstance(value, list):
            out[key] = [struct_to_dict(v) for v in value]
        else:
            out[key] = value
    return out


EMPTY_FUNCTION_NAME = "<empty-function-name>"


def _google_to_tool_call(function_call: aiplatform_types.FunctionCall) -> FunctionCall:
    args = struct_to_dict(function_call.args)
    if function_call.name == "":
        function = EMPTY_FUNCTION_NAME  # sometimes the API returns an empty string
    else:
        function = function_call.name
    return FunctionCall(function=function, args=args, id="")


def _google_to_assistant_message(
    message: genai.GenerationResponse,
) -> ChatAssistantMessage:
    tool_calls = []
    text = []
    for part in message.candidates[0].content.parts:
        if fn := part.function_call:
            tool_calls.append(_google_to_tool_call(fn))
        try:
            if part.text:
                text.append(part.text)
        except AttributeError:
            pass
    return ChatAssistantMessage(
        role="assistant",
        content=" ".join(text) if text else None,
        tool_calls=tool_calls,
    )


class GoogleLLM(BasePipelineElement):
    def __init__(
        self,
        model: str,
        temperature: float | None = 0.0,
    ):
        self.model = model
        self.generation_config = genai.GenerationConfig(temperature=temperature)

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        first_message, *other_messages = messages
        if first_message["role"] == "system":
            system_instruction = first_message["content"] or ""
        else:
            system_instruction = None
            other_messages = messages
        google_messages = [_message_to_google(message) for message in other_messages]
        google_messages = _merge_tool_result_messages(google_messages)
        google_functions = [_function_to_google(tool) for tool in tools]
        google_tools = [genai.Tool(function_declarations=google_functions)] if google_functions else []
        model = genai.GenerativeModel(
            model_name=self.model,
            system_instruction=genai.Part.from_text(text=system_instruction) if system_instruction else None,
        )
        completion = chat_completion_request(
            model,
            google_messages,
            tools=google_tools,
            generation_config=self.generation_config,
        )
        output = _google_to_assistant_message(completion)
        messages = [*messages, output]
        return query, tools, env, messages, extra_args
