from typing import (
    List,
    Optional,
)
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm
from torch.utils.data import DataLoader

from src.retriever import BaseRetriever
from src.service.document_service import DocumentService
from src.service.memory_service import MemoryService
from src.service.sub_query_service import SubQueryService
from src.service.pers_graph_service import PersGraphService
from src.schema import (
    QA,
    PredictedAnswer,
    Prediction,
)


class RetrievalService:
    def __init__(
        self,
        data_loader: DataLoader,
        doc_service: DocumentService,
        mem_service: MemoryService,
        doc_retriever: BaseRetriever,
        mem_retriever: BaseRetriever,
        mem_type: str,
        method_type: str,
        retrieval_strategy: str,
        sub_query_service: SubQueryService,
        pers_graph_service: PersGraphService,
        mem_top_k: int,
    ) -> None:
        assert mem_type in ["dialogue", "observation", "summary", "episodic_memory"]
        assert method_type in ["rag", "prompt"]
        assert retrieval_strategy in ["naive", "gen_sq", "oracle_sq", "gen_pg", "oracle_pg"]

        self.data_loader = data_loader
        self.doc_service = doc_service
        self.mem_service = mem_service
        self.doc_retriever = doc_retriever
        self.mem_retriever = mem_retriever
        self.mem_type = mem_type
        self.method_type = method_type
        self.retrieval_strategy = retrieval_strategy
        self.sub_query_service = sub_query_service
        self.pers_graph_service = pers_graph_service
        self.top_k = mem_top_k

    def _retrieve_prompt(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        for answer in qa.answers:
            if self.mem_type == "dialogue":
                ret_mem_contents = [self.mem_service.dialogue_prompt_map[answer.user_id]]
            elif self.mem_type == "observation":
                ret_mem_contents = [self.mem_service.observation_prompt_map[answer.user_id]]
            elif self.mem_type == "summary":
                ret_mem_contents = [self.mem_service.summary_prompt_map[answer.user_id]]
            elif self.mem_type == "episodic_memory":
                ret_mem_contents = [self.mem_service.episodic_memory_prompt_map[answer.user_id]]
            else:
                raise ValueError(f"Invalid mem_type: {self.mem_type}")
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=None,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=None,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction

    def _retrieve_rag_naive(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        for answer in qa.answers:
            query = qa.question
            ret_mem_contents, ret_mem_ids, ret_mem_scores = self.mem_retriever.retrieve(collection_name=answer.user_id, query=query)
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=ret_mem_ids,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=ret_mem_scores,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction
    
    def _retrieve_rag_gen_sq(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        sub_queries = self.sub_query_service.generate(context=qa.context, question=qa.question)
        for answer in qa.answers:
            ret_mem_contents, ret_mem_ids, ret_mem_scores = [], [], []
            for sub_query in sub_queries:
                ret_mem_contents_curr, ret_mem_ids_curr, ret_mem_scores_curr = self.mem_retriever.retrieve(collection_name=answer.user_id, query=sub_query)
                ret_mem_contents.extend(ret_mem_contents_curr)
                ret_mem_ids.extend(ret_mem_ids_curr)
                ret_mem_scores.extend(ret_mem_scores_curr)
            top_k_indices = sorted(range(len(ret_mem_scores)), key=lambda i: ret_mem_scores[i], reverse=True)[:self.top_k]
            ret_mem_contents = [ret_mem_contents[i] for i in top_k_indices]
            ret_mem_ids = [ret_mem_ids[i] for i in top_k_indices]
            ret_mem_scores = [ret_mem_scores[i] for i in top_k_indices]
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=ret_mem_ids,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=ret_mem_scores,
                gen_sub_queries=sub_queries,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction

    def _retrieve_rag_oracle_sq(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        for answer in qa.answers:
            ret_mem_contents, ret_mem_ids, ret_mem_scores = [], [], []
            sub_queries = answer.oracle_sub_queries
            for sub_query in sub_queries:
                ret_mem_contents_curr, ret_mem_ids_curr, ret_mem_scores_curr = self.mem_retriever.retrieve(collection_name=answer.user_id, query=sub_query)
                ret_mem_contents.extend(ret_mem_contents_curr)
                ret_mem_ids.extend(ret_mem_ids_curr)
                ret_mem_scores.extend(ret_mem_scores_curr)
            top_k_indices = sorted(range(len(ret_mem_scores)), key=lambda i: ret_mem_scores[i], reverse=True)[:self.top_k]
            ret_mem_contents = [ret_mem_contents[i] for i in top_k_indices]
            ret_mem_ids = [ret_mem_ids[i] for i in top_k_indices]
            ret_mem_scores = [ret_mem_scores[i] for i in top_k_indices]
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=ret_mem_ids,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=ret_mem_scores,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction

    def _retrieve_rag_gen_pg(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        pers_graph = self.pers_graph_service.generate(context=qa.context, question=qa.question)
        sub_queries = [sq for sub_queries in pers_graph.values() for sq in sub_queries]
        for answer in qa.answers:
            ret_mem_contents, ret_mem_ids, ret_mem_scores = [], [], []
            for sub_query in sub_queries:
                ret_mem_contents_curr, ret_mem_ids_curr, ret_mem_scores_curr = self.mem_retriever.retrieve(collection_name=answer.user_id, query=sub_query)
                ret_mem_contents.extend(ret_mem_contents_curr)
                ret_mem_ids.extend(ret_mem_ids_curr)
                ret_mem_scores.extend(ret_mem_scores_curr)
            top_k_indices = sorted(range(len(ret_mem_scores)), key=lambda i: ret_mem_scores[i], reverse=True)[:self.top_k]
            ret_mem_contents = [ret_mem_contents[i] for i in top_k_indices]
            ret_mem_ids = [ret_mem_ids[i] for i in top_k_indices]
            ret_mem_scores = [ret_mem_scores[i] for i in top_k_indices]
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=ret_mem_ids,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=ret_mem_scores,
                gen_pers_graph=pers_graph,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction

    def _retrieve_rag_oracle_pg(self, qa: QA) -> Prediction:
        ret_doc_contents, ret_doc_ids, ret_doc_scores = self.doc_retriever.retrieve(collection_name="all", query=qa.question)
        pred_answers = []
        for answer in qa.answers:
            ret_mem_contents, ret_mem_ids, ret_mem_scores = [], [], []
            pers_graph = answer.oracle_pers_graph
            sub_queries = [sq for sub_queries in pers_graph.values() for sq in sub_queries]
            for sub_query in sub_queries:
                ret_mem_contents_curr, ret_mem_ids_curr, ret_mem_scores_curr = self.mem_retriever.retrieve(collection_name=answer.user_id, query=sub_query)
                ret_mem_contents.extend(ret_mem_contents_curr)
                ret_mem_ids.extend(ret_mem_ids_curr)
                ret_mem_scores.extend(ret_mem_scores_curr)
            top_k_indices = sorted(range(len(ret_mem_scores)), key=lambda i: ret_mem_scores[i], reverse=True)[:self.top_k]
            ret_mem_contents = [ret_mem_contents[i] for i in top_k_indices]
            ret_mem_ids = [ret_mem_ids[i] for i in top_k_indices]
            ret_mem_scores = [ret_mem_scores[i] for i in top_k_indices]
            pred_answer = PredictedAnswer(
                answer_id=answer.answer_id,
                user_id=answer.user_id,
                ret_document_ids=ret_doc_ids,
                ret_document_contents=ret_doc_contents,
                ret_document_scores=ret_doc_scores,
                ret_memory_ids=ret_mem_ids,
                ret_memory_contents=ret_mem_contents,
                ret_memory_scores=ret_mem_scores,
            )
            pred_answers.append(pred_answer)
        prediction = Prediction(qa=qa, pred_answers=pred_answers)
        return prediction
    
    def _retrieve_rag(self, qa: QA) -> Prediction:
        if self.retrieval_strategy == "naive":
            return self._retrieve_rag_naive(qa)
        elif self.retrieval_strategy == "gen_sq":
            return self._retrieve_rag_gen_sq(qa)
        elif self.retrieval_strategy == "oracle_sq":
            return self._retrieve_rag_oracle_sq(qa)
        elif self.retrieval_strategy == "gen_pg":
            return self._retrieve_rag_gen_pg(qa)
        elif self.retrieval_strategy == "oracle_pg":
            return self._retrieve_rag_oracle_pg(qa)
        else:
            raise ValueError(f"Invalid retrieval_strategy: {self.retrieval_strategy}")

    def _retrieve_single(self, qa: QA) -> Prediction:
        if self.method_type == "prompt":
            return self._retrieve_prompt(qa)
        elif self.method_type == "rag":
            return self._retrieve_rag(qa)
        else:
            raise ValueError(f"Invalid method_type: {self.method_type}")

    def retrieve(self, num_workers: int = 8) -> List[Prediction]:
        # sequential processing
        results = []
        for qa in tqdm(self.data_loader):
            results.append(self._retrieve_single(qa))

        # parallel processing
        # with ThreadPoolExecutor(max_workers=num_workers) as pool:
        #     results = list(
        #         tqdm(
        #             pool.map(
        #                 lambda qa: self._retrieve_single(qa),
        #                 self.data_loader,
        #             ),
        #             total=len(self.data_loader),
        #         )
        #     )
        
        return results
