#!/usr/bin/python3
"""
Knowledge retrieval functions for retrieving relevant prior knowledge.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import gc
import json
import logging
import torch
import warnings
from gymnasium.envs.registration import registry
from llama_index.core.embeddings import BaseEmbedding
from textwrap import dedent
from typing import Any, Dict, List, NamedTuple, Tuple, Union

from . import source
from .source import (
    __all__ as knowledge_sources, KnowledgeBase, DEFAULT_KNOWLEDGE_SOURCES
)
from .source.cosmic import CancerType
from ..envs.base import BaseTask
from ..optim.base import BaseOptimizer
from ..optim.llm import BaseLLMOptimizer


__PRIOR_KNOWLEDGE_SYSTEM_PROMPT: str = dedent("""
    You are a helpful biomedical knowledge assistant whose job is to retrieve
    relevant knowledge to help a domain expert solve a specific task in
    biology and medicine. Given a description of a patient or cell-line and a
    task, you will decide which knowledge source functions to retrieve from
    (and with what arguments). You may call multiple functions sequentially;
    when you have enough information, return a final answer of relevant prior
    knowledge without any further function calls. Be specific, concise, and
    comprehensive in your response.
""")


__PRIOR_KNOWLEDGE_SYSTEM_PROMPT_NO_TOOL_CALL: str = dedent("""
    You are a helpful biomedical knowledge assistant whose job is to provide
    relevant knowledge to help a domain expert solve a specific task in
    biology and medicine. Given a description of a patient or cell-line and a
    task, you will return a final answer of relevant prior knowledge. Be
    specific, concise, and comprehensive in your response.
