from typing import (
    List,
    Optional,
    Tuple,
)

from dotenv import load_dotenv

from src.retriever import BaseRetriever
from src.schema import Document
from src.service.generation_stage_service.question_answering_service import QuestionAnsweringService
from src.service.rag_service.base_rag_service import BaseRAGService


load_dotenv()


class NaiveRAGService(BaseRAGService):
    def __init__(
        self,
        retriever: BaseRetriever,
        question_answering_service: QuestionAnsweringService,
    ) -> None:
        self.retriever = retriever
        self.question_answering_service = question_answering_service

    def retrieve(
        self,
        collection_name: str,
        query: str,
    ) -> Tuple[List[str], List[str], List[float]]:
        return self.retriever.retrieve(
            collection_name=collection_name,
            query=query,
        )

    def generate(
        self,
        query: str,
        documents: List[Document],
        ret_memory_contents: List[str],
        answer_element_universe: Optional[List[str]] = None,
    ) -> str:
        return self.question_answering_service.generate_answer(
            query=query,
            documents=documents,
            ret_memory_contents=ret_memory_contents,
            answer_element_universe=answer_element_universe
        )

    def forward(
        self,
        collection_name: str,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> Tuple[str, List[str], List[str], List[float]]:
        ret_memory_contents, ret_memory_ids, ret_memory_scores = self.retrieve(
            collection_name=collection_name,
            query=query,
        )
        raw_answer = self.generate(
            query=query,
            documents=documents,
            ret_memory_contents=ret_memory_contents,
            answer_element_universe=answer_element_universe,
        )
        return raw_answer, ret_memory_contents, ret_memory_ids, ret_memory_scores
