import json
import typing as T

import tiktoken
from anthropic import AnthropicVertex, AsyncAnthropicVertex
from google.cloud.aiplatform_v1beta1.types import content
from google.oauth2 import service_account
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.llms.llm import LLM
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.azure_inference import AzureAICompletionsModel
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.vertex import Vertex
from mypy_extensions import DefaultNamedArg

from minimal.configuration import NON_OPENAI_CONTEXT_WINDOW_FACTOR, cfg
from minimal.logger import logger
from minimal.patches import (
    _get_all_kwargs,
)

Anthropic._get_all_kwargs = _get_all_kwargs  # type: ignore


def _scale(
    context_window_length: int, factor: float = NON_OPENAI_CONTEXT_WINDOW_FACTOR
) -> int:
    return int(context_window_length * factor)


AZURE_GPT35_TURBO = AzureOpenAI(
    model="gpt-3.5-turbo",
    deployment_name="gpt-35",
    api_key=cfg.azure_oai.api_key.get_secret_value(),
    azure_endpoint=str(cfg.azure_oai.api_url),
    api_version="2024-06-01",
    temperature=0,
    max_retries=0,
    additional_kwargs={"user": "flowgen"},
)

AZURE_GPT4O_MINI = AzureOpenAI(
    model="gpt-4o-mini",
    deployment_name="gpt-4o-mini",
    api_key=cfg.azure_oai.api_key.get_secret_value(),
    azure_endpoint=str(cfg.azure_oai.api_url),
    api_version="2024-06-01",
    temperature=0,
    max_retries=0,
    additional_kwargs={"user": "flowgen"},
)

AZURE_GPT4O_STD = AzureOpenAI(
    model="gpt-4o",
    deployment_name="gpt-4o",
    api_key=cfg.azure_oai.api_key.get_secret_value(),
    azure_endpoint=str(cfg.azure_oai.api_url),
    api_version="2024-06-01",
    temperature=0,
    max_retries=0,
    additional_kwargs={"user": "flowgen"},
)

AZURE_o1 = AzureOpenAI(
    model="o1",
    deployment_name="o1",
    api_key=cfg.azure_oai.api_key.get_secret_value(),
    azure_endpoint=str(cfg.azure_oai.api_url),
    api_version="2024-12-01-preview",
    temperature=0,
    max_retries=0,
    additional_kwargs={"user": "flowgen"},
)

AZURE_o3_MINI = AzureOpenAI(
    model="o3-mini",
    deployment_name="o3-mini",
    api_key=cfg.azure_oai.api_key.get_secret_value(),
    azure_endpoint=str(cfg.azure_oai.api_url),
    api_version="2024-12-01-preview",
    temperature=0,
    max_retries=0,
    additional_kwargs={"user": "flowgen"},
)

GCP_SAFETY_SETTINGS = {
    content.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: content.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
    content.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: content.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
    content.HarmCategory.HARM_CATEGORY_HARASSMENT: content.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
    content.HarmCategory.HARM_CATEGORY_HATE_SPEECH: content.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
}

GCP_GEMINI_PRO = Vertex(
    model="gemini-1.5-pro-002",
    project=cfg.gcp_vertex.project_id,
    credentials=service_account.Credentials.from_service_account_info(
        json.loads(cfg.gcp_vertex.credentials.get_secret_value())
    ),
    temperature=0,
    safety_settings=GCP_SAFETY_SETTINGS,
    max_tokens=8000,
    context_window=_scale(2000000),
    max_retries=0,
    additional_kwargs={},
)

GCP_GEMINI_FLASH = Vertex(
    model="gemini-1.5-flash-002",
    project=cfg.gcp_vertex.project_id,
    credentials=service_account.Credentials.from_service_account_info(
        json.loads(cfg.gcp_vertex.credentials.get_secret_value())
    ),
    temperature=0,
    safety_settings=GCP_SAFETY_SETTINGS,
    context_window=_scale(1048000),
    max_tokens=8000,
    max_retries=0,
    additional_kwargs={},
)

