# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import numpy as np

from syne_tune.search_space import uniform, randint, choice
from syne_tune.optimizer.schedulers.searchers.gp_searcher_factory import \
    gp_fifo_searcher_defaults, gp_fifo_searcher_factory
from syne_tune.optimizer.schedulers.searchers.utils.default_arguments \
    import check_and_merge_defaults
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel \
    import Matern52, ProductKernelFunction


def test_create_transfer_learning():
    config_space = {
        'task_id': choice(['0', '1', '2', '3']),
        'a': uniform(lower=0.0, upper=2.0),
        'b': randint(lower=2, upper=5),
        'c': choice(['a', 'b', 'c']),
    }
    active_config_space = {
        'a': uniform(lower=0.2, upper=0.8),
        'b': randint(lower=2, upper=4),
        'c': choice(['a', 'c']),
    }
    search_options = {
        'scheduler': 'fifo',
        'configspace': config_space,
        'transfer_learning_task_attr': 'task_id',
        'transfer_learning_active_task': '2',
        'transfer_learning_active_config_space': active_config_space,
    }
    kwargs = check_and_merge_defaults(
        search_options, *gp_fifo_searcher_defaults(),
        dict_name='search_options')
    kwargs_int = gp_fifo_searcher_factory(**kwargs)

    filter_observed_data = kwargs_int.get('filter_observed_data')
    assert filter_observed_data is not None
    config = dict(task_id='2', a=0.0, b=2, c='a')
    assert filter_observed_data(config)
    config = dict(task_id='0', a=0.0, b=2, c='a')
    assert not filter_observed_data(config)

    hp_ranges = kwargs_int['hp_ranges']
    assert hp_ranges.internal_keys == ['task_id', 'a', 'b', 'c']
    assert hp_ranges.ndarray_size() == 9
    expected_ndarray_bounds = [
        (0.0, 0.0), (0.0, 0.0), (1.0, 1.0), (0.0, 0.0),
        (0.1, 0.4),
        (0.0, 0.75),
        (0.0, 1.0), (0.0, 0.0), (0.0, 1.0)]
    ndarray_bounds = hp_ranges.get_ndarray_bounds()
    mat = np.array([[x[i] for x in ndarray_bounds] for i in range(2)])
    expected_mat = np.array([[x[i] for x in expected_ndarray_bounds]
                             for i in range(2)])
    np.testing.assert_almost_equal(mat, expected_mat)

    kernel = kwargs_int['model_factory']._gpmodel.likelihood.kernel
    assert isinstance(kernel, ProductKernelFunction)
    kernel1 = kernel.kernel1
    assert isinstance(kernel1, Matern52)
    assert kernel1.dimension == 4
    assert not kernel1.ARD
    kernel2 = kernel.kernel2
    assert isinstance(kernel2, Matern52)
    assert kernel2.dimension == 5
    assert kernel2.ARD
