# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# Could eventually remove this code: Is this needed in unit tests?

"""
Object definitions that are used for testing.
"""

from typing import Iterator, Tuple, Dict, List, Optional, Union
import numpy as np

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common \
    import Hyperparameter, Configuration, dictionarize_objective
from syne_tune.search_space import Categorical, loguniform, randint, \
    choice, uniform
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges \
    import HyperparameterRanges
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges_factory \
    import make_hyperparameter_ranges
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state \
    import TuningJobState
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import \
    TrialEvaluations, PendingEvaluation
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.constants \
    import MCMCConfig, OptimizationConfig
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gp_regression \
    import GaussianProcessRegression
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gpr_mcmc \
    import GPRegressionMCMC
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel \
    import Matern52, KernelFunction
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.warping \
    import WarpedKernel, Warping
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes \
    import CandidateGenerator
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common \
    import ExclusionList


def build_kernel(state: TuningJobState,
                 do_warping: bool = False) -> KernelFunction:
    dims, warping_ranges = dimensionality_and_warping_ranges(state.hp_ranges)
    kernel = Matern52(dims, ARD=True)
    if do_warping:
        return WarpedKernel(
            kernel=kernel, warping=Warping(dims, warping_ranges))
    else:
        return kernel


def default_gpmodel(
        state: TuningJobState, random_seed: int,
        optimization_config: OptimizationConfig) -> GaussianProcessRegression:
    return GaussianProcessRegression(
        kernel=build_kernel(state),
        optimization_config=optimization_config,
        random_seed=random_seed
    )


def default_gpmodel_mcmc(
        state: TuningJobState, random_seed: int,
        mcmc_config: MCMCConfig) -> GPRegressionMCMC:
    return GPRegressionMCMC(
        build_kernel=lambda: build_kernel(state),
        mcmc_config=mcmc_config,
        random_seed=random_seed
    )


def dimensionality_and_warping_ranges(hp_ranges: HyperparameterRanges) -> \
        Tuple[int, Dict[int, Tuple[float, float]]]:
    lower_config = dict()
    upper_config = dict()
    for name, hp_range in hp_ranges.config_space.items():
        if not isinstance(hp_range, Categorical):
            lower_config[name] = hp_range.lower
            upper_config[name] = hp_range.upper
        else:
            lower_config[name] = hp_range.categories[0]
            upper_config[name] = hp_range.categories[0]
    lower_internal = hp_ranges.to_ndarray(lower_config)
    upper_internal = hp_ranges.to_ndarray(upper_config)
    dims = 0
    warping_ranges = dict()
    for name in hp_ranges.internal_keys:
        hp_range = hp_ranges.config_space[name]
        if not isinstance(hp_range, Categorical):
            _lower = lower_internal[dims]
            _upper = upper_internal[dims]
            if _upper > _lower:  # exclude cases where max equal to min
                warping_ranges[dims] = (_lower, _upper)
            else:
                assert _lower == _upper
            dims += 1
        else:
            dims += len(hp_range.categories)
    return dims, warping_ranges


class RepeatedCandidateGenerator(CandidateGenerator):
    """Generates candidates from a fixed set. Used to test the deduplication logic."""
    def __init__(self, n_unique_candidates: int):
        self.config_space = {
            'a': uniform(0, n_unique_candidates),
            'b': randint(0, n_unique_candidates),
            'c': choice([f"value_{i}" for i in range(n_unique_candidates)])}
        self.hp_ranges = make_hyperparameter_ranges(self.config_space)
        self.all_unique_candidates = [
            {'a': 1.0*j, 'b': j, 'c': f"value_{j}"}
            for j in range(n_unique_candidates)]

    def generate_candidates(self) -> Iterator[Configuration]:
        i = 0
        while True:
            i += 1
            yield self.all_unique_candidates[i % len(self.all_unique_candidates)]


