# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import time
import xgboost
import logging
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV

from typing import Dict

from syne_tune.optimizer.schedulers.searchers.searcher import \
    SearcherWithRandomSeed
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges_factory \
    import make_hyperparameter_ranges
from syne_tune.optimizer.schedulers.searchers.bore.de import DifferentialevolutionOptimizer
from syne_tune.optimizer.schedulers.searchers.bore.classififer import GPModel, MLP

logger = logging.getLogger(__name__)


class Bore(SearcherWithRandomSeed):

    def __init__(
            self, config_space: dict, metric: str,
            points_to_evaluate=None, random_seed_generator=None,
            random_seed=None, mode: str = 'max', gamma: float = 0.25,
            calibrate: bool = False, classifier: str = 'xgboost',
            acq_optimizer: str = 'rs', feval_acq: int = 500,
            random_prob: float = 0.0, init_random: int = 6,
            classifier_kwargs: dict = None):

        """
        Implements "Bayesian optimization by Density Ratio Estimation" as described in the following paper:

        BORE: Bayesian Optimization by Density-Ratio Estimation,
        Tiao, Louis C and Klein, Aaron and Seeger, Matthias W and Bonilla, Edwin V. and Archambeau, Cedric and Ramos, Fabio
        Proceedings of the 38th International Conference on Machine Learning


        Note: Bore only works in the non-parallel non-multi-fideltiy setting. Make sure that you use it with the
        FIFO scheduler and set num_workers to 1 in the backend.

        :param config_space: Configuration space. Constant parameters are filtered out
        :param metric: Name of metric reported by evaluation function.
        :param points_to_evaluate:
        :param gamma: Defines the percentile, i.e how many percent of configuration are used to model l(x).
        :param calibrate: If set to true, we calibrate the predictions of the classifier via CV
        :param classifier: The binary classifier to model the acquisition function.
            Choices: {'mlp', 'gp', 'xgboost', 'rf}
        :param random_seed: seed for the random number generator
        :param acq_optimizer: The optimization method to maximize the acquisition function. Choices: {'de', 'rs'}
        :param feval_acq: Maximum allowed function evaluations of the acquisition function.
        :param random_prob: probability for returning a random configurations (epsilon greedy)
        :param init_random: Number of initial random configurations before we start with the optimization.
        :param classifier_kwargs: Dict that contains all hyperparameters for the classifier
        """

        super().__init__(
            config_space, metric=metric, points_to_evaluate=points_to_evaluate,
            random_seed_generator=random_seed_generator,
            random_seed=random_seed)

        self.calibrate = calibrate
        self.gamma = gamma
        self.classifier = classifier
        assert acq_optimizer in {'rs', 'de'}
        self.acq_optimizer = acq_optimizer
        self.feval_acq = feval_acq
        self.init_random = init_random
        self.random_prob = random_prob
        self.mode = mode

        self._hp_ranges = make_hyperparameter_ranges(config_space)

        if classifier_kwargs is None:
            classifier_kwargs = dict()
        if self.classifier == 'xgboost':
            self.model = xgboost.XGBClassifier(use_label_encoder=False)
        elif self.classifier == "logreg":
            self.model = LogisticRegression()
        elif self.classifier == 'rf':
            self.model = RandomForestClassifier()
        elif self.classifier == 'gp':
            self.model = GPModel(**classifier_kwargs)
        elif self.classifier == 'mlp':
            self.model = MLP(n_inputs=self._hp_ranges.ndarray_size(), **classifier_kwargs)

        self.inputs = []
        self.targets = []

    def configure_scheduler(self, scheduler):
        from syne_tune.optimizer.schedulers.fifo import FIFOScheduler

        assert isinstance(scheduler, FIFOScheduler), \
            "This searcher requires FIFOScheduler scheduler"

        super().configure_scheduler(scheduler)

    def loss(self, x):
        if len(x.shape) < 2:
            y = - self.model.predict_proba(x[None, :])
        else:
            y = - self.model.predict_proba(x)
        if self.classifier in ['gp', 'mlp']:
            return y[:, 0]
        else:
            return y[:, 1]  # return probability of class 1

    def get_config(self, **kwargs):
        """Function to sample a new configuration

        This function is called inside TaskScheduler to query a new
        configuration.

        Note: Query `_next_initial_config` for initial configs to return first.

        Args:
        kwargs:
            Extra information may be passed from scheduler to searcher
        returns: config
            must return a valid configuration
        """

        start_time = time.time()

        if len(self.inputs) < self.init_random or np.random.rand() < self.random_prob:
            config = self._hp_ranges.random_config(self.random_state)

        else:
            if self.acq_optimizer == 'de':

                def wrapper(x):
                    l = self.loss(x)
                    return l[:, None]

                bounds = np.array(self._hp_ranges.get_ndarray_bounds())
                lower = bounds[:, 0]
                upper = bounds[:, 1]

                de = DifferentialevolutionOptimizer(wrapper, lower, upper, self.feval_acq)
                best, traj = de.run()
                config = self._hp_ranges.from_ndarray(best)

            else:

                # sample random configurations without replacement
                values = []
                X = []
                counter = 0
                while len(values) < self.feval_acq and counter < 10:
                    xi = self._hp_ranges.random_config(self.random_state)
                    if xi not in X:
                        X.append(xi)
                        values.append(self.loss(self._hp_ranges.to_ndarray(xi))[0])
                        counter = 0
                    else:
                        logging.warning("Re-sampled the same configuration. Retry...")
                        counter += 1  # we stop sampling if after 10 retires we are not able to find a new config

                ind = np.array(values).argmin()
                config = X[ind]

        opt_time = time.time() - start_time
        logging.debug(f"[Select new candidate: "
                      f"config={config}] "
                      f"optimization time : {opt_time}")

        return config

    def _update(self, trial_id: str, config: Dict, result: Dict):
        """Update surrogate model with result

        :param config: new configuration
        :param result: observed results from the train function
        """

        start_time = time.time()

        self.inputs.append(self._hp_ranges.to_ndarray(config))
        self.targets.append(result[self._metric])

        if len(self.inputs) >= self.init_random:

            X = np.array(self.inputs)

            if self.mode == 'min':
                y = np.array(self.targets)
            else:
                y = - np.array(self.targets)

            tau = np.quantile(y, q=self.gamma)
            z = np.less(y, tau)

            if self.calibrate:
                self.model = CalibratedClassifierCV(
                    self.model, cv=2, method=self.calibration)
                self.model.fit(X, np.array(z, dtype=np.int))
            else:
                self.model.fit(X, np.array(z, dtype=np.int))

            z_hat = self.model.predict(X)
            accuracy = np.mean(z_hat == z)

            train_time = time.time() - start_time
            logging.debug(f"[Model fit: "
                          f"accuracy={accuracy:.3f}] "
                          f"dataset size: {X.shape[0]}, "
                          f"train time : {train_time}")

    def clone_from_state(self, state):
        pass
