from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction
from src.vector_db.embedding_fn.bm25_embedding_fn import BM25EmbeddingFunction
from src.vector_db.embedding_fn.contriever_embedding_fn import ContrieverEmbeddingFunction
from src.vector_db.embedding_fn.dpr_embedding_fn import DPREmbeddingFunction
from src.vector_db.embedding_fn.dragon_embedding_fn import DRAGONEmbeddingFunction
from src.vector_db.embedding_fn.instructor_embedding_fn import InstructorEmbeddingFunction
from src.vector_db.embedding_fn.sbert_embedding_fn import SBERTEmbeddingFunction


__all__ = [
    "SPARSE_EMBEDDING_FN_NAMES",
    "DENSE_EMBEDDING_FN_NAMES",
    "embedding_fn_factory",
    "BaseEmbeddingFunction",
]


SPARSE_EMBEDDING_FN_NAMES = [
    "bm25",
]

DENSE_EMBEDDING_FN_NAMES = [
    "sbert",
    "dpr",
    "contriever",
    "instructor",
    "dragon",
]


def embedding_fn_factory(embedding_fn_name: str) -> BaseEmbeddingFunction:
    if embedding_fn_name == "bm25":
        return BM25EmbeddingFunction()
    elif embedding_fn_name == "sbert":
        return SBERTEmbeddingFunction()
    elif embedding_fn_name == "dpr":
        return DPREmbeddingFunction()
    elif embedding_fn_name == "contriever":
        return ContrieverEmbeddingFunction()
    elif embedding_fn_name == "instructor":
        return InstructorEmbeddingFunction()
    elif embedding_fn_name == "dragon":
        return DRAGONEmbeddingFunction()
    elif embedding_fn_name == "gt":
        return None
    elif embedding_fn_name == "no":
        return None
    else:
        raise ValueError(f"Unknown embedding_fn_name: {embedding_fn_name}")