# Example black box function, with adjustable location of global minimum.
# Potentially could catch issues with optimizer, e.g. if the optimizer
# ignoring somehow candidates on the edge of search space.
# A simple quadratic function is used.
class Quadratic3d:
    def __init__(self, local_minima, active_metric, metric_names):
        # local_minima: point where local_minima is located
        self.local_minima = np.array(local_minima).astype('float')
        self.local_minima[0] = np.log10(self.local_minima[0])
        self.active_metric = active_metric
        self.metric_names = metric_names

    @property
    def search_space(self):
        config_space = {
            'x': loguniform(1.0, 100.0),
            'y': randint(0, 2),
            'z': choice(['0.0', '1.0', '2.0'])}
        return make_hyperparameter_ranges(config_space)

    @property
    def f_min(self):
        return 0.0

    def __call__(self, candidate):
        p = np.array([float(hp) for hp in candidate])
        p[0] = np.log10(p[0])
        return dictionarize_objective(np.sum((self.local_minima - p) ** 2))


def tuples_to_configs(config_tpls: List[Tuple[Hyperparameter, ...]],
                      hp_ranges: HyperparameterRanges) -> List[Configuration]:
    """
    Many unit tests write configs as tuples.

    """
    return [hp_ranges.tuple_to_config(x) for x in config_tpls]


def create_exclusion_set(
        candidates_tpl, hp_ranges: HyperparameterRanges,
        is_dict: bool = False) -> ExclusionList:
    """
    Creates exclusion list from set of tuples.

    """
    if not is_dict:
        candidates_tpl = tuples_to_configs(candidates_tpl, hp_ranges)
    config_for_trial = {
        str(trial_id): config for trial_id, config in enumerate(candidates_tpl)}
    state = TuningJobState(
        hp_ranges=hp_ranges,
        config_for_trial=config_for_trial,
        trials_evaluations=[],
        failed_trials=[str(x) for x in range(len(candidates_tpl))])
    return ExclusionList(state)


TupleOrDict = Union[tuple, dict]


def create_tuning_job_state(
        hp_ranges: HyperparameterRanges, cand_tuples: List[TupleOrDict],
        metrics: List[Dict],
        pending_tuples: Optional[List[TupleOrDict]] = None,
        failed_tuples: Optional[List[TupleOrDict]] = None) -> TuningJobState:
    """
    Builds `TuningJobState` from basics, where configs are given as tuples or
    as dicts.

    NOTE: We assume that all configs in the different lists are different!

    """
    if cand_tuples and isinstance(cand_tuples[0], tuple):
        configs = tuples_to_configs(cand_tuples, hp_ranges)
    else:
        configs = cand_tuples
    trials_evaluations = [TrialEvaluations(trial_id=str(trial_id), metrics=y)
                          for trial_id, y in enumerate(metrics)]
    pending_evaluations = None
    if pending_tuples is not None:
        sz = len(configs)
        extra = len(pending_tuples)
        if pending_tuples and isinstance(pending_tuples[0], tuple):
            extra_configs = tuples_to_configs(pending_tuples, hp_ranges)
        else:
            extra_configs = pending_tuples
        configs.extend(extra_configs)
        pending_evaluations = [PendingEvaluation(trial_id=str(trial_id))
                               for trial_id in range(sz, sz + extra)]
    failed_trials = None
    if failed_tuples is not None:
        sz = len(configs)
        extra = len(failed_tuples)
        if failed_tuples and isinstance(failed_tuples[0], tuple):
            extra_configs = tuples_to_configs(failed_tuples, hp_ranges)
        else:
            extra_configs = failed_tuples
        configs.extend(extra_configs)
        failed_trials = [str(x) for x in range(sz, sz + extra)]

    config_for_trial = {
        str(trial_id): config for trial_id, config in enumerate(configs)}
    return TuningJobState(
        hp_ranges=hp_ranges,
        config_for_trial=config_for_trial,
        trials_evaluations=trials_evaluations,
        failed_trials=failed_trials,
        pending_evaluations=pending_evaluations)
