from pathlib import Path
from typing import (
    List,
    Tuple,
)

from src.retriever.base_retriever import BaseRetriever
from src.utils.io import load_qa_dataset


class GTRetriever(BaseRetriever):
    def __init__(
        self,
        db_uri: str,
    ) -> None:
        qa_json_path = Path(db_uri).parent.parent / "json_data" / "qa.json"
        qa_dataset = load_qa_dataset(qa_json_path)

        self.gt_memory_ids_by_user_id_query = {}
        self.gt_memory_contents_by_user_id_query = {}
        for qa_sample in qa_dataset:
            query = qa_sample.question
            for answer in qa_sample.answers:
                user_id = answer.user_id
                self.gt_memory_contents_by_user_id_query[user_id, query] = answer.ref_memory_contents
                self.gt_memory_ids_by_user_id_query[user_id, query] = answer.ref_memory_ids

    def retrieve(self, collection_name: str, query: str) -> Tuple[List[str], List[str], List[float]]:
        # Assumption: query is NOT sub-query. It should be the original query.
        contents = self.gt_memory_contents_by_user_id_query[collection_name, query]
        ids = self.gt_memory_ids_by_user_id_query[collection_name, query]
        assert len(contents) == len(ids)
        return contents, ids, [1.0 for _ in ids]
