import os
from typing import (
    Any,
    List,
)

from nptyping import (
    Float32,
    NDArray,
)
from pymilvus.model.dense import InstructorEmbeddingFunction as PyMilvusInstructorEmbeddingFunction

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class InstructorEmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.embedding_fn = PyMilvusInstructorEmbeddingFunction(
            model_name="hkunlp/instructor-xl",
            query_instruction="Represent the question for retrieval:",
            doc_instruction="Represent the document for retrieval:",
        )

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        return self.embedding_fn.encode_documents(documents)

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        return self.embedding_fn.encode_queries(queries)

    @property
    def dim(self) -> int:
        return self.embedding_fn.dim
    
    @property
    def metric_type(self) -> str:
        return "COSINE"
