from pathlib import Path
from typing import Optional

import fire

from loguru import logger

from src.dataset import dataset_factory
from src.generator import generator_factory
from src.generator.cached_generator import maybe_apply_cache
from src.retriever import retriever_factory
from src.service.answer_separation_service import AnswerSeparationService
from src.service.document_service import DocumentService
from src.service.inference_service import InferenceService
from src.service.rag_service import rag_service_factory
from src.utils.json import write_json_file
from src.utils.log import set_log_level
from src.vector_db.embedding_fn import embedding_fn_factory


def infer(
    rag_service_name: str,
    inference_results_path: str,
    dataset_name: str,
    qa_data_path: str,
    document_data_path: str,
    retriever_name: str,
    memory_db_uri: str,
    embedding_fn_name: str,
    generator_pri_name: str,
    top_k: Optional[int] = None,
    vector_db_name: Optional[str] = None,
    generator_pub_name: Optional[str] = None,
    generator_subquery_name: Optional[str] = None,
    generator_pri_temperature: float = 0.0,
    generator_pub_temperature: float = 0.0,
    generator_subquery_temperature: float = 0.0,
    generator_cache_dir: Optional[str] = None,
    batch_size: Optional[int] = None,
    shuffle: bool = False,
    infer_mode: str = "both",
    num_workers: int = 8,
    use_answer_separation: bool = True,
    default_document: str = "",
) -> None:
    set_log_level()

    dataset = dataset_factory(dataset_name=dataset_name, data_path=qa_data_path)
    embedding_fn = embedding_fn_factory(embedding_fn_name=embedding_fn_name)
    retriever = retriever_factory(
        retriever_name=retriever_name,
        memory_db_uri=memory_db_uri,
        embedding_fn=embedding_fn,
        vector_db_name=vector_db_name,
        top_k=top_k,
    )
    generator_cache_dir = Path(generator_cache_dir) if generator_cache_dir else None
    generator_pri = maybe_apply_cache(
        generator_factory(generator_name=generator_pri_name, temperature=generator_pri_temperature),
        generator_cache_dir,
    )
    generator_pub = (
        maybe_apply_cache(
            generator_factory(generator_name=generator_pub_name, temperature=generator_pub_temperature),
            generator_cache_dir,
        )
        if generator_pub_name
        else None
    )
    # Subquery generator: if name is specified, use it; otherwise, default to generator_pub,
    # and if it is not available, use generator_pri.
    generator_subquery = (
        maybe_apply_cache(
            generator_factory(generator_name=generator_subquery_name, temperature=generator_subquery_temperature),
            generator_cache_dir,
        )
        if generator_subquery_name
        else (generator_pub or generator_pri)
    )
    rag_service = rag_service_factory(
        rag_service_name=rag_service_name,
        retriever=retriever,
        generator_pri=generator_pri,
        generator_pub=generator_pub,
        generator_subquery=generator_subquery,
        top_k=top_k,
        default_document=default_document,
    )

    document_service = DocumentService(document_data_path=document_data_path)
    generator_4o_greedy = maybe_apply_cache(
        generator_factory(generator_name="gpt-4o", temperature=0.0), generator_cache_dir
    )
    answer_separation_service = (
        AnswerSeparationService(generator=generator_4o_greedy) if use_answer_separation else None
    )
    data_loader = dataset.get_data_loader(
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    inference_service = InferenceService(
        rag_service=rag_service,
        document_service=document_service,
        answer_separation_service=answer_separation_service,
        data_loader=data_loader,
    )

    inference_results = inference_service.infer(infer_mode=infer_mode, num_workers=num_workers)
    inference_results_dict = [res.model_dump(mode="json") for res in inference_results]
    logger.info(f"[Inference Result Sample]\n{inference_results_dict[0]}")
    write_json_file(file_path=inference_results_path, data=inference_results_dict)


if __name__ == "__main__":
    fire.Fire(infer)
