#!/usr/bin/python3
"""
LEON API that follows the `scipy.optimize.minimize()` API convention.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import torch
from importlib import import_module
from scipy.optimize import OptimizeResult
from tqdm import tqdm
from typing import Any, Dict, Optional

from .config import parse_options
from ..embedding import get_embedder
from ..envs.base import BaseTask
from ..knowledge import make_knowledge
from ..model import LipschitzMLP
from ..optim import BaseOptimizer
from ..optim.state import OptimizerState
from ..utils import initialize_designs


def maximize(
    task: BaseTask,
    x0: Optional[np.ndarray],
    args: Any,
    options: Optional[Dict[str, Any]] = None,
    **kwargs: Dict[str, Any]
) -> OptimizeResult:
    """
    Maximization of a scalar function of one or more variables.
    Input:
        task: the personalized medicine task to be minimized.
        x0: an optional initial guess or batch of initial guesses of shape
            (D,) or (N, D), where D is the number of independent variables and
            N is the number of initial guesses.
        args: the patient to perform the optimization for.
        options: a dictionary of solver options. For method-specific options,
            see :func:`show_options()`.
    Returns:
        The optimization result with the following attributes:
            x: the solution of the optimization.
            success: whether or not the optimizer exited successfully.
            fun: the value of the hidden oracle objective function at `x`.
            nfev: the number of evaluations of the oracle objective function.
            nit: the number of iterations performed by the optimizer (which
                is equal to the number of surrogate function evaluations).
    """
    del kwargs
    categorical_config = task.train[0].discrete_features()
    config: Dict[str, Any] = parse_options(options)
    optim_module = import_module("...optim", package=__name__)
    core_module = import_module("...core", package=__name__)

    assert len(args) == 1 and isinstance(args[0], int)
    test_idxs = np.random.default_rng(config["seed"]).choice(
        len(task.test), size=len(task.test), replace=False
    )
    ref = task.test[test_idxs[args[0]]]

    optim: BaseOptimizer = getattr(optim_module, config["llm"])(
        task=task,
        ref=ref,
        batch_size=config["batch_size"],
        temperature=config["temperature"],
        user_prompt_version=config["user_prompt_version"],
        reflection=config["do_reflection"],
        categorical_bounds=[
            list(range(len(vals))) for vals in categorical_config.values()
        ]
    )

    state = OptimizerState(
        task=task,
        optimizer_name=config["llm"],
        individual=ref,
        seed=config["seed"]
    )

    embedder = get_embedder(config["embedder_name"])

    knowledge = {"task": task.task_description()}
    prior_knowledge, knowledge_metadata = make_knowledge(
        task=task,
        optimizer=optim,
        z=state.individual,
        knowledge_sources=config["knowledge_source"],
        embedder=embedder,
        top_k=config["knowledge_top_k"],
        user_prompt_version=config["user_prompt_version"],
    )
    knowledge["prior_knowledge"] = prior_knowledge

    transform_cls = getattr(core_module, "EntropyPenalizedTransform")
    sim = getattr(core_module, config["equivalence_relation"])(
        task=task,
        embedder=embedder,
        num_equivalence_classes=config["num_equivalence_classes"]
    )
    forward_model = transform_cls(
        task=task,
        critic=LipschitzMLP(in_dim=task.ndim),
        Xp=torch.vstack([
            task.train[i].as_tensor() for i in range(len(task.train))
        ]),
        equiv_relation=sim,
        lambda_=config["lambda_"],
        mu_=config["mu_"],
        W0=config["w0"]
    )

    reference_designs = task.reduce([
        task.train[i] for i in range(len(task.train))
    ])

    for _ in tqdm(
        range(state.max_samples // config["batch_size"]),
        desc=f"{task.task_name}: {optim.optimizer_name} | {args[0]}"
    ):
        if state.has_converged:
            break

        if not state.designs:
            new_xq, metadata = initialize_designs(
                task, config["batch_size"], seed=config["seed"]
            )
        else:
            new_xq, metadata = optim(state=state, knowledge=knowledge)

        if not new_xq:
            continue
        full_xq = task.extend(new_xq, state.individual)
        new_yq, *_ = forward_model(full_xq)
        state.log(new_xq, new_yq, **metadata)

        optim.fit(state.designs, state.predictions)
        forward_model.fit(reference_designs, state.designs)

    designs = task.extend(state.xq, ref)
    idxs = np.argsort(state.yq.reshape(-1))[-1:]
    scores = np.array(state.task.predict([designs[i] for i in idxs]))
    return OptimizeResult(
        x=state.xq[idxs[0]],
        success=True,
        func=float(scores.item()),
        nfev=1,
        nit=state.max_samples
    )
