#!/usr/bin/python3
"""
DepMap cell perturbation knowledge base. DepMap quantifies the effect of
CRISPR gene knockouts on cancer model cell lines, integrated using Chronos.
Values are copy number corrected, scaled, and screen quality corrected.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Tsherniak A, Vazquez F, Montgomery PG, et al. Defining a cancer
        dependency map. Cell 170(3): 564-76. (2017). doi:
        10.1016/j.cell.2017.06.010
    [2] Dempster JM, Boyle I, Vazquez F, et al. Chronos: A cell population
        dynamics model of CRISPR experiments that improves inferences of
        gene fitness effects. Gen Biol 22(343). (2021). doi:
        10.1186/s13059-021-02540-7

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import pandas as pd
import pystow
from depmap_downloader.api import get_downloads_table
from llama_index.core import Document
from llama_index.core.embeddings import BaseEmbedding
from math import isnan
from pathlib import Path
from typing import Any, Dict, Final, List, NamedTuple, Optional, Union, cast

from .base import RAGKnowledgeBase


class DepMapModel(NamedTuple):
    model_id: str
    model_type: str
    age: Optional[int]
    race: str
    sex: str
    primary_or_metastasis: str
    cancer_type: str
    cosmic_id: Optional[int]

    def __str__(self) -> str:
        """
        Returns a string representation of the DepMap model cell.
        Input:
            None.
        Returns:
            A string representation of the DepMap model cell.
        """
        if len(self.model_id) == 0:
            return ""
        metadata = (
            "{model_id} is a {model_type} derived from a {age} year-old "
            "{race} {sex} with {is_metastatic} {cancer}"
        )
        kwargs = {
            "model_id": self.model_id,
            "model_type": self.model_type.lower(),
            "age": self.age,
            "race": self.race.title(),
            "sex": self.sex.lower(),
            "is_metastatic": self.primary_or_metastasis.lower(),
            "cancer": self.cancer_type
        }
        if self.age is None or isnan(self.age) or (
            not isinstance(self.age, (int, float))
        ):
            metadata = metadata.replace("{age} year-old ", "")
            kwargs.pop("age", None)
        return metadata.format(**kwargs)

    def __repr__(self) -> str:
        """
        Returns a string representation of the DepMap model cell.
        Input:
            None.
        Returns:
            A string representation of the DepMap model cell.
        """
        return str(self)


def _help_download(name: str, version: Optional[str] = None) -> str:
    """
    DepMap data download helper function. Hidden function taken directly from
    the `depmap_downloader` API source code at https://github.com/cthoyt/
    depmap-downloader/blob/main/src/depmap_downloader/api.py.
    Input:
        name: the name of the file to download.
        version: the version of the dataset to download.
    Returns:
        The URL to download the requested file and version from.
    """
    if version is None:
        latest = next(
            release
            for release in get_downloads_table()["releaseData"]
            if release["isLatest"]
        )
        version = cast(str, latest["releaseName"])
    for download in get_downloads_table()["table"]:
        if download["fileName"] == name and download["releaseName"] == version:
            return cast(str, download["downloadUrl"])
    raise ValueError


def _ensure_dependency(
    name: str, version: Optional[str] = None
) -> Union[Path, str]:
    """
    Ensures that a file dependency is downloaded and locally cached.
    Input:
        name: the name of the file to download.
        version: the version of the dataset to download.
    Returns:
        The local path to the downloaded file dependency.
    """
    return pystow.module("bio", "depmap").ensure(
        version=version, url=_help_download(name, version=version), name=name
    )


class DepMapKnowledgeBase(RAGKnowledgeBase):
    version: str = "DepMap Public 24Q4"

    def __init__(
        self,
        embedder: BaseEmbedding,
        top_k: int,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "depmap"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            embedder: the embedding function to use for retrieval.
            top_k: the number of documents to retrieve per query.
            cache_dir: a local directory to cache the knowledge locally.
        """
        del kwargs
        self.ko_path = _ensure_dependency("CRISPRGeneEffect.csv", self.version)
        self.gene_metadata_path = _ensure_dependency("Gene.csv", self.version)
        self.model_path = _ensure_dependency("Model.csv", self.version)

        self.ko_data: Final[pd.DataFrame] = pd.read_csv(self.ko_path)
        self.gene_metadata: Final[pd.DataFrame] = pd.read_csv(
            self.gene_metadata_path
        )
        self.model_metadata: pd.DataFrame = pd.read_csv(self.model_path)
        self.model_metadata.dropna(
            subset=["ModelType", "PatientRace", "Sex", "PrimaryOrMetastasis"],
            inplace=True
        )

        super(DepMapKnowledgeBase, self).__init__(
            embedder=embedder, top_k=top_k, cache_dir=cache_dir
        )

    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant cell lines for a given query.
        Input:
            query: a string query of a patient description.
        Returns:
            A vector of the top_k model cell names relevant to the patient
            description.
        """
        if query in self.model_cells:
            return str(self.model_cell_metadata(query))
        return "\n".join([
            x.split(" is a ", 1)[0]
            for x in super(DepMapKnowledgeBase, self).retrieve(query)[0]
        ])

    def get_sensitive_ko_genes(self, model_name: str, kp: int) -> List[str]:
        """
        Returns the metadata of the top-k' genes that decrease the viability
        of the cell line the most when knocked out.
        Input:
            model_name: the cell line or organoid to retrieve the genes for.
            kp: the number of genes to return.
        Returns:
            A list of the metadata of kp genes in natural language.
        """
        assert model_name in self.model_cells
        row = self.ko_data[self.ko_data["Unnamed: 0"] == model_name].iloc[0]
        idxs = np.argsort([
            x if isinstance(x, float) else np.inf for x in row.to_numpy()
        ])
        idxs = idxs[:min(kp, self.ko_data.shape[1])]
        return [self.ko_data.columns[i + 1].split(" (", 1)[0] for i in idxs]

    @property
    def ko_genes(self) -> List[str]:
        """
        Returns a list of the knockout genes with associated metadata.
        Input:
            None.
        Returns:
            A list of the knockout genes with associated metadata.
        """
        return self.gene_metadata["symbol"].tolist()

    def ko_gene_metadata(self, gene: str) -> str:
        """
        Returns the metadata associated with a knockout gene in natural
        language.
        Input:
            gene: the knockout gene to retrieve the metadata for.
        Returns:
            The metadata associated with the knockout gene in natural language.
        """
        df = self.gene_metadata[self.gene_metadata["symbol"] == gene]
        if len(df) == 0:
            return ""
        row = df.iloc[0]
        metadata = "{symbol} ({full_name}), part of the {group} gene group"
        kwargs = {
            "symbol": row["symbol"],
            "full_name": row["name"],
            "group": row["gene_group"]
        }
        if row["gene_group"] is None or not isinstance(row["gene_group"], str):
            metadata = metadata.replace(", part of the {group} gene group", "")
            kwargs.pop("group", None)
        return metadata.format(**kwargs)

    @property
    def model_cells(self) -> List[str]:
        """
        Returns a list of the model cells with associated metadata.
        Input:
            None.
        Returns:
            A list of the model cells with complete associated metadata.
        """
        return sorted(
            list(
                set(self.model_metadata["ModelID"].tolist()) & set(
                    self.ko_data["Unnamed: 0"].tolist()
                )
            )
        )

    def model_cell_metadata(self, model: str) -> DepMapModel:
        """
        Returns the metadata associated with a model cell in natural language.
        Input:
            model: the model cell to retrieve the metadata for.
        Returns:
            The metadata associated with the model cell in natural language.
        """
        df = self.model_metadata[self.model_metadata["ModelID"] == model]
        if len(df) == 0:
            return DepMapModel(
                model_id="",
                model_type="",
                age=None,
                race="",
                sex="",
                primary_or_metastasis="",
                cancer_type="",
                cosmic_id=None
            )

        row = df.iloc[0]
        return DepMapModel(
            model_id=model,
            model_type=str(row["ModelType"]),
            age=(None if isnan(row["Age"].item()) else int(row["Age"].item())),
            race=str(row["PatientRace"]),
            sex=str(row["Sex"]),
            primary_or_metastasis=str(row["PrimaryOrMetastasis"]),
            cancer_type=str(row["OncotreeSubtype"]),
            cosmic_id=(
                None
                if isnan(row["COSMICID"].item())
                else int(row["COSMICID"].item())
            )
        )

    def get_knowledge(self) -> List[Document]:
        """
        Returns a list of all of the model cell metadata from the knowledge
        base.
        Input:
            None.
        Returns:
            A list of the manuscript texts from the knowledge base.
        """
        return [
            Document(text=str(self.model_cell_metadata(model)))
            for model in self.model_cells
        ]

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            None.
        Returns:
            The description of the knowledge base.
        """
        del args, kwargs
        return (
            "Retrieves information about what cell lines are most likely to "
            "be good in-vitro models for a patient."
        )

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "The patient to retrieve information about."
