#!/usr/bin/python3
"""
Embedding model implementations and API.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import hashlib
import numpy as np
import os
import torch
from enum import Enum
from llama_index.embeddings.azure_openai import (  # type: ignore
    AzureOpenAIEmbedding
)
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.base.embeddings.base import Embedding
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from typing import Final, Iterable, List, Tuple


__all__ = [
    "EmbedderOptions",
    "RandomEmbedding",
    "HuggingFaceEmbedding",
    "SentenceTransformerEmbedding",
    "get_embedder_options",
    "get_embedder"
]


def _hash(x: str) -> int:
    """
    Deterministically hashes a string to an integer.
    Input:
        x: the string to hash.
    Returns:
        The integer representation of the hash value of the input string.
    """
    return int(hashlib.sha256(x.encode("utf-8")).hexdigest(), 16)


class EmbedderOptions(str, Enum):
    RANDOM = "leon/random"
    TEXT_EMBEDDING_ADA_002 = "openai/text-embedding-ada-002"
    TEXT_EMBEDDING_3_SMALL = "openai/text-embedding-3-small"
    TEXT_EMBEDDING_3_LARGE = "openai/text-embedding-3-large"
    E5_MISTRAL_7B_INSTRUCT = "intfloat/e5-mistral-7b-instruct"
    MEDEMBED_LARGE = "abhinand/MedEmbed-large-v0.1"
    BIOCLINICAL_BERT = "emilyalsentzer/Bio_ClinicalBERT"


def get_embedder_options() -> List[str]:
    """
    Returns a list of the implemented embedders available for use.
    Input:
        None.
    Returns:
        A list of the implemented embedders available for use.
    """
    return [opt.value for opt in EmbedderOptions]


class RandomEmbedding(BaseEmbedding):
    def __init__(self, dimensions: int = 1024, **kwargs):
        """
        Args:
            dimensions: the number of embedding dimensions to use.
        """
        assert dimensions > 1
        super(RandomEmbedding, self).__init__(
            model_name="RandomEmbedding", **kwargs
        )
        self._embedding_dim: Final[int] = dimensions

    def _rand_embed(self, x: str) -> torch.Tensor:
        """
        Embed an input text.
        Input:
            x: the text to embed.
        Returns:
            A tensor of shape D, where D is the embedding dimension.
        """
        rng = np.random.default_rng(seed=_hash(x))
        embedding = torch.tensor(rng.normal(size=self._embedding_dim))
        return embedding / torch.norm(embedding, p=2)

    def _get_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query synchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        return self._get_text_embedding(query)

    async def _aget_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query asynchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        return self._get_query_embedding(query)

    def _get_text_embedding(self, text: str) -> Embedding:
        """
        Embed the input text synchronously.
        Input:
            text: the text to embed.
        Returns:
            The embedding representation of the input text.
        """
        return self._rand_embed(text).detach().tolist()


class HuggingFaceEmbedding(BaseEmbedding):
    def __init__(self, model_id: str, **kwargs):
        """
        Args:
            model_id: the ID of the embedding model to use from Hugging Face.
        """
        super(HuggingFaceEmbedding, self).__init__(
            model_name=model_id.split("/")[-1], **kwargs
        )
        self._model_id: Final[str] = model_id
        self._tokenizer = AutoTokenizer.from_pretrained(self._model_id)
        self._model = AutoModel.from_pretrained(self._model_id)

    def parameters(self) -> Iterable[torch.Tensor]:
        """
        Returns the parameters associated with the embedding model.
        Input:
            None.
        Returns:
            An iterable over the embedding model's parameters.
        """
        return self._model.parameters()

    def named_parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
        """
        Returns the named parameters associated with the embedding model.
        Input:
            None.
        Returns:
            An iterable over the embedding model's named parameters.
        """
        return self._model.named_parameters()

    def _hf_embed(self, x: str) -> torch.Tensor:
        """
        Embed an input text.
        Input:
            x: the text to embed.
        Returns:
            A tensor of shape D, where D is the embedding dimension.
        """
        inputs = self._tokenizer([x], return_tensors="pt")
        with torch.no_grad():
            embeddings = self._model(**inputs).last_hidden_state
        if embeddings.ndim > 2:
            embeddings = embeddings.mean(dim=1)
        return (embeddings / torch.norm(embeddings, p=2, dim=-1)).squeeze(0)

    def _get_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query synchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        return self._get_text_embedding(query)

    async def _aget_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query asynchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        return self._get_query_embedding(query)

    def _get_text_embedding(self, text: str) -> Embedding:
        """
        Embed the input text synchronously.
        Input:
            text: the text to embed.
        Returns:
            The embedding representation of the input text.
        """
        return self._hf_embed(text).detach().tolist()


class SentenceTransformerEmbedding(BaseEmbedding):
    def __init__(self, model_id: str, **kwargs):
        super(SentenceTransformerEmbedding, self).__init__(
            model_name=model_id.split("/")[-1], **kwargs
        )
        self._model_id: Final[str] = model_id
        self._sbert = SentenceTransformer(self._model_id)

    def _get_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query synchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        embedding = self._sbert.encode(query, prompt_name="web_search_query")
        return embedding.detach().tolist()

    async def _aget_query_embedding(self, query: str) -> Embedding:
        """
        Embed the input query asynchronously.
        Input:
            query: the query to embed.
        Returns:
            The embedding representation of the input query.
        """
        return self._get_query_embedding(query)

    def _get_text_embedding(self, text: str) -> Embedding:
        """
        Embed the input text synchronously.
        Input:
            text: the text to embed.
        Returns:
            The embedding representation of the input text.
        """
        return self._sbert.encode(text).tolist()


def get_embedder(
    knowledge_embedder: str, **kwargs
) -> BaseEmbedding:
    """
    Returns a specified embedder.
    Input:
        knowledge_embedder: the string name of the embedder to return.
    Returns:
        The specified embedder.
    """
    if knowledge_embedder not in get_embedder_options():
        raise NotImplementedError

    if knowledge_embedder == EmbedderOptions.RANDOM.value:
        return RandomEmbedding()
    if knowledge_embedder == EmbedderOptions.TEXT_EMBEDDING_ADA_002.value:
        return AzureOpenAIEmbedding(
            azure_endpoint=os.environ.get("API_ENDPOINT_EMBED", None),
            api_version=os.environ.get("AZURE_API_VERSION", None),
            api_key=os.environ.get("API_KEY_EMBED", None),
            model=knowledge_embedder.split("/", 1)[-1],
            **kwargs
        )
    elif knowledge_embedder in [
        EmbedderOptions.TEXT_EMBEDDING_3_SMALL.value,
        EmbedderOptions.TEXT_EMBEDDING_3_LARGE.value
    ]:
        return AzureOpenAIEmbedding(
            azure_endpoint=os.environ.get("API_ENDPOINT_EMBED", None),
            api_version=os.environ.get("AZURE_API_VERSION", None),
            api_key=os.environ.get("API_KEY_EMBED", None),
            model=knowledge_embedder.split("/", 1)[-1],
            **kwargs
        )
    elif knowledge_embedder == EmbedderOptions.E5_MISTRAL_7B_INSTRUCT.value:
        return SentenceTransformerEmbedding(knowledge_embedder, **kwargs)
    return HuggingFaceEmbedding(knowledge_embedder, **kwargs)
