import random

from pathlib import Path
from typing import (
    List,
    Tuple,
)

import orjson

from src.retriever.base_retriever import BaseRetriever


class GTRetriever(BaseRetriever):
    def __init__(
        self,
        memory_db_uri: str,
    ) -> None:
        qa_json_path = Path(memory_db_uri).parent.parent / "json_data" / "qa.json"
        with qa_json_path.open() as f:
            qas = orjson.loads(f.read())

        self.gt_memory_ids_by_user_id_query = {}
        self.gt_memory_contents_by_user_id_query = {}
        for qa in qas:
            query = qa["question"]
            for answer_set in qa["answer_sets"]:
                user_id = answer_set["user_id"]
                self.gt_memory_contents_by_user_id_query[user_id, query] = answer_set["ref_memory_contents"]
                self.gt_memory_ids_by_user_id_query[user_id, query] = answer_set["ref_memory_ids"]

    def retrieve(self, collection_name: str, query: str) -> Tuple[List[str], List[str], List[float]]:
        # Assumption: query is NOT sub-query. It should be the original query.
        contents = self.gt_memory_contents_by_user_id_query[collection_name, query]
        ids = self.gt_memory_ids_by_user_id_query[collection_name, query]
        assert len(contents) == len(ids)
        return contents, ids, [1.0 for _ in ids]
