from typing import Optional

from src.retriever.base_retriever import BaseRetriever
from src.retriever.dense_retriever import DenseRetriever
from src.retriever.gt_retriever import GTRetriever
from src.retriever.sparse_retriever import SparseRetriever
from src.vector_db import vector_db_factory
from src.vector_db.embedding_fn import BaseEmbeddingFunction


__all__ = [
    "retriever_factory",
    "BaseRetriever",
]


def retriever_factory(
    retriever_name: str,
    memory_db_uri: str,
    embedding_fn: BaseEmbeddingFunction,
    vector_db_name: Optional[str],
    top_k: int,
) -> BaseRetriever:
    if retriever_name == "sparse":
        return SparseRetriever(
            memory_db_uri=memory_db_uri,
            embedding_fn=embedding_fn,
            top_k=top_k,
        )
    elif retriever_name == "dense":
        vector_db = vector_db_factory(
            vector_db_name=vector_db_name,
            uri=memory_db_uri,
        )
        return DenseRetriever(
            embedding_fn=embedding_fn,
            vector_db=vector_db,
            top_k=top_k,
        )
    elif retriever_name == "gt":
        assert not vector_db_name
        return GTRetriever(
            memory_db_uri=memory_db_uri,
        )
    else:
        raise ValueError(f"Unknown retriever_name: {retriever_name}")
