# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, List, Optional
import numpy as np
import logging

from syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer \
    import TransformerModelFactory
from syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_base \
    import BaseSurrogateModel
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.config_ext \
    import ExtendedConfiguration
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state \
    import TuningJobState
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.learncurve.gpiss_model \
    import GaussianProcessLearningCurveModel
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.learncurve.issm \
    import prepare_data, prepare_data_with_pending
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.learncurve.posterior_state \
    import GaussProcAdditivePosteriorState
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes \
    import SurrogateModel
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import \
    ConfigurationFilter, FantasizedPendingEvaluation
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.debug_log \
    import DebugLogPrinter
from syne_tune.optimizer.schedulers.utils.simple_profiler \
    import SimpleProfiler

logger = logging.getLogger(__name__)


class GaussProcAdditiveSurrogateModel(BaseSurrogateModel):
    def __init__(
            self, state: TuningJobState,
            gpmodel: GaussianProcessLearningCurveModel,
            fantasy_samples: List[FantasizedPendingEvaluation],
            active_metric: str,
            filter_observed_data: Optional[ConfigurationFilter] = None,
            normalize_mean: float = 0.0, normalize_std: float = 1.0):
        """
        Gaussian Process additive surrogate model, where model parameters are
        fit by marginal likelihood maximization.

        Note: `predict_mean_current_candidates` calls `predict` for all
        observed and pending extended configs. This may not be exactly
        correct, because `predict` is not meant to be used for configs
        which have observations (it IS correct at r = r_max).

        `fantasy_samples` contains the sampled (normalized) target values for
        pending configs. Only `active_metric` target values are considered.
        The target values for a pending config are a flat vector.

        :param state: TuningJobSubState
        :param gpmodel: GaussianProcessLearningCurveModel
        :param fantasy_samples: See above
        :param active_metric: See parent class
        :param filter_observed_data: See parent class
        :param normalize_mean: Mean used to normalize targets
        :param normalize_std: Stddev used to normalize targets

        """
        super().__init__(state, active_metric, filter_observed_data)
        self._gpmodel = gpmodel
        self.mean = normalize_mean
        self.std = normalize_std
        self.fantasy_samples = fantasy_samples

    def predict(self, inputs: np.ndarray) -> List[Dict[str, np.ndarray]]:
        """
        Input features `inputs` are w.r.t. extended configs (x, r).

        :param inputs: Input features
        :return: Predictive means, stddevs
        """
        predictions_list = []
        for post_mean, post_variance in self._gpmodel.predict(
                inputs):
            assert post_mean.shape[0] == inputs.shape[0], \
                (post_mean.shape, inputs.shape)
            assert post_variance.shape == (inputs.shape[0],), \
                (post_variance.shape, inputs.shape)
            # Undo normalization applied to targets
            mean_denorm = post_mean * self.std + self.mean
            std_denorm = np.sqrt(post_variance) * self.std
            predictions_list.append(
                {'mean': mean_denorm, 'std': std_denorm})
        return predictions_list

    def backward_gradient(
            self, input: np.ndarray,
            head_gradients: List[Dict[str, np.ndarray]]) -> List[np.ndarray]:
        poster_states = self.posterior_states
        assert poster_states is not None, \
            "Cannot run backward_gradient without a posterior state"
        assert len(poster_states) == len(head_gradients), \
            "len(posterior_states) = {} != {} = len(head_gradients)".format(
                len(poster_states), len(head_gradients))
        return [
            poster_state.backward_gradient(
                input, head_gradient, self.mean, self.std)
            for poster_state, head_gradient in zip(
                poster_states, head_gradients)]

    def does_mcmc(self):
        return False

    @property
    def posterior_states(self) -> Optional[List[GaussProcAdditivePosteriorState]]:
        return self._gpmodel.states


