# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Tuple, Iterator, Optional
import logging
from dataclasses import dataclass
import numpy as np
import itertools

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import (
    Configuration,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.hp_ranges import (
    HyperparameterRanges,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer import (
    ModelStateTransformer,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes import (
    NextCandidatesAlgorithm,
    CandidateGenerator,
    ScoringFunction,
    LocalOptimizer,
    SurrogateModel,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.bo_algorithm_components import (
    LBFGSOptimizeAcquisition,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.common import (
    generate_unique_candidates,
    ExclusionList,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.debug_log import (
    DebugLogPrinter,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.utils.duplicate_detector import (
    DuplicateDetector,
)
from syne_tune.optimizer.schedulers.utils.simple_profiler import SimpleProfiler

logger = logging.getLogger(__name__)


@dataclass
class BayesianOptimizationAlgorithm(NextCandidatesAlgorithm):
    """
    Core logic of the Bayesian optimization algorithm
    :param initial_candidates_generator: generator of candidates
    :param initial_scoring_function: scoring function used to rank the initial
        candidates.
        Note: If a batch is selected in one go (num_requested_candidates > 1,
        greedy_batch_selection = False), this function should encourage
        diversity among its top scorers. In general, greedy batch selection
        is recommended.
    :param num_initial_candidates: how many initial candidates to generate, if
        possible
    :param num_initial_candidates_for_batch: This is used only if
        num_requested_candidates > 1 and greedy_batch_selection is True. In
        this case, num_initial_candidates_for_batch overrides
        num_initial_candidates when selecting all but the first candidate for
        the batch. Typically, num_initial_candidates is larger than
        num_initial_candidates_for_batch in this case, which speeds up
        selecting large batches, but still select the first candidate very
        thoroughly
    :param local_optimizer: local optimizer which starts from score minimizer.
        If a batch is selected in one go (not greedily), then local
        optimizations are started from the top num_requested_candidates ranked
        candidates (after scoring)
    :param pending_candidate_state_transformer: Once a candidate is selected, it
        becomes pending, and the state is transformed by appending information.
        This is done by the transformer.
        This is object is needed only if next_candidates goes through > 1 outer
        iterations (i.e., if greedy_batch_selection is True and
        num_requested_candidates > 1. Otherwise, None can be passed here.
        Note: Model updates (by the state transformer) for batch candidates beyond
        the first do not involve fitting hyperparameters, so they are usually
        cheap.
    :param exclusion_candidates: set of tuples, candidates that should not be
        returned, because they are already labeled, currently pending, or have
        failed
    :param num_requested_candidates: number of candidates to return
    :param greedy_batch_selection: If True and num_requested_candidates > 1, we
        generate, order, and locally optimize for each single candidate to be
        selected. Otherwise (False), this is done just once, and
        num_requested_candidates are extracted in one go.
        Note: If this is True, pending_candidate_state_transformer is needed.
    :param duplicate_detector: used to make sure no candidates equal to already
        evaluated ones is returned
    :param profiler: If given, this is used for profiling parts in the code
    :param sample_unique_candidates: If True, we check that initial candidates
        sampled at random are unique and disjoint from the exclusion list.
        See below.
    :param debug_log: If a DebugLogPrinter is passed here, it is used to write
        log messages

    """

    initial_candidates_generator: CandidateGenerator
    initial_candidates_scorer: ScoringFunction
    num_initial_candidates: int
    local_optimizer: LocalOptimizer
    pending_candidate_state_transformer: Optional[ModelStateTransformer]
    exclusion_candidates: ExclusionList
    num_requested_candidates: int
    greedy_batch_selection: bool
    duplicate_detector: DuplicateDetector
    num_initial_candidates_for_batch: int = None
    profiler: SimpleProfiler = None
    sample_unique_candidates: bool = False
    debug_log: Optional[DebugLogPrinter] = None

    # Note: For greedy batch selection (num_outer_iterations > 1), the
    # underlying SurrrogateModel changes with each new pending candidate. The
    # model changes are managed by pending_candidate_state_transformer. The
    # model has to be passed to both initial_candidates_scorer and
    # local_optimizer.
    def next_candidates(self) -> List[Configuration]:
        if self.greedy_batch_selection:
            # Select batch greedily, one candidate at a time, updating the
            # model in between
            num_outer_iterations = self.num_requested_candidates
            num_inner_candidates = 1
        else:
            # Select batch in one go
            num_outer_iterations = 1
            num_inner_candidates = self.num_requested_candidates
        next_trial_id = None
        if num_outer_iterations > 1:
            assert (
                self.pending_candidate_state_transformer
            ), "Need pending_candidate_state_transformer for greedy batch selection"
            # For greedy batch selection, we need to assign new trial_id's to
            # configs included into the batch, in order to update the state
            # maintained in `pending_candidate_state_transformer`.
            # This is just to make batch suggestion work: neither the state
            # nor these trial_id's are used in the future.
            # Note: This code also works if trial_id's are arbitrary strings.
            # It guarantees that `str(next_trial_id + i)` is not equal to an
            # existing trial_id for all i >= 0.
            next_trial_id = 0
            for (
                trial_id
            ) in self.pending_candidate_state_transformer.state.config_for_trial.keys():
                try:
                    next_trial_id = max(next_trial_id, int(trial_id))
                except ValueError:
                    pass
            next_trial_id += 1
        candidates = []
        just_added = True
        model = None  # SurrogateModel, if num_outer_iterations > 1
        for outer_iter in range(num_outer_iterations):
            if just_added:
                if self.exclusion_candidates.config_space_exhausted():
                    logger.warning(
                        "All entries of finite config space (size "
                        + f"{self.exclusion_candidates.configspace_size}) have been selected. Returning "
                        + f"{len(candidates)} configs instead of {self.num_requested_candidates}"
                    )
                    break
                just_added = False
            if (
                self.num_initial_candidates_for_batch is not None
                and self.greedy_batch_selection
                and outer_iter > 0
            ):
                num_initial_candidates = self.num_initial_candidates_for_batch
            else:
                num_initial_candidates = self.num_initial_candidates
            inner_candidates = self._get_next_candidates(
                num_inner_candidates,
                model=model,
                num_initial_candidates=num_initial_candidates,
            )
            candidates.extend(inner_candidates)
            if outer_iter < num_outer_iterations - 1 and len(inner_candidates) > 0:
                just_added = True
                # This is not the last outer iteration
                for cand in inner_candidates:
                    self.exclusion_candidates.add(cand)
                # State transformer is used to produce new model
                # Note: We suppress fit_hyperpars for models obtained during
                # batch selection
                for candidate in inner_candidates:
                    self.pending_candidate_state_transformer.append_trial(
                        trial_id=str(next_trial_id), config=candidate
                    )
                    next_trial_id += 1
                model = self.pending_candidate_state_transformer.model(
                    skip_optimization=True
                )
            if (
                len(inner_candidates) < num_inner_candidates
                and len(candidates) < self.num_requested_candidates
            ):
                logger.warning(
                    "All entries of finite config space (size "
                    + f"{self.exclusion_candidates.configspace_size}) have been selected. Returning "
                    + f"{len(candidates)} configs instead of {self.num_requested_candidates}"
                )
                break

        return candidates

    def _get_next_candidates(
        self,
        num_candidates: int,
        model: Optional[SurrogateModel],
        num_initial_candidates: Optional[int] = None,
    ):
        if num_initial_candidates is None:
            num_initial_candidates = self.num_initial_candidates
        # generate a random candidates among which to pick the ones to be
        # locally optimized
        logger.info(
            f"BayesOpt Algorithm: Generating {num_initial_candidates} "
            "initial candidates."
        )
        if self.profiler is not None:
            self.profiler.push_prefix("nextcand")
            self.profiler.start("all")
            self.profiler.start("genrandom")
        if self.sample_unique_candidates:
            # This can be expensive, depending on what type Candidate is
            initial_candidates = generate_unique_candidates(
                self.initial_candidates_generator,
                num_initial_candidates,
                self.exclusion_candidates,
            )
        else:
            # Will not return candidates in `exclusion_candidates`, but there
            # can be duplicates
            initial_candidates = (
                self.initial_candidates_generator.generate_candidates_en_bulk(
                    num_initial_candidates, exclusion_list=self.exclusion_candidates
                )
            )
        if self.profiler is not None:
            self.profiler.stop("genrandom")
            self.profiler.start("scoring")
        logger.info("BayesOpt Algorithm: Scoring (and reordering) candidates.")
        if self.debug_log is not None:
            candidates_and_scores = _order_candidates(
                initial_candidates,
                self.initial_candidates_scorer,
                model=model,
                with_scores=True,
            )
            initial_candidates = [cand for score, cand in candidates_and_scores]
            config = initial_candidates[0]
            top_scores = np.array([x for x, _ in candidates_and_scores[:5]])
            self.debug_log.set_init_config(config, top_scores)
        else:
            initial_candidates = _order_candidates(
                initial_candidates, self.initial_candidates_scorer, model=model
            )
        if self.profiler is not None:
            self.profiler.stop("scoring")
            self.profiler.start("localsearch")
        candidates_with_optimization = _lazily_locally_optimize(
            initial_candidates,
            self.local_optimizer,
            hp_ranges=self.exclusion_candidates.hp_ranges,
            model=model,
        )
        logger.info("BayesOpt Algorithm: Selecting final set of candidates.")
        if self.debug_log is not None and isinstance(
            self.local_optimizer, LBFGSOptimizeAcquisition
        ):
            # We would like to get num_evaluations from the first run (usually
            # the only one). This requires peeking at the first entry of the
            # iterator
            peek = candidates_with_optimization.__next__()
            self.debug_log.set_num_evaluations(self.local_optimizer.num_evaluations)
            candidates_with_optimization = itertools.chain(
                [peek], candidates_with_optimization
            )
        candidates = _pick_from_locally_optimized(
            candidates_with_optimization,
            self.exclusion_candidates,
            num_candidates,
            self.duplicate_detector,
        )
        if self.profiler is not None:
            self.profiler.stop("localsearch")
            self.profiler.stop("all")
            self.profiler.pop_prefix()  # nextcand
        return candidates


def _order_candidates(
    candidates: List[Configuration],
    scoring_function: ScoringFunction,
    model: Optional[SurrogateModel],
    with_scores: bool = False,
):
    if len(candidates) == 0:
        return []
    # scored in batch as this can be more efficient
    scores = scoring_function.score(candidates, model=model)
    sorted_list = sorted(zip(scores, candidates), key=lambda x: x[0])
    if with_scores:
        return sorted_list
    else:
        return [cand for score, cand in sorted_list]


def _lazily_locally_optimize(
    candidates: List[Configuration],
    local_optimizer: LocalOptimizer,
    hp_ranges: HyperparameterRanges,
    model: Optional[SurrogateModel],
) -> Iterator[Tuple[Configuration, Configuration]]:
    """
    Due to local deduplication we do not know in advance how many candidates
    we have to locally optimize, hence this helper to create a lazy generator
    of locally optimized candidates.
    Note that `candidates` may contain duplicates, but such are skipped here.
    """
    considered_already = ExclusionList.empty_list(hp_ranges)
    for cand in candidates:
        if not considered_already.contains(cand):
            considered_already.add(cand)
            yield cand, local_optimizer.optimize(cand, model=model)


# Note: If duplicate_detector is at least DuplicateDetectorIdentical, it will
# filter out candidates in exclusion_candidates here. Such can in principle
# arise if sample_unique_candidates == False.
# This does not work if duplicate_detector is DuplicateDetectorNoDetection.
def _pick_from_locally_optimized(
    candidates_with_optimization: Iterator[Tuple[Configuration, Configuration]],
    exclusion_candidates: ExclusionList,
    num_candidates: int,
    duplicate_detector: DuplicateDetector,
) -> List[Configuration]:
    updated_excludelist = exclusion_candidates.copy()
    result = []
    for original_candidate, optimized_candidate in candidates_with_optimization:
        insert_candidate = None
        optimized_is_duplicate = duplicate_detector.contains(
            updated_excludelist, optimized_candidate
        )
        if optimized_is_duplicate:
            # in the unlikely case that the optimized candidate ended at a
            # place that caused a duplicate we try to return the original instead
            original_also_duplicate = duplicate_detector.contains(
                updated_excludelist, original_candidate
            )
            if not original_also_duplicate:
                insert_candidate = original_candidate
        else:
            insert_candidate = optimized_candidate
        if insert_candidate is not None:
            result.append(insert_candidate)
            updated_excludelist.add(insert_candidate)
        if len(result) == num_candidates:
            break

    return result
