from typing import (
    List,
    Optional,
)
from pathlib import Path

import fire
import numpy as np
from loguru import logger

from src.vector_db import vector_db_factory
from src.vector_db.embedding_fn import (
    DENSE_EMBEDDING_FN_NAMES,
    SPARSE_EMBEDDING_FN_NAMES,
    BaseEmbeddingFunction,
    embedding_fn_factory,
)
from src.schema import Document
from src.utils.json import (
    read_json_file,
    write_json_file,
)
from src.utils.log import set_log_level


def construct_sparse_doc_db(
    uri: str,
    docs: List[Document],
    embedding_fn: BaseEmbeddingFunction,
) -> None:
    uri_path = Path(uri)
    uri_path.mkdir(parents=True, exist_ok=True)

    # step 1. fit embedding function
    doc_contents = [doc.content for doc in docs]
    embedding_fn.fit(doc_contents)
    embedding_fn.save(str(uri_path / "embedding_fn_params.json"))

    # step 2. get embedding map and doc map
    embedding_map = {}
    doc_map = {}

    collection_name = "all"
    vectors = embedding_fn.embed_documents(documents=doc_contents)
    embedding_map[collection_name] = vectors

    doc_map[collection_name] = [doc.model_dump(mode="json") for doc in docs]

    # step 3. save embedding map and doc map
    np.savez_compressed(
        file=str(uri_path / "embedding_map.npz"),
        **embedding_map,
    )

    write_json_file(
        file_path=str(uri_path / "doc_map.json"),
        data=doc_map,
    )


def construct_dense_doc_db(
    uri: str,
    docs: List[Document],
    embedding_fn: BaseEmbeddingFunction,
    vector_db_name: str,
) -> None:
    uri_path = Path(uri)
    uri_path.parent.mkdir(parents=True, exist_ok=True)

    vector_db = vector_db_factory(vector_db_name=vector_db_name, uri=uri)

    collection_name = "all"

    try:
        vector_db.drop_collection(collection_name=collection_name)
    except:
        pass

    vector_db.create_collection(
        collection_name=collection_name,
        dimension=embedding_fn.dim,
        metric_type=embedding_fn.metric_type,
        id_type="string",
    )

    if not docs:
        logger.warning(f"Empty documents, nothing to do.")
        return

    vectors = embedding_fn.embed_documents(documents=[doc.content for doc in docs])

    data = [
        {
            "id": doc.document_id,
            "vector": vector,
            "content": doc.content,
        }
        for doc, vector in zip(docs, vectors)
    ]

    vector_db.insert(collection_name=collection_name, data=data)
    # vector_db.unload_collection(collection_name=collection_name)


def construct_doc_db(
    doc_data_path: str,
    doc_db_uri: str,
    embedding_fn_name: str,
    vector_db_name: Optional[str] = None,
) -> None:
    set_log_level()

    doc_data = read_json_file(file_path=doc_data_path)
    docs = [Document(**doc_dict) for doc_dict in doc_data]

    embedding_fn = embedding_fn_factory(embedding_fn_name=embedding_fn_name)

    if embedding_fn_name in SPARSE_EMBEDDING_FN_NAMES:
        construct_sparse_doc_db(
            uri=doc_db_uri,
            docs=docs,
            embedding_fn=embedding_fn,
        )
    elif embedding_fn_name in DENSE_EMBEDDING_FN_NAMES:
        construct_dense_doc_db(
            uri=doc_db_uri,
            docs=docs,
            embedding_fn=embedding_fn,
            vector_db_name=vector_db_name,
        )
    else:
        raise ValueError(f"Invalid embedding_fn_name: {embedding_fn_name}")

    logger.success("Done!")


if __name__ == "__main__":
    fire.Fire(construct_doc_db)
