from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
    Any,
    List,
    Optional,
)

import fire
import numpy as np

from loguru import logger
from tqdm import tqdm

from src.schema import User
from src.utils.json import (
    read_json_file,
    write_json_file,
)
from src.utils.log import set_log_level
from src.vector_db import vector_db_factory
from src.vector_db.embedding_fn import (
    DENSE_EMBEDDING_FN_NAMES,
    SPARSE_EMBEDDING_FN_NAMES,
    BaseEmbeddingFunction,
    embedding_fn_factory,
)


def construct_sparse_memory_db(
    uri: str,
    users: List[User],
    embedding_fn: BaseEmbeddingFunction,
) -> None:
    uri_path = Path(uri)
    uri_path.mkdir(parents=True, exist_ok=True)

    # step 1. fit embedding function
    total_memory_contents = [memory.content for user in users for memory in user.memories]
    embedding_fn.fit(total_memory_contents)
    embedding_fn.save(str(uri_path / "embedding_fn_params.json"))

    # step 2. get embedding map and memory map
    embedding_map = {}
    memory_map = {}

    for user in tqdm(users):
        collection_name = user.user_id
        memories = user.memories

        memory_contents = [memory.content for memory in memories]
        vectors = embedding_fn.embed_documents(documents=memory_contents)
        embedding_map[collection_name] = vectors

        memory_map[collection_name] = [memory.model_dump(mode="json") for memory in memories]

    # step 3. save embedding map and memory map
    np.savez_compressed(
        file=str(uri_path / "embedding_map.npz"),
        **embedding_map,
    )

    write_json_file(
        file_path=str(uri_path / "memory_map.json"),
        data=memory_map,
    )


def construct_dense_memory_db_for_single_user(
    args: Any,
) -> None:
    user, vector_db, embedding_fn = args

    collection_name = user.user_id
    memories = user.memories

    try:
        vector_db.drop_collection(collection_name=collection_name)
    except:
        pass

    vector_db.create_collection(
        collection_name=collection_name,
        dimension=embedding_fn.dim,
        metric_type=embedding_fn.metric_type,
        id_type="string",
    )

    if not memories:
        logger.warning(f"Empty memories for user {user.user_id}; skipping memory db construction")
        return

    vectors = embedding_fn.embed_documents(documents=[memory.content for memory in memories])

    data = [
        {
            "id": memory.memory_id,
            "vector": vector,
            "content": memory.content,
        }
        for memory, vector in zip(memories, vectors)
    ]

    vector_db.insert(collection_name=collection_name, data=data)
    vector_db.unload_collection(collection_name=collection_name)


def construct_dense_memory_db(
    uri: str,
    users: List[User],
    embedding_fn: BaseEmbeddingFunction,
    vector_db_name: str,
    max_workers: Optional[int],
) -> None:
    uri_path = Path(uri)
    uri_path.parent.mkdir(parents=True, exist_ok=True)

    vector_db = vector_db_factory(vector_db_name=vector_db_name, uri=uri)

    with ThreadPoolExecutor(max_workers=max_workers or len(users)) as pool:
        _ = list(
            tqdm(
                pool.map(
                    construct_dense_memory_db_for_single_user,
                    [(user, vector_db, embedding_fn) for user in users],
                ),
                total=len(users),
            )
        )


def construct_memory_db(
    memory_data_path: str,
    memory_db_uri: str,
    embedding_fn_name: str,
    vector_db_name: Optional[str] = None,  # Required for dense embedding functions
    max_workers: Optional[int] = None,
) -> None:
    set_log_level()

    memory_data = read_json_file(file_path=memory_data_path)
    users = [User(**user_dict) for user_dict in memory_data]

    embedding_fn = embedding_fn_factory(embedding_fn_name=embedding_fn_name)

    if embedding_fn_name in SPARSE_EMBEDDING_FN_NAMES:
        construct_sparse_memory_db(
            uri=memory_db_uri,
            users=users,
            embedding_fn=embedding_fn,
        )
    elif embedding_fn_name in DENSE_EMBEDDING_FN_NAMES:
        construct_dense_memory_db(
            uri=memory_db_uri,
            users=users,
            embedding_fn=embedding_fn,
            vector_db_name=vector_db_name,
            max_workers=max_workers,
        )
    else:
        raise ValueError(f"Invalid embedding_fn_name: {embedding_fn_name}")

    logger.info("Done!")


if __name__ == "__main__":
    fire.Fire(construct_memory_db)
