from concurrent.futures import ThreadPoolExecutor
from typing import (
    List,
    Optional,
)

from torch.utils.data import DataLoader
from tqdm import tqdm

from src.schema import (
    QA,
    InferenceResult,
    PredictedAnswerSet,
)
from src.service.answer_separation_service import AnswerSeparationService
from src.service.document_service import DocumentService
from src.service.rag_service.base_rag_service import BaseRAGService
from src.utils.set import union_of_sets


class InferenceService:
    def __init__(
        self,
        rag_service: BaseRAGService,
        document_service: DocumentService,
        answer_separation_service: AnswerSeparationService,
        data_loader: DataLoader,
    ) -> None:
        self.rag_service = rag_service
        self.document_service = document_service
        self.answer_separation_service = answer_separation_service
        self.data_loader = data_loader

    def infer_both(
        self,
        question: str,
        documents: List[str],
        answer_sets: List[str],
        answer_element_universe: Optional[List[str]],
    ) -> List[PredictedAnswerSet]:
        pred_answer_sets = []
        for answer_set in answer_sets:
            user_id = answer_set.user_id
            raw_answer, ret_memory_contents, ret_memory_ids, ret_memory_scores = self.rag_service.forward(
                collection_name=user_id,
                query=question,
                documents=documents,
                answer_element_universe=answer_element_universe,
            )
            if self.answer_separation_service is not None:
                possible_answers = answer_element_universe or union_of_sets(
                    [answer_set.elements for answer_set in answer_sets]
                )
                pred_elements = self.answer_separation_service.separate(
                    raw_answer=raw_answer,
                    possible_answers=possible_answers,
                )
            else:
                pred_elements = {raw_answer}
            pred_answer_set = PredictedAnswerSet(
                answer_set_id=answer_set.answer_set_id,
                user_id=user_id,
                raw_answer=raw_answer,
                elements=pred_elements,
                ref_memory_ids=ret_memory_ids,
                ref_memory_contents=ret_memory_contents,
                ref_memory_scores=ret_memory_scores,
            )
            pred_answer_sets.append(pred_answer_set)
        return pred_answer_sets

    def infer_ret(
        self,
        question: str,
        documents: List[str],
        answer_sets: List[str],
    ) -> List[PredictedAnswerSet]:
        pred_answer_sets = []
        for answer_set in answer_sets:
            user_id = answer_set.user_id
            ret_memory_contents, ret_memory_ids, ret_memory_scores = self.rag_service.retrieve(
                collection_name=user_id,
                query=question,
                documents=documents,
            )
            pred_answer_set = PredictedAnswerSet(
                answer_set_id=answer_set.answer_set_id,
                user_id=user_id,
                ref_memory_ids=ret_memory_ids,
                ref_memory_contents=ret_memory_contents,
                ref_memory_scores=ret_memory_scores,
            )
            pred_answer_sets.append(pred_answer_set)
        return pred_answer_sets

    def infer_gen(
        self,
        question: str,
        documents: List[str],
        answer_sets: List[str],
        answer_element_universe: Optional[List[str]],
    ) -> List[PredictedAnswerSet]:
        pred_answer_sets = []
        for answer_set in answer_sets:
            user_id = answer_set.user_id
            gt_ret_memory_ids = answer_set.ret_memory_ids
            gt_ret_memory_contents = answer_set.ret_memory_contents
            raw_answer = self.rag_service.generate(
                query=question,
                documents=documents,
                ret_memory_contents=gt_ret_memory_contents,
                answer_element_universe=answer_element_universe,
            )
            if self.answer_separation_service is not None:
                possible_answers = answer_element_universe or union_of_sets(
                    [answer_set.elements for answer_set in answer_sets]
                )
                pred_elements = self.answer_separation_service.separate(
                    raw_answer=raw_answer,
                    possible_answers=possible_answers,
                )
            else:
                pred_elements = {raw_answer}
            pred_answer_set = PredictedAnswerSet(
                answer_set_id=answer_set.answer_set_id,
                user_id=user_id,
                raw_answer=raw_answer,
                elements=pred_elements,
                ref_memory_ids=gt_ret_memory_ids,
                ref_memory_contents=gt_ret_memory_contents,
            )
            pred_answer_sets.append(pred_answer_set)
        return pred_answer_sets

    def infer_single(
        self,
        qa: QA,
        infer_mode: str = "both",
    ) -> InferenceResult:
        assert infer_mode in [
            "both",
            "ret",
            "gen",
        ], f"You can only choose from ['both', 'ret', 'gen'], but got {infer_mode}"

        question = qa.question
        answer_sets = qa.answer_sets
        document_ids = qa.document_ids
        answer_element_universe = qa.answer_element_universe

        documents = self.document_service.get_documents(document_ids)

        if infer_mode == "both":
            pred_answer_sets = self.infer_both(
                question=question,
                documents=documents,
                answer_sets=answer_sets,
                answer_element_universe=answer_element_universe,
            )
        elif infer_mode == "ret":
            pred_answer_sets = self.infer_ret(
                question=question,
                documents=documents,
                answer_sets=answer_sets,
            )
        elif infer_mode == "gen":
            pred_answer_sets = self.infer_gen(
                question=question,
                documents=documents,
                answer_sets=answer_sets,
            )

        res = InferenceResult(qa=qa, pred_answer_sets=pred_answer_sets)
        return res

    def infer(
        self,
        infer_mode: str = "both",
        num_workers: int = 8,
    ) -> List[InferenceResult]:
        assert infer_mode in [
            "both",
            "ret",
            "gen",
        ], f"You can only choose from ['both', 'ret', 'gen'], but got {infer_mode}"

        # cnt = 0
        # results = []
        # for qa in tqdm(self.data_loader): # sequential inference
        #     if cnt == 10:
        #         break

        #     results.append(self.infer_single(qa, infer_mode=infer_mode))
        #     cnt += 1

        with ThreadPoolExecutor(max_workers=num_workers) as pool: # parallel inference
            results = list(
                tqdm(
                    pool.map(
                        lambda qa: self.infer_single(qa, infer_mode=infer_mode),
                        self.data_loader,
                    ),
                    total=len(self.data_loader),
                )
            )
    
        return results
