#!/usr/bin/python3
"""
Config parse for LEON API.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from importlib import import_module
from typing import Any, Dict, Final, Optional, Sequence, Tuple

from ..core import get_equivalence_relation_options
from ..embedding import get_embedder_options
from ..knowledge import get_knowledge_source_options
from ..optim import get_optimizers, get_user_prompt_versions, BaseLLMOptimizer


__LLM_OPTIMIZERS__: Final[Sequence[str]] = [
    opt for opt in get_optimizers() if issubclass(
        getattr(import_module("...optim", package=__name__), opt),
        BaseLLMOptimizer
    )
]


__LEON_OPTIONS__: Final[Dict[str, Dict[str, Any]]] = {
    "llm": {
        "default": "GPT4oMiniOptimizer",
        "type": str,
        "options": __LLM_OPTIMIZERS__,
        "help": "LLM optimizer to use with LEON."
    },
    "knowledge_source": {
        "default": ("default",),
        "type": Optional[list],
        "options": get_knowledge_source_options(),
        "help": "The source(s) of prior knowledge to make available."
    },
    "knowledge_top_k": {
        "default": 8,
        "type": int,
        "options": None,
        "help": (
            "Number of documents to retrieve per query per "
            "knowledge source."
        )
    },
    "embedder_name": {
        "default": "openai/text-embedding-3-small",
        "type": str,
        "options": get_embedder_options(),
        "help": "Embedding model to use for embedding tasks."
    },
    "equivalence_relation": {
        "default": "KMeansEquivalenceRelation",
        "type": str,
        "options": get_equivalence_relation_options(),
        "help": "Equivalence relation to use in LEON."
    },
    "num_equivalence_classes": {
        "default": None,
        "type": Optional[int],
        "options": None,
        "help": "An optional fixed number of equivalence classes to use."
    },
    "lambda_": {
        "default": None,
        "type": Optional[float],
        "options": None,
        "help": "Optional fixed source critic weighting hyperparameter."
    },
    "mu_": {
        "default": None,
        "type": Optional[float],
        "options": None,
        "help": "Optional fixed LLM certainty hyperparameter."
    },
    "batch_size": {
        "default": 32,
        "type": int,
        "options": None,
        "help": "Sampling batch size."
    },
    "temperature": {
        "default": 1.0,
        "type": float,
        "options": None,
        "help": "LLM temperature hyperparameter."
    },
    "w0": {
        "default": 1.0,
        "type": float,
        "options": None,
        "help": "1-Wasserstein distance constraint bound hyperparameter."
    },
    "user_prompt_version": {
        "default": "base",
        "type": str,
        "options": get_user_prompt_versions(),
        "help": "Version of the LLM user prompt to use."
    },
    "do_reflection": {
        "default": True,
        "type": bool,
        "options": None,
        "help": "Whether to perform reflection as in Ma YJ et al. ICLR (2024)."
    },
    "seed": {
        "default": None,
        "type": Optional[int],
        "options": None,
        "help": "Optional random seed."
    }
}


def show_options(
    *args: Tuple[Any], disp: bool = True, **kwargs: Dict[str, Any]
) -> Optional[str]:
    """
    Show documentation for additional options of the LEON optimization solver.
    Input:
        disp: whether to print the result rather than returning it.
    Returns:
        Either None or the text string of documentation of additional options.
    """
    text = "### LEON Configuration Options\n\n"
    for opt, spec in __LEON_OPTIONS__.items():
        text += f"Parameter: {opt}\n"
        text += f"    Description: {spec['help']}\n"
        text += f"    Type: {spec['type']}\n"
        text += f"    Default Value: {spec['default']}\n"
        if spec["options"]:
            text += f"    Possible Values: {' | '.join(spec['options'])}\n"
        text += "\n"
    text = text.strip()
    if disp:
        return print(text)
    return text


def parse_options(options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Parses the options for the LEON API.
    Input:
        options: an optional dict of option specifications for the LEON API.
    Returns:
        The parsed options for the LEON API.
    """
    parsed = {key: val["default"] for key, val in __LEON_OPTIONS__.items()}
    if not options:
        return parsed
    _unexpected = [k for k in options.keys() if k not in options.keys()]
    if len(_unexpected):
        raise ValueError(f"Unexpected options: {', '.join(_unexpected)}")
    for key, val in options.items():
        _type = __LEON_OPTIONS__[key]["type"]
        if not isinstance(val, _type):
            raise ValueError(
                f"Value {val} is not of type {_type} for key {key}"
            )
        _opts = __LEON_OPTIONS__[key]["options"]
        if _opts and isinstance(val, list) and (
            _err := [v for v in val if v not in _opts]
        ):
            raise ValueError(
                f"Value(s) {', '.join(_err)} for key {key} are not in {_opts}"
            )
        elif _opts and not isinstance(val, list) and val not in _opts:
            raise ValueError(f"Value {val} for key {key} is not in {_opts}")
        parsed[key] = val
    return parsed