class GaussProcAdditiveModelFactory(TransformerModelFactory):
    def __init__(
            self, gpmodel: GaussianProcessLearningCurveModel,
            num_fantasy_samples: int,
            active_metric: str,
            configspace_ext: ExtendedConfiguration,
            normalize_targets: bool = False,
            profiler: Optional[SimpleProfiler] = None,
            debug_log: Optional[DebugLogPrinter] = None,
            filter_observed_data: Optional[ConfigurationFilter] = None):
        """
        If `num_fantasy_samples > 0`, we draw this many fantasy targets
        independently, while each sample is dependent over all pending
        evaluations. If `num_fantasy_samples == 0`, pending evaluations
        in `state` are ignored.

        :param gpmodel: GaussianProcessLearningCurveModel
        :param num_fantasy_samples: See above
        :param active_metric: Name of the metric to optimize.
        :param configspace_ext: ExtendedConfiguration
        :param normalize_targets: Normalize observed target values?
        :param debug_log: DebugLogPrinter (optional)
        :param filter_observed_data: Filter for observed data before
            computing incumbent

        """
        self._gpmodel = gpmodel
        self.active_metric = active_metric
        r_min, r_max = configspace_ext.resource_attr_range
        assert 0 < r_min < r_max, \
            f"r_min = {r_min}, r_max = {r_max}: Need 0 < r_min < r_max"
        assert num_fantasy_samples >= 0, \
            f"num_fantasy_samples = {num_fantasy_samples}, must be non-negative int"
        self.num_fantasy_samples = num_fantasy_samples
        self._configspace_ext = configspace_ext
        self._debug_log = debug_log
        self._profiler = profiler
        self._filter_observed_data = filter_observed_data
        self.normalize_targets = normalize_targets

    @property
    def debug_log(self) -> Optional[DebugLogPrinter]:
        return self._debug_log

    @property
    def profiler(self) -> Optional[SimpleProfiler]:
        return self._profiler

    def get_params(self):
        return self._gpmodel.get_params()

    def set_params(self, param_dict):
        self._gpmodel.set_params(param_dict)

    def model(self, state: TuningJobState, fit_params: bool) -> SurrogateModel:
        assert state.num_observed_cases(self.active_metric) > 0, \
            "Cannot compute posterior: state has no labeled datapoints " +\
            f"for metric {self.active_metric}"
        if self._debug_log is not None:
            self._debug_log.set_state(state)
        do_fantasizing = \
            state.pending_evaluations and self.num_fantasy_samples > 0

        # [1] Fit model and compute posterior state, ignoring pending evals
        data = prepare_data(
            state, self._configspace_ext, self.active_metric,
            normalize_targets=self.normalize_targets,
            do_fantasizing=False)
        if fit_params:
            logger.info(f"Fitting surrogate model for {self.active_metric}")
            self._gpmodel.fit(data, profiler=self._profiler)
        elif not do_fantasizing:
            # Only if part below is skipped
            logger.info("Recomputing posterior state")
            self._gpmodel.recompute_states(data)
        if self._debug_log is not None:
            self._debug_log.set_model_params(self.get_params())
        if self.normalize_targets:
            extra_kwargs = {
                'normalize_mean': data['mean_targets'],
                'normalize_std': data['std_targets']}
        else:
            extra_kwargs = dict()

        # [2] Fantasizing for pending evaluations (optional)
        if do_fantasizing:
            # Sample fantasy values for pending evaluations
            logger.info("Sampling fantasy target values for pending evaluations")
            state_with_fantasies = self._draw_fantasy_values(state)
            fantasy_samples = state_with_fantasies.pending_evaluations
            # Recompute posterior state with fantasy samples
            logger.info("Recomputing posterior state with fantasy targets")
            data = prepare_data(
                state=state_with_fantasies,
                configspace_ext=self._configspace_ext,
                active_metric=self.active_metric,
                normalize_targets=self.normalize_targets,
                do_fantasizing=True)
            self._gpmodel.recompute_states(data)
        else:
            fantasy_samples = []

        return GaussProcAdditiveSurrogateModel(
            state=state,
            gpmodel=self._gpmodel,
            fantasy_samples=fantasy_samples,
            active_metric=self.active_metric,
            filter_observed_data=self._filter_observed_data, **extra_kwargs)

    def model_for_fantasy_samples(
            self, state: TuningJobState,
            fantasy_samples: List[FantasizedPendingEvaluation]) -> SurrogateModel:
        """
        Same as `model` with `fit_params=False`, but `fantasy_samples` are
        passed in, rather than sampled here.

        :param state: See `model`
        :param fantasy_samples: See above
        :return: See `model`

        """
        assert state.num_observed_cases(self.active_metric) > 0, \
            "Cannot compute posterior: state has no labeled datapoints " +\
            f"for metric {self.active_metric}"
        assert state.pending_evaluations and self.num_fantasy_samples > 0

        # Recompute posterior state with fantasy samples
        state_with_fantasies = TuningJobState(
            hp_ranges=state.hp_ranges,
            config_for_trial=state.config_for_trial,
            trials_evaluations=state.trials_evaluations,
            failed_trials=state.failed_trials,
            pending_evaluations=fantasy_samples)
        # Recompute posterior state with fantasy samples
        data = prepare_data(
            state=state_with_fantasies,
            configspace_ext=self._configspace_ext,
            active_metric=self.active_metric,
            normalize_targets=self.normalize_targets,
            do_fantasizing=True)
        self._gpmodel.recompute_states(data)
        if self.normalize_targets:
            extra_kwargs = {
                'normalize_mean': data['mean_targets'],
                'normalize_std': data['std_targets']}
        else:
            extra_kwargs = dict()

        return GaussProcAdditiveSurrogateModel(
            state=state,
            gpmodel=self._gpmodel,
            fantasy_samples=fantasy_samples,
            active_metric=self.active_metric,
            filter_observed_data=self._filter_observed_data, **extra_kwargs)

    def _draw_fantasy_values(
            self, state: TuningJobState) -> TuningJobState:
        """
        Note: Fantasized target values are not de-normalized, because they
        are used internally only (see `prepare_data` with
        `do_fantasizing=True`).

        :param state: State with pending evaluations without fantasies
        :return: Copy of `state`, where `pending_evaluations` contains
            fantasized target values

        """
        assert self.num_fantasy_samples > 0
        # Fantasies are drawn in sequential chunks, one trial with pending
        # evaluations at a time.
        data_nopending, data_pending = prepare_data_with_pending(
            state=state,
            configspace_ext=self._configspace_ext,
            active_metric=self.active_metric,
            normalize_targets=self.normalize_targets)
        if not data_nopending['configs']:
            # It can happen that all trials with observed data also have
            # pending evaluations. This is possible only at the very start,
            # as long as no trial has been stopped or paused.
            # In this case, we find the trial with the largest number of
            # observed targets and remove its pending evaluations, so
            # `data_nopending` gets one entry. It is not possible to compute
            # a posterior state without any data, so handling this case
            # correctly would be very tedious).
            assert data_pending['configs'], \
                "State is empty, cannot do posterior inference:\n" +\
                str(state)
            names = ('configs', 'targets', 'trial_ids')
            elem = {k: data_pending[k].pop(0) for k in names}
            for k, v in elem.items():
                data_nopending[k] = [v]
            k = 'features'
            all_features = data_pending[k]
            data_nopending[k] = all_features[0].reshape((1, -1))
            data_pending[k] = all_features[1:, :]
            logger.info(
                "All trials currently have pending evaluations. In order to "
                "sample fantasy targets, I'll remove pending evaluations "
                f"from trial_id {elem['trial_ids']} (which has "
                f"{elem['targets'].size} observations)")
        # Start with posterior state, conditioned on data from trials without
        # pending evaluations
        self._gpmodel.recompute_states(data_nopending)
        poster_state_nopending = self._gpmodel.states[0]
        # Loop over trials with pending evaluations: For each trial, we sample
        # fantasy targets given observed ones, then update `poster_state` by
        # conditioning on both. This ensures we obtain a joint sample (the
        # ordering of trials does not matter). For the application here, we
        # do not need the final `poster_state`.
        all_fantasy_targets = []
        for sample_it in range(self.num_fantasy_samples):
            fantasy_targets, _ = poster_state_nopending.sample_and_update_for_pending(
                data_pending, sample_all_nonobserved=False,
                random_state=self._gpmodel.random_state)
            for pos, fantasies in enumerate(fantasy_targets):
                if sample_it == 0:
                    all_fantasy_targets.append([fantasies])
                else:
                    all_fantasy_targets[pos].append(fantasies)
        # Convert into `FantasizedPendingEvaluation`
        r_min = self._configspace_ext.resource_attr_range[0]
        pending_evaluations_with_fantasies = []
        for trial_id, targets, fantasies in zip(
                data_pending['trial_ids'], data_pending['targets'],
                all_fantasy_targets):
            n_observed = targets.size
            n_pending = fantasies[0].size
            start = r_min + n_observed
            resources = list(range(start, start + n_pending))
            fantasy_matrix = np.hstack(v.reshape((-1, 1)) for v in fantasies)
            assert fantasy_matrix.shape == (n_pending, self.num_fantasy_samples)
            for resource, fantasy in zip(resources, fantasy_matrix):
                pending_evaluations_with_fantasies.append(
                    FantasizedPendingEvaluation(
                        trial_id=trial_id,
                        fantasies={self.active_metric: fantasy},
                        resource=resource))
        # Return new state, with `pending_evaluations` replaced
        return TuningJobState(
            hp_ranges=state.hp_ranges,
            config_for_trial=state.config_for_trial,
            trials_evaluations=state.trials_evaluations,
            failed_trials=state.failed_trials,
            pending_evaluations=pending_evaluations_with_fantasies)
