#!/usr/bin/python3
"""
Knowledge base for retrieving relevant prior knowledge for a query.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from __future__ import annotations
import abc
import os
from llama_index.core import (
    Document,
    Settings,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage
)
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from typing import Any, Dict, Final, List, Optional, Set, Union


class BioEntity(str):
    def __new__(
        cls, entity_type: str, name: str, synonyms: Optional[List[str]] = None
    ):
        """
        Args:
            entity_type: the type of the entity.
            name: the official name of the entity.
            synonyms: optional synonyms of the entity.
        """
        obj = super().__new__(cls, name)
        setattr(obj, "_entity_type", entity_type)
        setattr(obj, "_synonyms", set(synonyms or []) | set([name]))
        return obj

    def __eq__(self, other: object) -> bool:
        """
        Defines equality of entities.
        Input:
            other: another entity to compare against.
        Returns:
            Whether the two entities are equal.
        """
        synonyms: Set[str] = getattr(self, "_synonyms", set([]))
        if isinstance(other, BioEntity):
            return not synonyms.isdisjoint(
                getattr(other, "_synonyms", set([]))
            )
        elif isinstance(other, str):
            return other in synonyms
        raise NotImplementedError

    def __hash__(self):
        """
        Defines the hash function for entities.
        Input:
            None.
        Returns:
            The has value of the entity.
        """
        return hash(
            tuple([self._entity_type] + sorted(list((self._synonyms))))
        )


class KnowledgeBase(abc.ABC):
    def __init__(self, top_k: int, **kwargs: Dict[str, Any]):
        """
        Args:
            top_k: the number of documents to retrieve per query.
        """
        del kwargs
        self.top_k: int = top_k

    @abc.abstractmethod
    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant document(s) for a given query(s).
        Input:
            query: a string query or list of queries.
        Returns:
            A nxk matrix of retrieved documents, where n is the number
            of queries and k is equal to top_k.
        """
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            None.
        Returns:
            The description of the knowledge base.
        """
        raise NotImplementedError

    @classmethod
    @abc.abstractmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        raise NotImplementedError


class RAGKnowledgeBase(KnowledgeBase):
    def __init__(
        self,
        embedder: BaseEmbedding,
        top_k: int,
        cache_dir: Union[Path, str],
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            embedder: the embedding function to use for retrieval.
            top_k: the number of documents to retrieve per query.
            cache_dir: a local directory to cache the knowledge locally.
        """
        Settings.embed_model = embedder
        super(RAGKnowledgeBase, self).__init__(top_k, **kwargs)
        self.embedder: Final[str] = embedder.model_name
        self.cache_dir: Final[str] = os.path.join(
            str(cache_dir), embedder.model_name
        )
        os.makedirs(self.cache_dir, exist_ok=True)
        self._retriever = self._build_retriever()

    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant document(s) for a given query.
        Input:
            query: a string query.
        Returns:
            A string of the retrieved document(s).
        """
        return "\n".join([
            getattr(x.node, "text", "")
            for x in self._retriever.retrieve(query)
        ])

    def _build_retriever(self) -> VectorIndexRetriever:
        """
        Builds a retriever from the specified knowledge source.
        Input:
            None.
        Returns:
            A Retreiver object.
        """
        cache_path = os.path.join(self.cache_dir, "index")
        if os.path.isdir(cache_path):
            index: VectorStoreIndex = load_index_from_storage(  # type: ignore
                StorageContext.from_defaults(persist_dir=cache_path)
            )
        else:
            index = VectorStoreIndex.from_documents(self.get_knowledge())
            if cache_path is not None:
                index.storage_context.persist(persist_dir=cache_path)

        return VectorIndexRetriever(index=index, similarity_top_k=self.top_k)

    @abc.abstractmethod
    def get_knowledge(self) -> List[Document]:
        """
        Returns a list of all of the manuscript texts from the knowledge base.
        Input:
            None.
        Returns:
            A list of the manuscript texts from the knowledge base.
        """
        raise NotImplementedError
