#!/usr/bin/python3
"""
Hetionet Biomedical Knowledge Graph as a Knowledge Base.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Himmelstein DS, Lizee A, Hessler C, et al. Systematic integration of
        biomedical knowledge prioritizes drugs for repurposing. eLife 6:e26726.
        (2017). doi: 10.7554/eLife.26726

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from hetnetpy.hetnet import MetaGraph, Graph, Node  # type: ignore
from hetnetpy.readwrite import (  # type: ignore
    extract_writable,
    graph_from_writable
)
from hetnetpy.pathtools import paths_between  # type: ignore
from llama_index.core import Document
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from typing import Any, Dict, Final, List, Union

from .base import RAGKnowledgeBase


class HetionetKGKnowledgeBase(RAGKnowledgeBase):
    url: str = (
        "https://github.com/hetio/hetionet/raw/refs/heads/main/hetnet/json/"
        "hetionet-v1.0.json.bz2"
    )

    def __init__(
        self,
        embedder: BaseEmbedding,
        top_k: int,
        target: str,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "hetionet"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            embedder: the embedding function to use for retrieval.
            top_k: the number of documents to retrieve per query.
            target: the target node type to retrieve.
            cache_dir: a local directory to cache the knowledge locally.
        """
        del kwargs
        writable = extract_writable(self.url)
        self.graph: Final[Graph] = graph_from_writable(writable)
        self.target: Final[str] = target
        self.metagraph: Final[MetaGraph] = self.graph.metagraph
        self.nodes: Final[List[Node]] = list(self.graph.get_nodes())
        self.node2idx: Final[Dict[str, int]] = {
            x.name: i for i, x in enumerate(self.nodes)
        }
        super(HetionetKGKnowledgeBase, self).__init__(
            embedder=embedder, top_k=top_k, cache_dir=cache_dir
        )

    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant relationship(s) associated with a query.
        Input:
            query: a string query.
        Returns:
            A string of the retrieved relationships.
        """
        metapath = self.metagraph.metapath_from_abbrev("CbGaD")
        src = super(HetionetKGKnowledgeBase, self).retrieve(query)
        source_ids = [
            self.nodes[self.node2idx[x.split(": ", 1)[-1]]]
            for x in src.split("\n")
        ]
        tgt = super(HetionetKGKnowledgeBase, self).retrieve(self.target)
        target_id = self.nodes[
            self.node2idx[tgt.split(": ", 1)[-1].split("\n")[0]]
        ]

        self.top_k, k = len(self.nodes), self.top_k
        self._retriever = self._build_retriever()
        out: List[str] = []
        for src_id in source_ids:
            try:
                paths = paths_between(self.graph, src_id, target_id, metapath)
            except KeyError:
                continue
            if len(paths) == 0:
                continue
            compound, gene, disease = paths[0].get_nodes()
            out.append(
                f"{compound.name} targets {gene.name} associated with "
                f"{disease.name}."
            )
            if len(out) >= k:
                break
        self._retriever = self._build_retriever()
        self.top_k = k
        result = "\n".join(out)
        if len(result):
            return result
        return "No relevant knowledge available from the knowledge graph."

    def get_knowledge(self) -> List[Document]:
        """
        Returns a list of the knowledge from the knowledge base.
        Input:
            None.
        Returns:
            A list of the knowledge from the knowledge base.
        """
        return [
            Document(text=f"{node.metanode.identifier}: {node.name}")
            for node in self.nodes
            if node.metanode.identifier in ["Compound", "Disease"]
        ]

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            target: the target node type to retrieve.
        Returns:
            The description of the knowledge base.
        """
        assert len(args) or "target" in kwargs.keys()
        target = str(kwargs.get("target", None))
        if target == "None":
            target = str(args[0])
        desc = "Retrieves information about how drugs are related to {target}."
        return desc.format(target=target)

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "The drug to retrieve information about."
