from typing import Optional

import fire
from loguru import logger

from src.dataset import dataset_factory
from src.retriever import retriever_factory
from src.generator import generator_factory
from src.vector_db.embedding_fn import embedding_fn_factory
from src.service.document_service import DocumentService
from src.service.memory_service import MemoryService
from src.service.retrieval_service import RetrievalService
from src.service.sub_query_service import SubQueryService
from src.service.pers_graph_service import PersGraphService
from src.utils.json import write_jsonl_file
from src.utils.log import set_log_level


def retrieve(
    retrieval_path: str,
    dataset_name: str,
    qa_data_path: str,
    doc_data_path: str,
    mem_data_path: str,
    mem_type: str,  # [dialogue, observation, summary, episodic_memory]
    method_type: str,  # [rag, prompt]
    retrieval_strategy: str,  # [naive, gen_sq, oracle_sq, gen_pg, oracle_pg]
    retriever_name: str,
    doc_db_uri: str,
    doc_embedding_fn_name: str,
    doc_top_k: int,
    mem_db_uri: str,
    mem_embedding_fn_name: str,
    mem_top_k: int,
    vector_db_name: Optional[str] = None,
    generator_name: Optional[str] = None,
    generator_temperature: Optional[float] = None,
    batch_size: Optional[int] = None,
    shuffle: bool = False,
    num_workers: int = 8,
):
    set_log_level()

    dataset = dataset_factory(
        dataset_name=dataset_name,
        data_path=qa_data_path,
    )
    data_loader = dataset.get_data_loader(
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    doc_embedding_fn = embedding_fn_factory(embedding_fn_name=doc_embedding_fn_name)
    doc_retriever = retriever_factory(
        retriever_name=retriever_name if doc_embedding_fn_name not in ("gt", "no") else doc_embedding_fn_name,
        db_uri=doc_db_uri,
        embedding_fn=doc_embedding_fn,
        vector_db_name=vector_db_name,
        top_k=doc_top_k,
    )
    mem_embedding_fn = embedding_fn_factory(embedding_fn_name=mem_embedding_fn_name)
    mem_retriever = retriever_factory(
        retriever_name=retriever_name if mem_embedding_fn_name not in ("gt", "no") else mem_embedding_fn_name,
        db_uri=mem_db_uri,
        embedding_fn=mem_embedding_fn,
        vector_db_name=vector_db_name,
        top_k=mem_top_k,
    )

    generator = generator_factory(
        generator_name=generator_name if generator_name != "none" else "gpt-4o",
        temperature=generator_temperature,
    )
    
    doc_service = DocumentService(doc_data_path=doc_data_path)
    mem_service = MemoryService(mem_data_path=mem_data_path)
    sub_query_service = SubQueryService(generator=generator)
    pers_graph_service = PersGraphService(generator=generator)
    retrieval_service = RetrievalService(
        data_loader=data_loader,
        doc_service=doc_service,
        mem_service=mem_service,
        doc_retriever=doc_retriever,
        mem_retriever=mem_retriever,
        mem_type=mem_type,
        method_type=method_type,
        retrieval_strategy=retrieval_strategy,
        sub_query_service=sub_query_service,
        pers_graph_service=pers_graph_service,
        mem_top_k=mem_top_k,
    )
    
    ret_results = retrieval_service.retrieve(num_workers=num_workers)
    ret_results_dict = [res.model_dump(mode="json") for res in ret_results]
    write_jsonl_file(file_path=retrieval_path, data=ret_results_dict)
    logger.success("Done!")


if __name__ == "__main__":
    fire.Fire(retrieve)
