# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Set, Tuple
import pytest

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common \
    import Configuration, dictionarize_objective
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges \
    import HyperparameterRanges
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common \
    import ExclusionList, generate_unique_candidates
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.test_objects \
    import RepeatedCandidateGenerator
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges_factory \
    import make_hyperparameter_ranges
from syne_tune.search_space import randint, choice
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.test_objects \
    import create_tuning_job_state, create_exclusion_set


@pytest.fixture(scope='function')
def hp_ranges():
    return make_hyperparameter_ranges({
        'hp1': randint(0, 200),
        'hp2': choice(['a', 'b', 'c'])})


@pytest.mark.parametrize('observed_data,failed_tuples,pending_tuples,expected', [
    ([], [], [], set()),
    ([((123, 'a'), 9.87)], [], [], {'hp1:123,hp2:0'}),
    ([], [(123, 'a')], [], {'hp1:123,hp2:0'}),
    ([], [], [(123, 'a')], {'hp1:123,hp2:0'}),
    ([((1, 'a'), 9.87)], [(2, 'b')], [(3, 'c')],
     {'hp1:1,hp2:0', 'hp1:2,hp2:1', 'hp1:3,hp2:2'})
])
def test_compute_blacklisted_candidates(
        hp_ranges: HyperparameterRanges,
        observed_data: List[Tuple],
        failed_tuples: List[Tuple],
        pending_tuples: List[Tuple],
        expected: Set[str]):
    if observed_data:
        cand_tuples, metrics = zip(*observed_data)
    else:
        cand_tuples = []
        metrics = []
    if metrics:
        metrics = [dictionarize_objective(y) for y in metrics]
    state = create_tuning_job_state(
        hp_ranges, cand_tuples=cand_tuples, metrics=metrics,
        pending_tuples=pending_tuples, failed_tuples=failed_tuples)
    actual = ExclusionList(state)
    assert set(expected) == actual.excl_set


def _assert_no_duplicates(
        candidates: List[Configuration], hp_ranges: HyperparameterRanges):
    cands_tpl = [hp_ranges.config_to_match_string(x) for x in candidates]
    assert len(candidates) == len(set(cands_tpl))


@pytest.mark.parametrize('num_unique_candidates,num_requested_candidates', [
    (5, 10),
    (15, 10)
])
def test_generate_unique_candidates(num_unique_candidates, num_requested_candidates):
    generator = RepeatedCandidateGenerator(num_unique_candidates)
    hp_ranges = generator.hp_ranges
    exclusion_candidates = create_exclusion_set([], hp_ranges)
    candidates = generate_unique_candidates(
        candidates_generator=generator,
        num_candidates=num_requested_candidates,
        exclusion_candidates=exclusion_candidates)
    assert len(candidates) == min(num_unique_candidates, num_requested_candidates)
    _assert_no_duplicates(candidates, hp_ranges)

    # introduce excluded candidates, simply take a few already unique
    size_excluded = len(candidates) // 2
    excluded = list(candidates)[:size_excluded]
    exclusion_candidates = create_exclusion_set(
        excluded, generator.hp_ranges, is_dict=True)
    candidates = generate_unique_candidates(
        candidates_generator=generator,
        num_candidates=num_requested_candidates,
        exclusion_candidates=exclusion_candidates)

    # total unique candidates are adjusted by the number of excluded candidates which are unique too due to set()
    assert len(candidates) == min(num_unique_candidates - len(excluded), num_requested_candidates)
    _assert_no_duplicates(candidates, hp_ranges)