GCP_GEMINI_FLASH2 = Vertex(
    model="gemini-2.0-flash-001",
    project=cfg.gcp_vertex.project_id,
    credentials=service_account.Credentials.from_service_account_info(
        json.loads(cfg.gcp_vertex.credentials.get_secret_value())
    ),
    temperature=0,
    max_tokens=8000,
    context_window=_scale(1048000),
    max_retries=0,
    safety_settings=GCP_SAFETY_SETTINGS,
    additional_kwargs={},
)


def add_scoped_credentials_anthropic(anthropic_llm: Anthropic) -> Anthropic:
    """Add Google service account credentials to an Anthropic LLM"""
    credentials = service_account.Credentials.from_service_account_info(
        json.loads(cfg.gcp_vertex.credentials.get_secret_value())
    ).with_scopes(["https://www.googleapis.com/auth/cloud-platform"])
    sync_client = anthropic_llm._client
    assert isinstance(sync_client, AnthropicVertex)
    sync_client.credentials = credentials
    anthropic_llm._client = sync_client
    async_client = anthropic_llm._aclient
    assert isinstance(async_client, AsyncAnthropicVertex)
    async_client.credentials = credentials
    anthropic_llm._aclient = async_client
    return anthropic_llm


ANTHROPIC_CLAUDE_SONNET_35 = add_scoped_credentials_anthropic(
    Anthropic(
        model="claude-3-5-sonnet-v2@20241022",
        project_id=str(cfg.gcp_vertex.project_id),
        region="us-east5",
        temperature=0,
    )
)


ANTHROPIC_CLAUDE_HAIKU_35 = add_scoped_credentials_anthropic(
    Anthropic(
        model="claude-3-5-haiku@20241022",
        project_id=str(cfg.gcp_vertex.project_id),
        region="us-east5",
        temperature=0,
    )
)


class AzureAICompletionsModelLlama(AzureAICompletionsModel):
    def __init__(self, credential, model_name, endpoint, temperature=0):
        super().__init__(
            credential=credential,
            model_name=model_name,
            endpoint=endpoint,
            temperature=temperature,
        )

    @property
    def metadata(self):
        return LLMMetadata(
            context_window=120000,
            num_output=1000,
            is_chat_model=True,
            is_function_calling_model=True,
            model_name="Llama-3.3-70B-Instruct",
        )


class AzureAICompletionsModelPhi4(AzureAICompletionsModel):
    def __init__(self, credential, model_name, endpoint, temperature=0):
        super().__init__(
            credential=credential,
            model_name=model_name,
            endpoint=endpoint,
            temperature=temperature,
        )

    @property
    def metadata(self):
        return LLMMetadata(
            context_window=14000,
            num_output=1000,
            is_chat_model=True,
            is_function_calling_model=True,
            model_name="Phi-4",
        )


class AzureAICompletionsModelR1(AzureAICompletionsModel):
    def __init__(self, credential, model_name, endpoint, temperature=0):
        super().__init__(
            credential=credential,
            model_name=model_name,
            endpoint=endpoint,
            temperature=temperature,
        )

    @property
    def metadata(self):
        return LLMMetadata(
            context_window=120000,
            num_output=8000,
            is_chat_model=True,
            is_function_calling_model=False,
            model_name="Deepseek-R1",
        )


AZURE_LLAMA33_70B = AzureAICompletionsModelLlama(
    credential=cfg.azure_inference_llama33.api_key.get_secret_value(),  # type: ignore
    model_name="AzureLlama3370B",  # type: ignore
    endpoint=(  # type: ignore
        "https://"
        + "Llama-3-3-70B-Instruct-drocto"
        + "."
        + str(cfg.azure_inference_llama33.region_name)
        + ".models.ai.azure.com"
    ),
    temperature=0,  # type: ignore
)

