# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import numpy as np

from syne_tune.optimizer.schedulers.searchers.bayesopt.models.gp_model \
    import get_internal_candidate_evaluations
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import dictionarize_objective, \
    INTERNAL_METRIC_NAME
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.test_objects \
    import dimensionality_and_warping_ranges, create_tuning_job_state
from syne_tune.search_space import uniform, randint, choice, loguniform
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges_factory \
    import make_hyperparameter_ranges


def test_get_internal_candidate_evaluations():
    """we do not test the case with no evaluations, as it is assumed
    that there will be always some evaluations generated in the beginning
    of the BO loop."""

    hp_ranges = make_hyperparameter_ranges({
        'a': randint(0, 10),
        'b': uniform(0.0, 10.0),
        'c': choice(['X', 'Y'])})
    cand_tuples = [
        (2, 3.3, 'X'),
        (1, 9.9, 'Y'),
        (7, 6.1, 'X')]
    metrics = [dictionarize_objective(y) for y in (5.3, 10.9, 13.1)]

    state = create_tuning_job_state(
        hp_ranges=hp_ranges, cand_tuples=cand_tuples, metrics=metrics)
    state.failed_trials.append('0')  # First trial with observation also failed

    result = get_internal_candidate_evaluations(
        state, INTERNAL_METRIC_NAME, normalize_targets=True,
        num_fantasy_samples=20)

    assert len(result.features.shape) == 2, "Input should be a matrix"
    assert len(result.targets.shape) == 2, "Output should be a matrix"

    assert result.features.shape[0] == len(cand_tuples)
    assert result.targets.shape[-1] == 1, \
        "Only single output value per row is suppored"

    assert np.abs(np.mean(result.targets)) < 1e-8, \
        "Mean of the normalized outputs is not 0.0"
    assert np.abs(np.std(result.targets) - 1.0) < 1e-8, \
        "Std. of the normalized outputs is not 1.0"

    np.testing.assert_almost_equal(result.mean, 9.766666666666666)
    np.testing.assert_almost_equal(result.std, 3.283629428273267)


def test_dimensionality_and_warping_ranges():
    hp_ranges = make_hyperparameter_ranges({
        'a': choice(['X', 'Y']),
        'b': loguniform(0.1, 10.0),
        'c': choice(['a', 'b', 'c']),
        'd': uniform(0.0, 10.0),
        'e': choice(['X', 'Y'])})

    dim, warping_ranges = dimensionality_and_warping_ranges(hp_ranges)
    assert dim == 9
    assert warping_ranges == {2: (0.0, 1.0), 6: (0.0, 1.0)}
