from typing import (
    List,
    Optional,
    Tuple,
)

from src.retriever import BaseRetriever
from src.schema import Document
from src.service.generation_stage_service.cot_execution_service import COTExecutionService
from src.service.generation_stage_service.cot_generation_service import COTGenerationService
from src.service.rag_service.base_rag_service import BaseRAGService
from src.service.retrieval_stage_service.memory_reranking_service import MemoryRerankingService
from src.service.retrieval_stage_service.sub_query_generation_service import SubQueryGenerationService


class COTRAGService(BaseRAGService):
    def __init__(
        self,
        retriever: BaseRetriever,
        sub_query_generation_service: Optional[SubQueryGenerationService],
        memory_reranking_service: MemoryRerankingService,
        cot_generation_service: COTGenerationService,
        cot_execution_service: COTExecutionService,
    ) -> None:
        self.retriever = retriever
        self.sub_query_generation_service = sub_query_generation_service
        self.memory_reranking_service = memory_reranking_service
        self.cot_generation_service = cot_generation_service
        self.cot_execution_service = cot_execution_service

    def _retrieve_for_sub_queries(
        self,
        collection_name,
        sub_queries: List[str],
    ) -> Tuple[List[str], List[str], List[float]]:
        ret_items = []
        ret_item_ids = []
        ret_item_scores = []
        for sub_query in sub_queries:
            items, item_ids, item_scores = self.retriever.retrieve(
                collection_name=collection_name,
                query=sub_query,
            )
            ret_items.extend(items)
            ret_item_ids.extend(item_ids)
            ret_item_scores.extend(item_scores)
        return ret_items, ret_item_ids, ret_item_scores

    def retrieve(
        self,
        collection_name: str,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> Tuple[List[str], List[str], List[float]]:
        if self.sub_query_generation_service is not None:
            sub_queries = self.sub_query_generation_service.generate_sub_queries(
                query=query,
                documents=documents,
                answer_element_universe=answer_element_universe,
            )
            ret_memory_contents, ret_memory_ids, ret_memory_scores = self._retrieve_for_sub_queries(
                collection_name=collection_name,
                sub_queries=sub_queries,
            )
        else:
            ret_memory_contents, ret_memory_ids, ret_memory_scores = self.retriever.retrieve(
                collection_name=collection_name,
                query=query,
            )
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self.memory_reranking_service.rerank(
            ret_memory_contents=ret_memory_contents,
            ret_memory_ids=ret_memory_ids,
            ret_memory_scores=ret_memory_scores,
        )
        return ret_memory_contents, ret_memory_ids, ret_memory_scores

    def generate(
        self,
        query: str,
        documents: List[Document],
        ret_memory_contents: List[str],
        answer_element_universe: Optional[List[str]] = None,
    ) -> str:
        cot = self.cot_generation_service.generate_cot(
            query=query,
            documents=documents,
            answer_element_universe=answer_element_universe,
        )
        raw_answer = self.cot_execution_service.execute_cot(
            query=query,
            ret_memory_contents=ret_memory_contents,
            answer_element_universe=answer_element_universe,
            cot=cot,
        )
        return raw_answer

    def forward(
        self,
        collection_name: str,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> Tuple[str, List[str], List[str], List[float]]:
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self.retrieve(
            collection_name=collection_name,
            query=query,
            documents=documents,
            answer_element_universe=answer_element_universe,
        )
        raw_answer = self.generate(
            query=query,
            documents=documents,
            ret_memory_contents=ret_memory_contents,
            answer_element_universe=answer_element_universe,
        )
        return raw_answer, ret_memory_contents, ret_memory_ids, ret_memory_scores
