from typing import Optional

from src.generator import BaseGenerator
from src.retriever import BaseRetriever
from src.service.generation_stage_service.cot_execution_service import COTExecutionService
from src.service.generation_stage_service.cot_generation_service import COTGenerationService
from src.service.generation_stage_service.question_answering_service import QuestionAnsweringService
from src.service.generation_stage_service.reasoning_graph_execution_service import ReasoningGraphExecutionService
from src.service.generation_stage_service.reasoning_graph_generation_service import ReasoningGraphGenerationService
from src.service.rag_service.advanced_rag_service import AdvancedRAGService
from src.service.rag_service.base_rag_service import BaseRAGService
from src.service.rag_service.cot_rag_service import COTRAGService
from src.service.rag_service.graph_rag_service import GraphRAGService
from src.service.rag_service.naive_rag_service import NaiveRAGService
from src.service.retrieval_stage_service.memory_reranking_service import MemoryRerankingService
from src.service.retrieval_stage_service.sub_query_generation_service import SubQueryGenerationService


__all__ = [
    "rag_service_factory",
    "BaseRAGService",
]


def rag_service_factory(
    rag_service_name: str,
    retriever: BaseRetriever,
    generator_pri: BaseGenerator,
    generator_pub: Optional[BaseGenerator],
    generator_subquery: BaseGenerator,
    top_k: int,
    default_document: str,
) -> BaseRAGService:
    """
    generator_pri/pub: private, public generator.
        For the RAG services that don't use dual system (e.g. Naive) pass generator_pub=None
    """
    if rag_service_name == "naive":
        question_answering_service = QuestionAnsweringService(
            generator=generator_pri, default_document=default_document
        )
        rag_service = NaiveRAGService(retriever=retriever, question_answering_service=question_answering_service)
    elif rag_service_name == "advanced":
        sub_query_generation_service = SubQueryGenerationService(
            generator=generator_subquery, default_document=default_document
        )
        memory_reranking_service = MemoryRerankingService(top_k=top_k)
        question_answering_service = QuestionAnsweringService(
            generator=generator_pri, default_document=default_document
        )
        rag_service = AdvancedRAGService(
            retriever=retriever,
            sub_query_generation_service=sub_query_generation_service,
            memory_reranking_service=memory_reranking_service,
            question_answering_service=question_answering_service,
        )
    elif rag_service_name == "graph":
        memory_reranking_service = MemoryRerankingService(top_k=top_k)
        reasoning_graph_generation_service = ReasoningGraphGenerationService(
            generator=generator_pub,
            default_document=default_document,
        )
        reasoning_graph_execution_service = ReasoningGraphExecutionService(
            generator=generator_pri,
            default_document=default_document,
        )
        rag_service = GraphRAGService(
            retriever=retriever,
            memory_reranking_service=memory_reranking_service,
            reasoning_graph_generation_service=reasoning_graph_generation_service,
            reasoning_graph_execution_service=reasoning_graph_execution_service,
        )
    elif rag_service_name == "cot":
        sub_query_generation_service = SubQueryGenerationService(
            generator=generator_subquery, default_document=default_document
        )
        memory_reranking_service = MemoryRerankingService(top_k=top_k)
        cot_generation_service = COTGenerationService(generator=generator_pub, default_document=default_document)
        cot_execution_service = COTExecutionService(generator=generator_pri)
        rag_service = COTRAGService(
            retriever=retriever,
            sub_query_generation_service=sub_query_generation_service,
            memory_reranking_service=memory_reranking_service,
            cot_generation_service=cot_generation_service,
            cot_execution_service=cot_execution_service,
        )
    elif rag_service_name == "cot_without_subquery":
        memory_reranking_service = MemoryRerankingService(top_k=top_k)
        cot_generation_service = COTGenerationService(generator=generator_pub, default_document=default_document)
        cot_execution_service = COTExecutionService(generator=generator_pri)
        rag_service = COTRAGService(
            retriever=retriever,
            sub_query_generation_service=None,
            memory_reranking_service=memory_reranking_service,
            cot_generation_service=cot_generation_service,
            cot_execution_service=cot_execution_service,
        )
    else:
        raise ValueError(f"Invalid rag_service_name: {rag_service_name}")
    return rag_service
