#!/usr/bin/python3
"""
Adapter class to expose expected endpoints for running LLAMBO (Large Language
Models to enhance Bayesian Optimization).

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Liu T, Astorga N, Seedat N, van der Schaar M. Large language models
        to enhance Bayesian optimization. Proc ICLR (2024). URL:
        https://openreview.net/forum?id=OOxotBmGol

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
from typing import Any, Dict, Final, List, NamedTuple, Optional, Tuple

from ..data.utils import HIVDB_FEATURES
from ..envs.base import BaseTask
from ..utils import initialize_designs


class LLAMBOWrapperOptimizer:
    def __init__(
        self,
        task: BaseTask,
        ref_individual: NamedTuple,
        seed: int,
        task_context: Optional[Dict[str, Any]] = None
    ):
        """
        Args:
            task: the optimization task.
            ref_individual: the individual to perform optimization for.
            seed: random seed.
            task_context: optional task context to override the default.
        """
        self.seed: Final[int] = seed
        self.task: Final[BaseTask] = task
        self._task_context: Final[Optional[Dict[str, Any]]] = task_context
        self.model: Final[str] = str(self.task_context["model"])
        self.ref_individual: Final[NamedTuple] = ref_individual

    def generate_initialization(self, n_samples: int) -> List[Dict[str, Any]]:
        """
        Initialize the set of designs for the optimizer.
        Args:
            n_samples: the number of designs to initialize for the optimizer.
        Returns:
            A list of the initial design configurations.
        """
        init_configs, _ = initialize_designs(
            self.task, n_samples, seed=self.seed
        )
        return [json.loads(x.model_dump_json()) for x in init_configs]

    def evaluate_point(
        self, candidate_config: Dict[str, Any]
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Evaluate a single candidate design.
        Input:
            candidate_config: an input dictionary containing the point to be
                evaluated.
        Returns:
            candidate_config: the original input dictionary containing the
                evaluated point.
            fvals: a dictionary containing the evaluation results.
        """
        if self.task.task_name in ["HIVDB-v0"]:
            _config = {
                key: (int(val) if "adjuvant" not in str(val) else str(val))
                for key, val in candidate_config.items()
            }
        else:
            _config = candidate_config

        score = self.task(
            self.task.extend(
                [
                    self.task.design_schema.model_validate_json(
                        json.dumps(_config)
                    )
                ],
                self.ref_individual
            )
        )
        out = {"score": score[0], "generalization_score": score[0]}
        return candidate_config, out

    @property
    def task_context(self) -> Dict[str, Any]:
        """
        Returns the task context for LLAMBO.
        Input:
            None.
        Returns:
            The task context for LLAMBO.
        """
        if self._task_context is not None:
            return self._task_context
        num_cat_feats = len(self.task.train[0].discrete_features().keys())
        task_context = {
            "task": self.task,
            "model": self.task._train_model.__class__.__name__,
            "tot_feats": self.task.ndim,
            "cat_feats": num_cat_feats,
            "num_feats": self.task.ndim - num_cat_feats,
            "n_classes": 2,
            "metric": "score",
            "lower_is_better": False,
            "hyperparameter_constraints": {}
        }

        constraints = {}
        if self.task.task_name == "IWPCWarfarin-v0":
            constraints["warfarin_dose"] = ["float", "linear", [0, 45]]
        elif self.task.task_name == "HIVDB-v0":
            constraints.update({
                x: ["ordinal", None, [0, 1]]
                for x in HIVDB_FEATURES["medications"]
            })
        else:
            raise NotImplementedError
        task_context["hyperparameter_constraints"] = constraints

        return task_context
