from typing import (
    Any,
    List,
)

from nptyping import (
    Float32,
    NDArray,
)
from pymilvus.model.dense import SentenceTransformerEmbeddingFunction

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction


class SBERTEmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.embedding_fn = SentenceTransformerEmbeddingFunction(
            model_name="sentence-transformers/all-mpnet-base-v2",
            device="cpu",
        )

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        return self.embedding_fn.encode_documents(documents)

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        return self.embedding_fn.encode_queries(queries)

    @property
    def dim(self) -> int:
        return self.embedding_fn.dim
    
    @property
    def metric_type(self) -> str:
        return "IP"
