from src.vector_db.base_vector_db import BaseVectorDB
from src.vector_db.milvus_vector_db import MilvusVectorDB
from src.vector_db.naive_vector_db import NaiveVectorDB


__all__ = [
    "vector_db_factory",
    "BaseVectorDB",
]


def vector_db_factory(vector_db_name: str, uri: str) -> BaseVectorDB:
    if vector_db_name == "milvus":
        return MilvusVectorDB(uri=uri)
    elif vector_db_name == "naive":
        return NaiveVectorDB(uri=uri)
    else:
        raise ValueError(f"Unknown vector_db_name: {vector_db_name}")
