#!/usr/bin/python3
"""
Catalogue Of Somatic Mutations In Cancer (COSMIC) Knowledge Base.

Author(s):
    Anonymized Authors @anonymized-authors

URL:
    https://cancer.sanger.ac.uk/cosmic

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import os
import pandas as pd
from enum import Enum
from typing import Any, Dict

from .base import KnowledgeBase


class CancerType(str, Enum):
    BREAST = "BC"
    NSCLC = "NSCLC"
    COLON = "Colon"
    HEMEONC = "HemeOnc"


class COSMICKnowledgeBase(KnowledgeBase):
    def __init__(self, top_k: int, **kwargs: Dict[str, Any]):
        """
        Args:
            top_k: number of top genes to return.
        """
        del kwargs
        super(COSMICKnowledgeBase, self).__init__(top_k)

    def retrieve(self, query: str) -> str:
        """
        Retrieve the top k genes for the specified cancer type.
        Input:
            query: a string cancer type.
        Returns:
            A string of the top k mutated genes and their population
            frequencies represented in natural language.
        """
        if query not in [type_.value for type_ in CancerType]:
            return "No information found."
        response_fmt = (
            "Gene {gene} is mutated in {cancer} in {freq:.1f}% of samples."
        )
        data_fn = os.path.join(
            os.path.dirname(os.path.dirname(__file__)),
            "COSMIC",
            f"GRCh38_{query}_COSMIC_v102_20250523.csv"
        )
        data = pd.read_csv(data_fn, nrows=self.top_k)

        if query == "BC":
            cancer_type = "Breast Cancer"
        elif query == "Colon":
            cancer_type = "Colon Cancer"
        elif query == "HemeOnc":
            cancer_type = "Hematological Malignancies"
        else:
            cancer_type = query

        out = []
        for _, row in data.iterrows():
            out.append(
                response_fmt.format(
                    gene=row["Gene name"],
                    cancer=cancer_type,
                    freq=(
                        100.0 * float(row["Mutated samples"]) / (
                            float(row["Samples tested"])
                        )
                    )
                )
            )
        return "\n".join(out)

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            None.
        Returns:
            The description of the knowledge base.
        """
        del args, kwargs
        return "Retrieves information about commonly mutated genes in cancer."

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "The cancer type to retrieve information about."
