from typing import Optional

import fire
from loguru import logger

from src.dataset import dataset_factory
from src.batch_processor import batch_processor_factory
from src.service.document_service import DocumentService
from src.service.memory_service import MemoryService
from src.service.generation_service import GenerationService
from src.schema import Prediction
from src.utils.json import (
    read_jsonl_file,
    write_jsonl_file
)
from src.utils.log import set_log_level


def gen_build_req(
    retrieval_path: str,
    request_path: str,
    dataset_name: str,
    qa_data_path: str,
    doc_data_path: str,
    mem_data_path: str,
    generator_name: str,
    generation_strategy: str,  # [naive, gen_sq, oracle_sq, gen_pg, oracle_pg]
    api_type: str,  # [s, p, b]
    generator_temperature: float = 0.0,
    semantic_memory_enabled: bool = False,
    batch_size: Optional[int] = None,
    shuffle: bool = False,
    num_workers: int = 8,
) -> None:
    set_log_level()

    retrieval_data = read_jsonl_file(retrieval_path)
    predictions = [Prediction(**data) for data in retrieval_data]
    dataset = dataset_factory(
        dataset_name=dataset_name,
        data_path=qa_data_path,
    )
    data_loader = dataset.get_data_loader(
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    doc_service = DocumentService(doc_data_path=doc_data_path)
    mem_service = MemoryService(mem_data_path=mem_data_path)
    generation_service = GenerationService()
    batch_processor = batch_processor_factory(
        generator_name=generator_name,
        api_type=api_type,
    )
    requests = batch_processor.build_request(
        data_loader=data_loader,
        predictions=predictions,
        doc_service=doc_service,
        mem_service=mem_service,
        generation_service=generation_service,
        generation_strategy=generation_strategy,
        generator_temperature=generator_temperature,
        semantic_memory_enabled=semantic_memory_enabled,
    )
    write_jsonl_file(request_path, requests)
    logger.success(f"Done!")


if __name__ == "__main__":
    fire.Fire(gen_build_req)