""")


logger = logging.getLogger(__name__)


def _make_func_specs(
    task: BaseTask, sources: Tuple[str, ...]
) -> List[Dict[str, Any]]:
    """
    Make the function tool specifications for all implemented knowledge bases.
    Input:
        task: the task to make a knowledge source for.
        knowledge_sources: the allowed knowledge sources.
    Returns:
        A list of function tool specifications to be used for function calling.
    """
    if "None" in sources:
        return []
    if sources == ("default",):
        sources = DEFAULT_KNOWLEDGE_SOURCES
    sources = tuple([src for src in sources if src != "default"])

    func_specs: List[Dict[str, Any]] = []
    for attr in filter(
        lambda attr: attr.endswith("KnowledgeBase") and (
            attr not in ["KnowledgeBase", "RAGKnowledgeBase"]
        ),
        knowledge_sources
    ):
        if attr not in sources:
            continue
        knowledge_base = getattr(source, attr)

        func_specs.append({
            "type": "function",
            "function": {
                "name": f"{attr}",
                "description": knowledge_base.knowledge_description(
                    target=task.disease_name
                ),
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {
                            "type": "string",
                            "description": knowledge_base.query_description()
                        }
                    },
                    "required": ["query"],
                    "additionalProperties": False
                },
                "strict": True
            }
        })

        if attr == "COSMICKnowledgeBase":
            func_specs[-1]["function"]["parameters"]["properties"]["query"][
                "enum"
            ] = list(CancerType.__members__.values())

    return func_specs


def make_knowledge(
    task: BaseTask,
    optimizer: BaseOptimizer,
    z: NamedTuple,
    knowledge_sources: Tuple[str, ...],
    embedder: BaseEmbedding,
    top_k: int,
    user_prompt_version: str,
    ablate_distribution_shift: bool = False,
    max_queries: int = 8,
) -> Tuple[
    str,
    Dict[
        str,
        Union[int, List[Tuple[int, str, Dict[str, Any]]], List[Dict[str, str]]]
    ]
]:
    """
    Make the prior knowledge for a particular task and set of covariates.
    Input:
        task: the task to make a knowledge source for.
        optimizer: the optimizer to use for the knowledge source.
        z: the set of covariates (i.e., patient or cell-line information) to
            make a knowledge source for.
        knowledge_sources: the knowledge source(s) to use.
        embedder: the knowledge embedder to use for retrieval-based knowledge.
        top_k: the number of top-k results to retrieve.
        user_prompt_version: the version of the user prompt to use.
        ablate_distribution_shift: whether to ablate reasoning on the
            distribution shift.
        max_queries: the maximum number of LLM queries to make.
    Returns:
        A string of the prior knowledge and a dictionary of relevant metadata.
    """
    if len(knowledge_sources) == 0 or "None" in knowledge_sources:
        return "", {
            "num_input_tokens": 0,
            "num_output_tokens": 0,
            "tool_calls": [],
            "messages": []
        }
    __FUNCTION_DISPATCH_CACHE: Dict[str, KnowledgeBase] = {}

    logger.info("### Status: Constructing Prior Knowledge ###")
    if not isinstance(optimizer, BaseLLMOptimizer) or (
        user_prompt_version.lower() != "base"
    ):
        logger.info("  Method requires no prior knowledge, skipping...")
        return "", {
            "num_input_tokens": 0,
            "num_output_tokens": 0,
            "tool_calls": [],
            "messages": []
        }

    adversarial_knowledge_base = "AdversarialKnowledgeBase"
    if adversarial_knowledge_base in knowledge_sources:
        knowledge = getattr(source, adversarial_knowledge_base)().retrieve(
            task.task_name
        )
        return knowledge, {
            "num_input_tokens": 0,
            "num_output_tokens": 0,
            "tool_calls": [],
            "messages": []
        }

    assert task.task_name in registry.keys()
    tools = _make_func_specs(task, knowledge_sources)

    if len(tools):
        system_prompt = " ".join(
            line.strip()
            for line in __PRIOR_KNOWLEDGE_SYSTEM_PROMPT.split("\n")
        )
    else:
        system_prompt = " ".join(
            line.strip() for line in (
                __PRIOR_KNOWLEDGE_SYSTEM_PROMPT_NO_TOOL_CALL.split("\n")
            )
        )
    user_prompt = (
        f"{str(z)}\n\nProblem Description: "
        f"{task.task_description(ablate_distribution_shift)}\n\n"
        "Provide relevant factual information to help the expert solve "
        "the problem."
    )
    messages: List[Dict[str, str]] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    logger.info(
        f"  System Prompt:\n{system_prompt}\n\n  User Prompt:\n{user_prompt}"
    )

    called_tools: List[Tuple[int, str, Dict[str, Any]]] = []
    num_input_tokens, num_output_tokens = 0, 0
    for i in range(max_queries):
        response, metadata = optimizer.chat(
            messages=messages,
            tools=(tools if i < max_queries - 1 else [])
        )
        num_input_tokens += max(metadata.get("num_input_tokens", 0), 0)
        num_output_tokens += max(metadata.get("num_output_tokens", 0), 0)

        if not response.get("tool_calls"):
            __FUNCTION_DISPATCH_CACHE.clear()
            del __FUNCTION_DISPATCH_CACHE
            gc.collect()
            torch.cuda.empty_cache()
            output_knowledge = response.get("content", "").strip()
            logger.info(f"  Learned Knowledge: {output_knowledge}")
            return output_knowledge, {
                "num_input_tokens": num_input_tokens,
                "num_output_tokens": num_output_tokens,
                "tool_calls": called_tools,
                "messages": messages
            }

        func_name = response["tool_calls"][0]["function"]["name"]
        try:
            args = json.loads(
                response["tool_calls"][0]["function"]["arguments"]
            )
        except json.JSONDecodeError:
            warnings.warn(
                f"Failed to parse arguments for function {func_name}: "
                f"{response['tool_calls'][0]['function']['arguments']}."
                " Skipping this tool call and continuing..."
            )
            continue
        called_tools.append((i, func_name, args))
        if func_name not in __FUNCTION_DISPATCH_CACHE:
            __FUNCTION_DISPATCH_CACHE[func_name] = (
                getattr(source, func_name)(
                    top_k=top_k,
                    embedder=embedder,
                    task_name=task.task_name,
                    target=task.disease_name
                )
            )
        messages.append(response)
        try:
            logger.info(
                f"  Function Call {i + 1}:\n{json.dumps(response, indent=2)}"
            )
        except TypeError:
            logger.info(f"  Function Call {i + 1}:\n{response}")

        knowledge = __FUNCTION_DISPATCH_CACHE[func_name].retrieve(
            args["query"]
        )

        result = {
            "role": "tool",
            "tool_call_id": response["tool_calls"][0]["id"],
            "content": knowledge
        }
        messages.append(result)
        logger.info(
            f"  Function Output {i + 1}:\n{json.dumps(result, indent=2)}"
        )

    raise RuntimeError("Max number of queries reached.")
