from typing import (
    Any,
    List,
)

from nptyping import (
    Float32,
    NDArray,
)
from pymilvus.model.sparse import BM25EmbeddingFunction as PyMilvusBM25EmbeddingFunction
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer

from src.utils.matrix import csr_matrix_to_numpy
from src.vector_db.embedding_fn import BaseEmbeddingFunction


class BM25EmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.analyzer = build_default_analyzer(language="en")
        self.embedding_fn = PyMilvusBM25EmbeddingFunction(self.analyzer)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        vectors = self.embedding_fn.encode_documents(documents)
        vectors_npy = csr_matrix_to_numpy(matrix=vectors)
        return vectors_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        vectors = self.embedding_fn.encode_queries(queries)
        vectors_npy = csr_matrix_to_numpy(matrix=vectors)
        return vectors_npy

    def fit(self, documents: List[str]) -> None:
        self.embedding_fn.fit(documents)

    def save(self, path: str) -> None:
        self.embedding_fn.save(path)

    def load(self, path: str) -> None:
        self.embedding_fn.load(path)

    def get_analyzer(self):
        return self.analyzer

    @property
    def dim(self) -> int:
        return self.embedding_fn.dim

    @property
    def metric_type(self) -> str:
        return None