AZURE_PHI4 = AzureAICompletionsModelPhi4(
    credential=cfg.azure_inference_phi4.api_key.get_secret_value(),  # type: ignore
    model_name="AzurePhi-4",  # type: ignore
    endpoint=(  # type: ignore
        "https://"
        + "Phi-4-drocto"
        + "."
        + str(cfg.azure_inference_llama33.region_name)
        + ".models.ai.azure.com"
    ),
    temperature=0,  # type: ignore
)


class AzureAICompletionsModelMistral(AzureAICompletionsModel):
    def __init__(self, credential, model_name, endpoint, temperature=0):
        super().__init__(
            credential=credential,
            model_name=model_name,
            endpoint=endpoint,
            temperature=temperature,
        )

    @property
    def metadata(self):
        return LLMMetadata(
            context_window=120000,
            num_output=2056,
            is_chat_model=True,
            is_function_calling_model=True,
            model_name="mistral-large-2411",
        )


MISTRAL_LARGE = AzureAICompletionsModelMistral(
    credential=cfg.azure_inference_mistral.api_key.get_secret_value(),  # type: ignore
    model_name="MistralLarge2411",  # type: ignore
    endpoint=(  # type: ignore
        "https://"
        + "Mistral-Large-2411-drocto"
        + "."
        + str(cfg.azure_inference_mistral.region_name)
        + ".models.ai.azure.com"
    ),
    temperature=0,  # type: ignore
)

# When you add model, make sure all tests pass successfully
LLMs = {
    "o3-mini": AZURE_o3_MINI,
    "gpt-4o-mini": AZURE_GPT4O_MINI,
    "gpt-4o-std": AZURE_GPT4O_STD,
    "anthropic-sonnet-35": ANTHROPIC_CLAUDE_SONNET_35,
    "anthropic-haiku-35": ANTHROPIC_CLAUDE_HAIKU_35,
    "gemini-pro": GCP_GEMINI_PRO,
    "gemini-flash": GCP_GEMINI_FLASH,
    "gemini-flash2": GCP_GEMINI_FLASH2,
    "llama-33-70B": AZURE_LLAMA33_70B,
    "mistral-large": MISTRAL_LARGE,
    "phi-4": AZURE_PHI4,
}


def get_llm(name: str | None = None):
    if not name:
        logger.warning("No LLM name specified.")
        return None
    assert name in LLMs, (
        f"Invalid LLM name specified: {name}. Valid options are: {list(LLMs.keys())}"
    )
    return LLMs[name]


def get_llm_name(llm: FunctionCallingLLM | None = None):
    for llm_name, llm_instance in LLMs.items():
        if llm == llm_instance:
            return llm_name
    raise ValueError("Invalid LLM specified")


def is_function_calling(llm: LLM):
    try:
        if getattr(llm.metadata, "is_function_calling_model", False):
            if "flash" in llm.metadata.model_name:
                return False
            return True
    except ValueError:
        return False


def get_tokenizer(
    name: str,
) -> T.Callable[
    [
        str,
        DefaultNamedArg(T.Literal["all"] | T.AbstractSet[str], "allowed_special"),
        DefaultNamedArg(T.Literal["all"] | T.Collection[str], "disallowed_special"),
    ],
    list[int],
]:
    if name in [
        "o1",
        "o3-mini",
        "gpt-4o-mini",
        "gpt-4o-std",
        "anthropic-sonnet-35",
        "anthropic-haiku-35",
        "llama-33-70B",
        "mistral-large",
        "gemini-pro",
        "gemini-flash",
        "gemini-flash2",
        "phi-4",
    ]:
        return tiktoken.encoding_for_model("gpt-4o-mini").encode
    raise ValueError("Invalid tokenizer specified")


if __name__ == "__main__":
    print(get_llm_name(MISTRAL_LARGE))
