from typing import (
    List,
    Optional,
    Tuple,
)

from src.retriever import BaseRetriever
from src.schema import Document
from src.service.generation_stage_service.reasoning_graph_execution_service import ReasoningGraphExecutionService
from src.service.generation_stage_service.reasoning_graph_generation_service import ReasoningGraphGenerationService
from src.service.rag_service.base_rag_service import BaseRAGService
from src.service.retrieval_stage_service.memory_reranking_service import MemoryRerankingService


class GraphRAGService(BaseRAGService):
    def __init__(
        self,
        retriever: BaseRetriever,
        memory_reranking_service: MemoryRerankingService,
        reasoning_graph_generation_service: ReasoningGraphGenerationService,
        reasoning_graph_execution_service: ReasoningGraphExecutionService,
    ) -> None:
        self.retriever = retriever
        self.memory_reranking_service = memory_reranking_service
        self.reasoning_graph_generation_service = reasoning_graph_generation_service
        self.reasoning_graph_execution_service = reasoning_graph_execution_service

    def retrieve(
        self,
        collection_name: str,
        query: str,
        documents: List[Document],
    ) -> Tuple[List[str], List[str], List[float]]:
        pass

    def generate(
        self,
        query: str,
        documents: List[Document],
        ret_memory_contents: List[str],
    ) -> str:
        pass

    def _retrieve_for_sub_queries(
        self,
        collection_name,
        sub_queries: List[str],
    ) -> Tuple[List[str], List[str], List[float]]:
        ret_items = []
        ret_item_ids = []
        ret_item_scores = []
        for sub_query in sub_queries:
            items, item_ids, item_scores = self.retriever.retrieve(
                collection_name=collection_name,
                query=sub_query,
            )
            ret_items.extend(items)
            ret_item_ids.extend(item_ids)
            ret_item_scores.extend(item_scores)
        return ret_items, ret_item_ids, ret_item_scores

    def forward(
        self,
        collection_name: str,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> Tuple[str, List[str], List[str], List[float]]:
        reasoning_graph = self.reasoning_graph_generation_service.generate_reasoning_graph(
            query=query,
            documents=documents,
            answer_element_universe=answer_element_universe,
        )
        sub_queries = [node.sub_query for node in reasoning_graph.nodes if node.is_leaf == False]
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self._retrieve_for_sub_queries(
            collection_name=collection_name,
            sub_queries=sub_queries,
        )
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self.memory_reranking_service.rerank(
            ret_memory_contents=ret_memory_contents,
            ret_memory_ids=ret_memory_ids,
            ret_memory_scores=ret_memory_scores,
        )
        raw_answer = self.reasoning_graph_execution_service.execute_reasoning_graph(
            query=query,
            documents=documents,
            reasoning_graph=reasoning_graph,
            ret_memory_contents=ret_memory_contents,
            answer_element_universe=answer_element_universe,
        )
        return raw_answer, ret_memory_contents, ret_memory_ids, ret_memory_scores
