# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Optional

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges_factory \
    import make_hyperparameter_ranges
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges \
    import HyperparameterRanges
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common \
    import Configuration, ConfigurationFilter
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel \
    import Matern52, ProductKernelFunction, KernelFunction, RangeKernelFunction


def create_hp_ranges_for_warmstarting(**kwargs) -> HyperparameterRanges:
    """
    See :class:`GPFIFOSearcher` for details on transfer_learning_task_attr',
    'transfer_learning_active_task', 'transfer_learning_active_config_space'
    as optional fields in `kwargs`. If given, they determine
    `active_config_space` and `prefix_keys` of `hp_ranges` created here,
    and they also places constraints on 'configspace'.

    This function is not only called in `gp_searcher_factory` to create
    `hp_ranges` for a new :class:`GPFIFOSearcher` object. It is also needed to
    create the `TuningJobState` containing the data to be used in warmstarting.

    """
    task_attr = kwargs.get('transfer_learning_task_attr')
    config_space = kwargs['configspace']
    prefix_keys = None
    active_config_space = None
    if task_attr is not None:
        from syne_tune.search_space import Categorical

        active_task = kwargs.get('transfer_learning_active_task')
        assert active_task is not None, \
            "transfer_learning_active_task is needed if transfer_learning_task_attr is given"
        hp_range = config_space.get(task_attr)
        assert isinstance(hp_range, Categorical), \
            f"config_space[{task_attr}] must be a categorical parameter"
        assert active_task in hp_range.categories, \
            f"'{active_task}' must be value in config_space[{task_attr}] " +\
            f"(values: {hp_range.categories})"
        prefix_keys = [task_attr]
        active_config_space = kwargs.get(
            'transfer_learning_active_config_space')
        if active_config_space is None:
            active_config_space = config_space
        # The parameter `task_attr` in `active_config_space` must be restricted
        # to `active_task` as a single value
        task_param = Categorical(categories=[active_task])
        active_config_space = dict(
            active_config_space, **{task_attr: task_param})
    return make_hyperparameter_ranges(
        config_space, active_config_space=active_config_space,
        prefix_keys=prefix_keys)


def create_filter_observed_data_for_warmstarting(
        **kwargs) -> Optional[ConfigurationFilter]:
    """
    See :class:`GPFIFOSearcher` for details on transfer_learning_task_attr',
    'transfer_learning_active_task' as optional fields in `kwargs`.

    """
    task_attr = kwargs.get('transfer_learning_task_attr')
    if task_attr is not None:
        active_task = kwargs.get('transfer_learning_active_task')
        assert active_task is not None, \
            "transfer_learning_active_task is needed if transfer_learning_task_attr is given"

        def filter_observed_data(config: Configuration) -> bool:
            return config[task_attr] == active_task

        return filter_observed_data
    else:
        return None


def create_base_gp_kernel_for_warmstarting(
        hp_ranges: HyperparameterRanges, **kwargs) -> KernelFunction:
    """
    In the transfer learning case, the base kernel is a product of
    two `Matern52` kernels, the first non-ARD over the categorical
    parameter determining the task, the second ARD over the remaining
    parameters.

    """
    task_attr = kwargs.get('transfer_learning_task_attr')
    assert task_attr is not None
    hp_range = hp_ranges.config_space.get(task_attr)
    # Note: This attribute is the first in `hp_ranges`, see
    # `create_hp_ranges_for_warmstarting`
    assert hp_ranges.internal_keys[0] == task_attr  # Sanity check
    categ_dim = len(hp_range)
    full_dim = hp_ranges.ndarray_size()
    model = kwargs.get('transfer_learning_model', 'matern52_product')
    kernel2 = Matern52(full_dim - categ_dim, ARD=True)
    if model == 'matern52_product':
        # Kernel is a product of Matern with single length scale on task_id
        # attribute, and Matern ARD kernel on the rest
        kernel1 = Matern52(categ_dim, ARD=False)
        return ProductKernelFunction(kernel1, kernel2)
    else:
        assert model == 'matern52_same', \
            f"transfer_learning_model = {model} not supported"
        # Kernel is Matern ARD on rest, ignoring the task_id attribute
        return RangeKernelFunction(
            dimension=full_dim, kernel=kernel2, start=categ_dim)
