from pathlib import Path
from typing import (
    List,
    Tuple,
)

import numpy as np

from src.retriever.base_retriever import BaseRetriever
from src.utils.json import read_json_file
from src.vector_db.embedding_fn import BaseEmbeddingFunction


class SparseRetriever(BaseRetriever):
    def __init__(self, memory_db_uri: str, embedding_fn: BaseEmbeddingFunction, top_k: int) -> None:
        self.memory_db_uri = memory_db_uri
        self.embedding_fn = embedding_fn
        self.top_k = top_k

        embedding_fn_params_path = Path(memory_db_uri) / "embedding_fn_params.json"
        embedding_map_path = Path(memory_db_uri) / "embedding_map.npz"
        memory_map_path = Path(memory_db_uri) / "memory_map.json"

        self.embedding_fn.load(str(embedding_fn_params_path))
        self.embedding_map = dict(np.load(str(embedding_map_path)))
        self.memory_map = read_json_file(file_path=str(memory_map_path))

    def retrieve(self, collection_name: str, query: str) -> Tuple[List[str], List[str], List[float]]:
        assert collection_name in self.embedding_map, f"Collection '{collection_name}' not found in embedding_map."
        assert collection_name in self.memory_map, f"Collection '{collection_name}' not found in memory_map."

        query_embedding = self.embedding_fn.embed_queries([query])
        key_embeddings = self.embedding_map[collection_name]
        scores = np.dot(key_embeddings, query_embedding)

        top_k_indices = np.argsort(scores)[-self.top_k :][::-1]
        ret_items = [self.memory_map[collection_name][i]["content"] for i in top_k_indices]
        ret_item_ids = [self.memory_map[collection_name][i]["memory_id"] for i in top_k_indices]
        ret_item_scores = [scores[i] for i in top_k_indices]

        return ret_items, ret_item_ids, ret_item_scores
