# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from collections import Counter
from typing import Callable
from dataclasses import dataclass

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state \
    import TuningJobState
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common \
    import TrialEvaluations, PendingEvaluation, INTERNAL_METRIC_NAME
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges \
    import HyperparameterRanges


@dataclass
class MapReward:
    forward: Callable[[float], float]
    reverse: Callable[[float], float]

    def __call__(self, x: float) -> float:
        return self.forward(x)


def map_reward_const_minus_x(const=1.0) -> MapReward:
    """
    Factory for map_reward argument in GPMultiFidelitySearcher.
    """
    def const_minus_x(x):
        return const - x

    return MapReward(forward=const_minus_x, reverse=const_minus_x)


SUPPORTED_INITIAL_SCORING = {
    'thompson_indep',
    'acq_func'}


DEFAULT_INITIAL_SCORING = 'thompson_indep'


def encode_state(state: TuningJobState) -> dict:
    trials_evaluations = [
        {'trial_id': x.trial_id, 'metrics': x.metrics}
        for x in state.trials_evaluations]
    pending_evaluations = [
        {'trial_id': x.trial_id, 'resource': x.resource}
        if x.resource is not None else {'trial_id': x.trial_id}
        for x in state.pending_evaluations]
    enc_state = {
        'config_for_trial': state.config_for_trial,
        'trials_evaluations': trials_evaluations,
        'failed_trials': state.failed_trials,
        'pending_evaluations': pending_evaluations}
    return enc_state


def decode_state(enc_state: dict, hp_ranges: HyperparameterRanges) \
        -> TuningJobState:
    trials_evaluations = [
        TrialEvaluations(**x) for x in enc_state['trials_evaluations']]
    pending_evaluations = [
        PendingEvaluation(**x) for x in enc_state['pending_evaluations']]
    return TuningJobState(
        hp_ranges=hp_ranges,
        config_for_trial=enc_state['config_for_trial'],
        trials_evaluations=trials_evaluations,
        failed_trials=enc_state['failed_trials'],
        pending_evaluations=pending_evaluations)


def _get_trial_id(
        hp_ranges: HyperparameterRanges, config: dict, config_for_trial: dict,
        trial_for_config: dict) -> str:
    match_str = hp_ranges.config_to_match_string(config, skip_last=True)
    trial_id = trial_for_config.get(match_str)
    if trial_id is None:
        trial_id = str(len(trial_for_config))
        trial_for_config[match_str] = trial_id
        config_for_trial[trial_id] = config
    return trial_id


def decode_state_from_old_encoding(
        enc_state: dict, hp_ranges: HyperparameterRanges) -> TuningJobState:
    """
    Decodes `TuningJobState` from encoding done for the old definition of
    `TuningJobState`. Code maintained for backwards compatibility.

    Note: Since the old `TuningJobState` did not contain `trial_id`, we need
    to make them up here. We assign these IDs in the order
    `candidate_evaluations`, `failed_candidates`, `pending_candidates`,
    matching for duplicates.

    :param enc_state:
    :param hp_ranges:
    :return:
    """
    config_for_trial = dict()
    trial_for_config = dict()
    trials_evaluations = []
    for cand_eval in enc_state['candidate_evaluations']:
        config = cand_eval['candidate']
        trial_id = _get_trial_id(
            hp_ranges, config, config_for_trial, trial_for_config)
        trials_evaluations.append(TrialEvaluations(
            trial_id, cand_eval['metrics']))
    failed_trials = []
    for failed_cand in enc_state['failed_candidates']:
        failed_trials.append(_get_trial_id(
            hp_ranges, failed_cand, config_for_trial, trial_for_config))
    pending_evaluations = []
    resource_attr_name = hp_ranges.name_last_pos
    for pending_cand in enc_state['pending_candidates']:
        resource = None
        if resource_attr_name is not None and resource_attr_name in pending_cand:
            # Extended config (multi-fidelity)
            resource = int(pending_cand[resource_attr_name])
            pending_cand = pending_cand.copy()
            del pending_cand[resource_attr_name]
        trial_id = _get_trial_id(
            hp_ranges, pending_cand, config_for_trial, trial_for_config)
        pending_evaluations.append(PendingEvaluation(trial_id, resource))
    return TuningJobState(
        hp_ranges=hp_ranges,
        config_for_trial=config_for_trial,
        trials_evaluations=trials_evaluations,
        failed_trials=failed_trials,
        pending_evaluations=pending_evaluations)


