#!/usr/bin/python3
"""
Implements the standard quasi-Expected Improvement (qEI) generative policy.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Jones DR, Schonlau M, Welch WJ. Efficient global optimization of
        expensive black-box functions. J Glob Opt 13:455-92. (1998). doi:
        10.1023/A:1008306431147

Portions of this code are adapted from the DynAMO repository by @michael-s-yao
at https://github.com/michael-s-yao/DynAMO/blob/main/src/dynamo/optim/qei.py

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import torch
from botorch import fit_gpytorch_mll  # type: ignore
from botorch.acquisition import qLogExpectedImprovement  # type: ignore
from botorch.models import SingleTaskGP  # type: ignore
from botorch.models.transforms.outcome import Standardize  # type: ignore
from botorch.optim import optimize_acqf  # type: ignore
from botorch.optim.optimize import (   # type: ignore
    optimize_acqf_mixed, optimize_acqf_discrete_local_search
)
from botorch.utils.transforms import normalize, unnormalize  # type: ignore
from gpytorch.mlls import ExactMarginalLogLikelihood  # type: ignore
from itertools import product
from pydantic import BaseModel
from typing import Any, Dict, Final, List, Optional, Tuple

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import Pydantic2Rd, Rd2Pydantic


class BOqEIOptimizer(BaseOptimizer):
    optimizer_name: str = "BO-qEI"

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        seed: int = 2025,
        num_restarts: int = 10,
        raw_samples: int = 512,
        categorical_bounds: List[List[int]] = [],
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size to use for Bayesian sampling per iteration.
            seed: random seed. Default 2025.
            num_restarts: the number of starting points for multistart
                acquisition function optimization. Default 10.
            raw_samples: the number of samples for initialization. Default 512.
            categorical_bounds: the bounds for each categorical dimension.
        """
        del kwargs
        super(BOqEIOptimizer, self).__init__(task, batch_size, seed=seed)
        self.sampling_bounds: torch.Tensor = task.sampling_bounds
        self.normalized_bounds: torch.Tensor = torch.ones_like(
            self.sampling_bounds
        )
        self.normalized_bounds[0, :] = 0
        self.num_restarts: Final[int] = num_restarts
        self.raw_samples: Final[int] = raw_samples
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.continuous: Final[bool] = bool(
            self.task.task_name in ["IWPCWarfarin-v0"]
        )

        # The categorical dimensions should always be at the beginning of
        # the design vectors.
        self.categorical_dims: Final[torch.Tensor] = torch.tensor(
            list(range(len(categorical_bounds))), device=self.device
        )
        self.categorical_bounds: Final[torch.Tensor] = torch.tensor(
            categorical_bounds, device=self.device
        )
        self.all_discrete: Final[bool] = (
            self.categorical_dims.numel() >= self.sampling_bounds.size(dim=-1)
        )
        if self.all_discrete:
            self.discrete_choices: List[torch.Tensor] = [
                torch.tensor(dim_choices, device=self.device)
                if not isinstance(dim_choices, torch.Tensor)
                else dim_choices.detach()
                for dim_choices in self.categorical_bounds
            ]
        else:
            self.fixed_features_list: Final[List[Dict[int, Any]]] = (
                self.generate_feature_dicts()
            )
        torch.manual_seed(self.seed)

    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        options: Optional[Dict[str, Any]] = {"batch_limit": 5, "maxiter": 200},
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Optimizes the quasi-Expected Improvement (qEI) acquisition function and
        returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
            options: optional keyword arguments for the acquisition function
                optimization call.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del state, knowledge, kwargs
        if len(self.categorical_bounds) == 0:
            candidates, _ = optimize_acqf(
                acq_function=self.acqf,
                bounds=self.normalized_bounds,
                q=self.batch_size,
                num_restarts=self.num_restarts,
                raw_samples=self.raw_samples,
                options=options
            )
            candidates = unnormalize(candidates, self.sampling_bounds)
            np_candidates = candidates.detach().cpu().numpy()
            return Rd2Pydantic(self.task, np_candidates, continuous=True), {}
        elif self.all_discrete:
            candidates, _ = optimize_acqf_discrete_local_search(
                acq_function=self.acqf,
                discrete_choices=self.discrete_choices,
                q=self.batch_size,
                num_restarts=self.num_restarts,
                raw_samples=self.raw_samples,
                unique=True
            )
            np_candidates = candidates.detach().cpu().numpy()
            return Rd2Pydantic(self.task, np_candidates, continuous=False), {}

        candidates, _ = optimize_acqf_mixed(
            acq_function=self.acqf,
            bounds=self.normalized_bounds,
            q=self.batch_size,
            num_restarts=self.num_restarts,
            raw_samples=self.raw_samples,
            fixed_features_list=self.fixed_features_list,
            options=options
        )

        cat_dims = self.categorical_dims
        cont_dims = [
            d
            for d in range(self.sampling_bounds.size(dim=-1))
            if d not in cat_dims
        ]

        cont_bounds = self.sampling_bounds[:, cont_dims]
        cont_unnorm = unnormalize(candidates[..., cont_dims], cont_bounds)

        disc = candidates[..., cat_dims].round()

        candidates = torch.zeros_like(candidates)
        candidates[:, cont_dims] = cont_unnorm
        candidates[:, cat_dims] = disc.float()
        x = Rd2Pydantic(
            self.task,
            candidates.detach().cpu().numpy(),
            continuous=self.continuous
        )
        return x, {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits a GP surrogate model and then optimizes the acquisition function
        based on the updated posterior.
        Input:
            X: a tensor of shape ND of all prior evaluated designs, where N is
                the number of designs and D is the number of design dimensions.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            The optimized acquisition function.
        """
        Xval = torch.from_numpy(
            Pydantic2Rd(self.task, X, continuous=self.continuous)
        )
        if torch.cuda.is_available():
            Xval = Xval.cuda()
        self.normalized_bounds = self.normalized_bounds.to(Xval)
        self.sampling_bounds = self.sampling_bounds.to(Xval)
        Xval = normalize(Xval, self.sampling_bounds).detach()
        yval = torch.from_numpy(y).to(Xval)
        Xval, yval = Xval.double(), yval.double().unsqueeze(dim=-1)
        if torch.cuda.is_available():
            Xval, yval = Xval.cuda(), yval.cuda()
        self.device = Xval.device
        z = normalize(Xval, self.sampling_bounds).detach()
        if z.ndim == 1:
            z = z.unsqueeze(dim=-1)
        if len(self.categorical_dims) > 0:
            z[:, self.categorical_dims] = Xval[:, self.categorical_dims]
        self.model = SingleTaskGP(z, yval, outcome_transform=Standardize(m=1))
        self.model = self.model.to(self.device)
        mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        fit_gpytorch_mll(mll)

        self.acqf = qLogExpectedImprovement(
            model=self.model, best_f=yval.max()
        )

    def generate_feature_dicts(self) -> List[Dict[int, Any]]:
        """
        Generate all possible feature dictionaries given a list of possible
        discrete dimension values.
        Input:
            None.
        Returns:
            A list of dictionaries mapping feature index to chosen value for
            every possible combination.
        """
        return [
            {i: combo[i] for i in range(len(combo))}
            for combo in product(*self.categorical_bounds)
        ]
