import hashlib
import os

from pathlib import Path
from typing import (
    List,
    Tuple,
)

import numpy as np

from dotenv import load_dotenv
from loguru import logger

from src.retriever.base_retriever import BaseRetriever
from src.vector_db import BaseVectorDB
from src.vector_db.embedding_fn import BaseEmbeddingFunction


load_dotenv()


class DenseRetriever(BaseRetriever):
    def __init__(
        self,
        embedding_fn: BaseEmbeddingFunction,
        vector_db: BaseVectorDB,
        top_k: int,
    ) -> None:
        self.embedding_fn = embedding_fn
        self.vector_db = vector_db
        self.top_k = top_k

        self.enable_query_cache = os.getenv("ENABLE_QUERY_EMB_CACHE", "false").lower() == "true"
        self.query_embedding_cache_dir = Path(self.vector_db.uri).parent / "query_embedding_cache"
        self.query_embedding_cache_dir.mkdir(parents=True, exist_ok=True)
        self.query_embedding_cache_files = list(self.query_embedding_cache_dir.glob("*.npy"))
        self.query_embedding_cache = {
            query_embedding_cache_file.stem: np.load(query_embedding_cache_file)
            for query_embedding_cache_file in self.query_embedding_cache_files
        }

    def _hash_query(self, query: str) -> str:
        return hashlib.sha256(query.encode("utf-8")).hexdigest()

    def retrieve(self, collection_name: str, query: str) -> Tuple[List[str], List[str], List[float]]:
        if self.enable_query_cache:
            query_hash = self._hash_query(query)
            if query_hash in self.query_embedding_cache:
                query_embedding = self.query_embedding_cache[query_hash]
            else:
                logger.info("cache miss")
                query_embedding = self.embedding_fn.embed_queries([query])
                self.query_embedding_cache[query_hash] = query_embedding
                np.save(self.query_embedding_cache_dir / f"{query_hash}.npy", query_embedding)
        else:
            query_embedding = self.embedding_fn.embed_queries([query])

        search_results = self.vector_db.search(
            collection_name=collection_name,
            data=query_embedding,
            limit=self.top_k,
            output_fields=["id", "content"],
        )
        search_result = search_results[0]

        ret_items, ret_item_ids, ret_item_scores = [], [], []
        for item in search_result:
            ret_items.append(item["entity"]["content"])
            ret_item_ids.append(item["entity"]["id"])
            ret_item_scores.append(item["distance"])

        return ret_items, ret_item_ids, ret_item_scores