class ResourceForAcquisitionMap(object):
    """
    In order to use a standard acquisition function (like expected improvement)
    for multi-fidelity HPO, we need to decide at which `r_acq` we would like
    to evaluate the AF, w.r.t. the posterior distribution over `f(x, r=r_acq)`.
    This decision can depend on the current state.

    """
    def __call__(self, state: TuningJobState, **kwargs) -> int:
        raise NotImplementedError()


class ResourceForAcquisitionBOHB(ResourceForAcquisitionMap):
    """
    Implements a heuristic proposed in the BOHB paper: `r_acq` is the
    largest `r` such that we have at least `threshold` observations at
    `r`. If there are less than `threshold` observations at all levels,
    the smallest level is returned.

    """
    def __init__(
            self, threshold: int, active_metric: str = INTERNAL_METRIC_NAME):
        self.threshold = threshold
        self.active_metric = active_metric

    def __call__(self, state: TuningJobState, **kwargs) -> int:
        assert state.num_observed_cases(self.active_metric) > 0, \
            f"state must have data for metric {self.active_metric}"
        all_resources = []
        for cand_eval in state.trials_evaluations:
            all_resources += [
                int(r) for r in cand_eval.metrics[self.active_metric].keys()]
        histogram = Counter(all_resources)
        return self._max_at_least_threshold(histogram)

    def _max_at_least_threshold(self, counter: Counter) -> int:
        """
        Get largest key of `counter` whose value is at least `threshold`.

        :param counter: dict with keys that support comparison operators
        :return: largest key of `counter`
        """
        return max(filter(
            lambda r: counter[r] >= self.threshold, counter.keys()),
            default=min(counter.keys()))


class ResourceForAcquisitionFirstMilestone(ResourceForAcquisitionMap):
    """
    Here, `r_acq` is the smallest rung level to be attained by a config
    started from scratch.

    """
    def __call__(self, state: TuningJobState, **kwargs) -> int:
        assert 'milestone' in kwargs, \
            "Need the first milestone to be attained by the new config " +\
            "passed as kwargs['milestone']. Use a scheduler which does " +\
            "that (e.g., HyperbandScheduler)"
        return kwargs['milestone']


class ResourceForAcquisitionFinal(ResourceForAcquisitionMap):
    """
    Here, `r_acq = r_max` is the largest resource level.

    """
    def __init__(self, r_max: int):
        self._r_max = r_max

    def __call__(self, state: TuningJobState, **kwargs) -> int:
        return self._r_max


SUPPORTED_RESOURCE_FOR_ACQUISITION = {'bohb', 'first', 'final'}


def resource_for_acquisition_factory(
        kwargs: dict,
        hp_ranges: HyperparameterRanges) -> ResourceForAcquisitionMap:
    resource_acq = kwargs.get('resource_acq', 'bohb')
    assert resource_acq in SUPPORTED_RESOURCE_FOR_ACQUISITION, \
        f"resource_acq = {resource_acq} not supported, must be in " +\
        str(SUPPORTED_RESOURCE_FOR_ACQUISITION)
    if resource_acq == 'bohb':
        threshold = kwargs.get(
            'resource_acq_bohb_threshold', hp_ranges.ndarray_size())
        resource_for_acquisition = ResourceForAcquisitionBOHB(
            threshold=threshold)
    elif resource_acq == 'first':
        assert resource_acq == 'first', \
            "resource_acq must be 'bohb' or 'first'"
        resource_for_acquisition = ResourceForAcquisitionFirstMilestone()
    else:
        r_max = kwargs['max_epochs']
        resource_for_acquisition = ResourceForAcquisitionFinal(r_max=r_max)
    return resource_for_acquisition
