#!/usr/bin/python3
"""
Knowledge retrieval through MedGemma, a variant of Gemma 3 trained for medical
text comprehension.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import torch
from transformers import pipeline
from typing import Any, Dict, Final

from .base import KnowledgeBase


class MedGemmaKnowledgeBase(KnowledgeBase):
    def __init__(
        self,
        top_k: int,
        model_id: str,
        max_new_tokens: int = 2048,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            top_k: the number of documents to retrieve per query.
            model_id: the MedGemma model to use. Must be one of
                [`google/medgemma-4b-it`, `google/medgemma-27b-text-it`].
            max_new_tokens: maximum number of new tokens to generate.
        """
        del kwargs
        assert model_id in [
            "google/medgemma-4b-it", "google/medgemma-27b-text-it"
        ]
        super(MedGemmaKnowledgeBase, self).__init__(top_k)

        self.model_id: Final[str] = model_id
        self.max_new_tokens: Final[int] = max_new_tokens
        self.pipe = pipeline(
            "text-generation",
            model=self.model_id,
            torch_dtype=torch.bfloat16,
            device=("cuda" if torch.cuda.is_available() else "cpu")
        )

    def retrieve(self, query: str) -> str:
        """
        Retrieves the most relevant information for a given query.
        Input:
            query: a string query.
        Returns:
            A string of the retrieved information.
        """
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": query}
        ]
        if self.model_id == "google/medgemma-4b-it":
            for i in range(len(messages)):
                messages[i]["content"] = [  # type: ignore
                    {"type": "text", "text": messages[i]["content"]}
                ]
        response = self.pipe(messages, max_new_tokens=self.max_new_tokens)
        response = response[0]["generated_text"][-1]["content"]

        response = "{" + response.split("{", 1)[-1].split("}", 1)[0] + "}"
        text = []
        try:
            response = json.loads(response)
            for topic, facts in response.items():
                if len(facts) > self.top_k:
                    facts = facts[:self.top_k]
                facts = " ".join(facts)
                text.append(f"{topic}: {facts}")
        except json.JSONDecodeError:
            text.append(response)
        return "\n".join(text)

    @property
    def system_prompt(self) -> str:
        """
        Returns the system prompt to use for the MedGemma knowledge base.
        Input:
            None.
        Returns:
            The system prompt to use for the MedGemma knowledge base.
        """
        return (
            "You are a helpful medical assistant that provides accurate and "
            "evidence-based information on medical topics. When given a "
            "query, you will identify each medical topic separated by "
            "semicolons and provide detailed and precise information. "
            "Follow these guidelines:\n"
            "    1. **Topic Identification:** Recognize and separate medical "
            "topics provided in the query by semicolons.\n"
            "    2. **Fact Retrieval:** For each topic, retrieve up to "
            f"{self.top_k} facts that are accurate, objective, evidence-based "
            "and up-to-date.\n"
            "    3. **Conciseness:** Offer clear and concise medical "
            "explanations for each topic.\n"
            "    4. **Formatting:** Present the information in a structured "
            "format that is easy to read.\n"
            "    5. **Clarity:** Ensure explanations are accessible to both "
            "medical professionals and patients.\n"
            "Respond in JSON format using the following template:\n"
            "    {\n"
            "        topic_1: [fact_1, fact_2, fact_3, ...],\n"
            "        topic_2: [fact_1, fact_2, fact_3, ...]\n"
            "    }"
        )

    @classmethod
    def knowledge_description(
        cls, *args: Any, **kwargs: Dict[str, Any]
    ) -> str:
        """
        Returns a description of the knowledge.
        Input:
            None.
        Returns:
            The description of the knowledge base.
        """
        del args, kwargs
        return "Retrieves information about biomedical topics."

    @classmethod
    def query_description(cls) -> str:
        """
        Returns a description of the expected query type.
        Input:
            None.
        Returns:
            The description of the expected query type.
        """
        return "A string query to ask the medical expert about."


class MedGemma4BKnowledgeBase(MedGemmaKnowledgeBase):
    def __init__(
        self, top_k: int, max_new_tokens: int = 2048, **kwargs
    ):
        """
        Args:
            top_k: the number of documents to retrieve per query.
            max_new_tokens: maximum number of new tokens to generate.
        """
        super(MedGemma4BKnowledgeBase, self).__init__(
            top_k,
            model_id="google/medgemma-4b-it",
            max_new_tokens=max_new_tokens,
            **kwargs
        )


class MedGemma27BKnowledgeBase(MedGemmaKnowledgeBase):
    def __init__(
        self, top_k: int, max_new_tokens: int = 2048, **kwargs
    ):
        """
        Args:
            top_k: the number of documents to retrieve per query.
            max_new_tokens: maximum number of new tokens to generate.
        """
        super(MedGemma27BKnowledgeBase, self).__init__(
            top_k,
            model_id="google/medgemma-27b-text-it",
            max_new_tokens=max_new_tokens,
            **kwargs
        )
