from typing import (
    Any,
    List,
    Optional,
)
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

import fire
import numpy as np
from tqdm import tqdm
from loguru import logger

from src.vector_db import vector_db_factory
from src.vector_db.embedding_fn import (
    DENSE_EMBEDDING_FN_NAMES,
    SPARSE_EMBEDDING_FN_NAMES,
    BaseEmbeddingFunction,
    embedding_fn_factory,
)
from src.schema import User
from src.utils.json import (
    read_json_file,
    write_json_file,
)
from src.utils.log import set_log_level


def construct_sparse_mem_db(
    uri: str,
    users: List[User],
    mem_type: str,
    embedding_fn: BaseEmbeddingFunction,
) -> None:
    uri_path = Path(uri)
    uri_path.mkdir(parents=True, exist_ok=True)

    # step 1. fit embedding function
    if mem_type == "dialogue":
        total_mem_contents = [mem.content for user in users for mem in user.dialogue]
    elif mem_type == "observation":
        total_mem_contents = [mem.content for user in users for mem in user.observation]
    elif mem_type == "summary":
        total_mem_contents = [mem.content for user in users for mem in user.summary]
    elif mem_type == "episodic_memory":
        total_mem_contents = [mem.content for user in users for mem in user.episodic_memory]
    else:
        raise ValueError(f"Invalid mem_type: {mem_type}")

    embedding_fn.fit(total_mem_contents)
    embedding_fn.save(str(uri_path / "embedding_fn_params.json"))

    # step 2. get embedding map and mem map
    embedding_map = {}
    mem_map = {}

    for user in tqdm(users):
        collection_name = user.user_id
        memories = user.memories

        mem_contents = [mem.content for mem in memories]
        vectors = embedding_fn.embed_documents(documents=mem_contents)
        embedding_map[collection_name] = vectors

        mem_map[collection_name] = [mem.model_dump(mode="json") for mem in memories]

    # step 3. save embedding map and mem map
    np.savez_compressed(
        file=str(uri_path / "embedding_map.npz"),
        **embedding_map,
    )

    write_json_file(
        file_path=str(uri_path / "mem_map.json"),
        data=mem_map,
    )


def construct_dense_mem_db_for_single_user(
    args: Any,
) -> None:
    user, mem_type, vector_db, embedding_fn = args

    collection_name = user.user_id
    
    if mem_type == "dialogue":
        memories = user.dialogue
    elif mem_type == "observation":
        memories = user.observation
    elif mem_type == "summary":
        memories = user.summary
    elif mem_type == "episodic_memory":
        memories = user.episodic_memory
    else:
        raise ValueError(f"Invalid mem_type: {mem_type}")

    try:
        vector_db.drop_collection(collection_name=collection_name)
    except Exception as e:
        pass

    try:
        vector_db.create_collection(
            collection_name=collection_name,
            dimension=embedding_fn.dim,
            metric_type=embedding_fn.metric_type,
            id_type="string",
        )
    except Exception as e:
        raise RuntimeError(f"Error creating collection for user {collection_name}: {e}, embedding_fn.metric_type: {embedding_fn.metric_type}, embedding_fn.dim: {embedding_fn.dim}")

    if not memories:
        logger.warning(f"Empty memories for user {collection_name}; skipping mem db construction")
        return

    try:
        vectors = embedding_fn.embed_documents(documents=[mem.content for mem in memories])

        data = [
            {
                "id": mem.memory_id,
                "vector": vector,
                "content": mem.content,
            }
            for mem, vector in zip(memories, vectors)
        ]

        vector_db.insert(collection_name=collection_name, data=data)
    except Exception as e:
        error_documents = [mem.content for mem in memories]
        raise RuntimeError(f"Error inserting documents for user {user.user_id}: {e}")

    # vector_db.unload_collection(collection_name=collection_name)


def construct_dense_mem_db(
    uri: str,
    users: List[User],
    mem_type: str,
    embedding_fn: BaseEmbeddingFunction,
    vector_db_name: str,
    max_workers: Optional[int] = None,
) -> None:
    uri_path = Path(uri)
    uri_path.parent.mkdir(parents=True, exist_ok=True)

    vector_db = vector_db_factory(vector_db_name=vector_db_name, uri=uri)

    # with ThreadPoolExecutor(max_workers=max_workers or len(users)) as pool:
    #     _ = list(
    #         tqdm(
    #             pool.map(
    #                 construct_dense_mem_db_for_single_user,
    #                 [(user, mem_type, vector_db, embedding_fn) for user in users],
    #             ),
    #             total=len(users),
    #         )
    #     )

    for user in tqdm(users):
        construct_dense_mem_db_for_single_user((user, mem_type, vector_db, embedding_fn))


def construct_mem_db(
    mem_data_path: str,
    mem_db_uri: str,
    mem_type: str,
    embedding_fn_name: str,
    vector_db_name: Optional[str] = None,
    max_workers: Optional[int] = None,
) -> None:
    set_log_level()

    mem_data = read_json_file(file_path=mem_data_path)
    users = [User(**user_dict) for user_dict in mem_data]

    embedding_fn = embedding_fn_factory(embedding_fn_name=embedding_fn_name)

    if embedding_fn_name in SPARSE_EMBEDDING_FN_NAMES:
        construct_sparse_mem_db(
            uri=mem_db_uri,
            users=users,
            mem_type=mem_type,
            embedding_fn=embedding_fn,
        )
    elif embedding_fn_name in DENSE_EMBEDDING_FN_NAMES:
        construct_dense_mem_db(
            uri=mem_db_uri,
            users=users,
            mem_type=mem_type,
            embedding_fn=embedding_fn,
            vector_db_name=vector_db_name,
            max_workers=max_workers,
        )
    else:
        raise ValueError(f"Invalid embedding_fn_name: {embedding_fn_name}")

    logger.success("Done!")


if __name__ == "__main__":
    fire.Fire(construct_mem_db)
