#! /usr/bin/env python3

# Similar to: https://github.com/pytorch/botorch/blob/main/botorch/optim/initializers.py

from __future__ import annotations

import torch
from torch import Tensor
from botorch.utils.sampling import boltzmann_sample
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.optim.initializers import gen_batch_initial_conditions

from rescue.acquisition.causal_knowledge_gradient import (
    qCausalHypervolumeKnowledgeGradient,
    qMultiFidelityCausalHypervolumeKnowledgeGradient,
    causal_hv_value_function
)

def gen_one_shot_hvkg_initial_conditions(
    acq_function: qCausalHypervolumeKnowledgeGradient,
    bounds: Tensor,
    q: int,
    num_restarts: int,
    raw_samples: int,
    fixed_features: dict[int, float] | None = None,
    options: dict[str, bool | float | int] | None = None,
    inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
    equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
) -> Tensor | None:
    r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.

    This function generates initial conditions for optimizing one-shot HVKG using
    the hypervolume maximizing set (of fixed size) under the posterior mean.
    Intutively, the hypervolume maximizing set of the fantasized posterior mean
    will often be close to a hypervolume maximizing set under the current posterior
    mean. This function uses that fact to generate the initial conditions
    for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
    options) of the restarts are generated by learning the hypervolume maximizing sets
    under the current posterior mean, where each hypervolume maximizing set is
    obtained from maximizing the hypervolume from a different starting point. Given
    a hypervolume maximizing set, the `q` candidate points are selected using to the
    standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
    hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
    as well as all `q` candidate points are chosen according to the standard
    initialization strategy in `gen_batch_initial_conditions`.

    Args:
        acq_function: The qKnowledgeGradient instance to be optimized.
        bounds: A `2 x d` tensor of lower and upper bounds for each column of
            task features.
        q: The number of candidates to consider.
        num_restarts: The number of starting points for multistart acquisition
            function optimization.
        raw_samples: The number of raw samples to consider in the initialization
            heuristic.
        fixed_features: A map `{feature_index: value}` for features that
            should be fixed to a particular value during generation.
        options: Options for initial condition generation. These contain all
            settings for the standard heuristic initialization from
            `gen_batch_initial_conditions`. In addition, they contain
            `frac_random` (the fraction of fully random fantasy points),
            `num_inner_restarts` and `raw_inner_samples` (the number of random
            restarts and raw samples for solving the posterior objective
            maximization problem, respectively) and `eta` (temperature parameter
            for sampling heuristic from posterior objective maximizers).
        inequality constraints: Optionally, list of tuples (indices, coefficients, rhs),
            with each tuple encoding an inequality constraint of the form
            `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. Each
            tensor of indices must be one-dimensional, since inter-point
            constraints are not supported here.
        equality constraints: Optionally, a list of tuples (indices, coefficients, rhs),
            with each tuple encoding an inequality constraint of the form
            `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.

    Returns:
        A `num_restarts x q' x d` tensor that can be used as initial conditions
        for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
        of points (candidate points plus fantasy points).
    """
    from botorch.optim.optimize import optimize_acqf

    options = options or {}
    frac_random: float = options.get("frac_random", 0.1)
    if not 0 < frac_random < 1:
        raise ValueError(
            f"frac_random must take on values in (0,1). Value: {frac_random}"
        )

    value_function = causal_hv_value_function(
        model=acq_function.model,
        ref_point=acq_function.ref_point,
        objective=acq_function.objective,
        sampler=acq_function.inner_sampler,
        use_posterior_mean=acq_function.use_posterior_mean,
    )

    is_mf_hvkg = isinstance(acq_function, qMultiFidelityCausalHypervolumeKnowledgeGradient)
    if is_mf_hvkg:
        dim = bounds.shape[-1]
        fidelity_dims, fidelity_targets = zip(*acq_function.target_fidelities.items())
        value_function = FixedFeatureAcquisitionFunction(
            acq_function=value_function,
            d=dim,
            columns=fidelity_dims,
            values=fidelity_targets,
        )

        non_fidelity_dims = list(set(range(dim)) - set(fidelity_dims))
    num_optim_restarts = int(round(num_restarts * (1 - frac_random)))
    fantasy_cands, fantasy_vals = optimize_acqf(
        acq_function=value_function,
        bounds=bounds[:, non_fidelity_dims] if is_mf_hvkg else bounds,
        q=acq_function.num_pareto,
        num_restarts=options.get("num_inner_restarts", 20),
        raw_samples=options.get("raw_inner_samples", 1024),
        return_best_only=False,
        options=options,
        inequality_constraints=inequality_constraints,
        equality_constraints=equality_constraints,
        sequential=False,
    )
    # sampling from the optimizers
    if num_optim_restarts > 0:
        idx = boltzmann_sample(
            function_values=fantasy_vals,
            num_samples=num_optim_restarts * acq_function.num_fantasies,
            eta=options.get("eta", 2.0),
            replacement=True,
        )

        optim_ics = fantasy_cands[idx]
        if is_mf_hvkg:
            # add fixed features
            optim_ics = value_function._construct_X_full(optim_ics)
        optim_ics = optim_ics.reshape(
            num_optim_restarts, acq_function.num_pseudo_points, bounds.shape[-1]
        )

    # get random initial conditions
    num_random_restarts = num_restarts - num_optim_restarts
    if num_random_restarts > 0:
        q_aug = acq_function.get_augmented_q_batch_size(q=q)
        base_ics = gen_batch_initial_conditions(
            acq_function=acq_function,
            bounds=bounds,
            q=q_aug,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            fixed_features=fixed_features,
            options=options,
            inequality_constraints=inequality_constraints,
            equality_constraints=equality_constraints,
        )

        if num_optim_restarts > 0:
            probs = torch.full(
                (num_restarts,),
                1.0 / num_restarts,
                dtype=optim_ics.dtype,
                device=optim_ics.device,
            )
            optim_idxr = probs.multinomial(
                num_samples=num_optim_restarts, replacement=False
            )
            base_ics[optim_idxr, q:] = optim_ics
    else:
        # optim_ics is num_restarts x num_pseudo_points x d
        # add padding so that base_ics is num_restarts x q+num_pseudo_points x d
        q_padding = torch.zeros(
            optim_ics.shape[0],
            q,
            optim_ics.shape[-1],
            dtype=optim_ics.dtype,
            device=optim_ics.device,
        )
        base_ics = torch.cat([q_padding, optim_ics], dim=-2)

    if num_optim_restarts > 0:
        all_ics = []
        if num_random_restarts > 0:
            optim_idcs = optim_idxr.view(-1).tolist()
        else:
            optim_idcs = list(range(num_restarts))
        for i in list(range(num_restarts)):
            if i in optim_idcs:
                # optimize the q points,
                # given fixed, optimized fantasy designs
                ics = gen_batch_initial_conditions(
                    acq_function=acq_function,
                    bounds=bounds,
                    q=q,
                    num_restarts=1,
                    raw_samples=raw_samples,
                    fixed_features=fixed_features,
                    options=options,
                    inequality_constraints=inequality_constraints,
                    equality_constraints=equality_constraints,
                    fixed_X_fantasies=base_ics[i, q:],
                )
            else:
                # ics are all randomly sampled
                ics = base_ics[i : i + 1]
            all_ics.append(ics)
        return torch.cat(all_ics, dim=0)

    return base_ics