#!/usr/bin/python3
"""
PrimeKG Knowledge Graph as a Knowledge Base.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Chandak P, Huang K, Zitnik M. Building a knowledge graph to
        enable precision medicine. Sci Data 10(67). (2023). doi:
        10.1038/s41597-023-01960-3

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import os
import pandas as pd
import subprocess
from llama_index.core import Document
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from typing import Any, Dict, Final, List, Union

from .base import RAGKnowledgeBase


class PrimeKGKnowledgeBase(RAGKnowledgeBase):
    url: str = "https://dataverse.harvard.edu/api/access/datafile/6180620"

    def __init__(
        self,
        embedder: BaseEmbedding,
        top_k: int,
        target: str,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "primekg"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            embedder: the embedding function to use for retrieval.
            top_k: the number of documents to retrieve per query.
            target: the target node type to retrieve.
            cache_dir: a local directory to cache the knowledge locally.
        """
        del kwargs
        os.makedirs(str(cache_dir), exist_ok=True)
        cache_path = os.path.join(str(cache_dir), "kg.csv")
        if not os.path.exists(cache_path):
            subprocess.run(["wget", self.url, "-O", cache_path], check=True)
        self.graph = pd.read_csv(cache_path, dtype={"x_id": str, "y_id": str})
        self.graph = self.graph.query("y_type=='disease' & x_type == 'drug'")
        self.graph = self.graph[["display_relation", "x_name", "y_name"]]
        self.nodes = sorted(list(set(self.graph["x_name"].tolist())))
        self.nodes.extend(
            sorted(list(set(self.graph["y_name"].unique().tolist())))
        )
        self.target: Final[str] = target
        super(PrimeKGKnowledgeBase, self).__init__(
            embedder=embedder, top_k=top_k, cache_dir=cache_dir
        )

    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant relationship(s) associated with a query.
        Input:
            query: a string query.
        Returns:
            A string of the retrieved relationships.
        """
        src = super(PrimeKGKnowledgeBase, self).retrieve(query)
        tgt = super(PrimeKGKnowledgeBase, self).retrieve(self.target)
        graph = self.graph[self.graph["x_name"] == src]
        self.top_k, k = len(self.graph), self.top_k
        self._retriever = self._build_retriever()
        out: List[str] = []
        for _tgt in tgt.split("\n"):
            subgraph = graph[graph["y_name"] == tgt]
            if len(subgraph) == 0:
                continue
            out.append(
                f"A {graph.iloc[0]['display_relation']} of {src} is {_tgt}."
            )
            if len(out) >= k:
                break
        self.top_k = k
        self._retriever = self._build_retriever()
        result = "\n".join(out)
        if len(result):
            return result
        return "No relevant knowledge available from the knowledge graph."

    def get_knowledge(self) -> List[Document]:
        """
        Returns a list of the knowledge from the knowledge base.
        Input:
            None.
        Returns:
            A list of the knowledge from the knowledge base.
        """
        return [Document(text=node) for node in self.nodes]

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            target: the target node type to retrieve.
        Returns:
            The description of the knowledge base.
        """
        assert len(args) or "target" in kwargs.keys()
        target = str(kwargs.get("target", None))
        if target == "None":
            target = str(args[0])
        desc = (
            "Retrieves information about how biomedical queries are related "
            "to {target}."
        )
        return desc.format(target=target)

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "The biomedical query to retrieve information about."
