# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Tuple
import numpy as np
import logging

logger = logging.getLogger(__name__)


RungSystemsPerBracket = List[List[Tuple[int, int]]]


class SynchronousHyperbandRungSystem(object):
    """
    Collects factory methods for `RungSystemsPerBracket` rung systems to be
    used in :class:`SynchronousHyperbandBracketManager`.

    """
    @staticmethod
    def geometric(
            min_resource: int, max_resource: int,
            reduction_factor: float, num_brackets: int) -> RungSystemsPerBracket:
        """
        This is the geometric progression setup from the original papers on
        successive halving and Hyperband.

        If `smax = ceil(log(max_resource / min_resource) /
        log(reduction_factor))`, there can be at most `s_max + 1` brackets.
        Here, bracket s has `r_num = s_max - s + 1` rungs, and the size of
        rung r in bracket s is
            `n(r,s) = ceil( ((s_max + 1) / r_num) *
            power(reduction_factor, r_num - r - 1)`

        :param min_resource: Smallest resource level (positive int)
        :param max_resource: Largest resource level (positive int)
        :param reduction_factor: Approximate ratio between successive rung levels
        :param num_brackets: Number of brackets
        :return: Rung system
        """
        SynchronousHyperbandRungSystem._assert_positive_int(
            min_resource, 'min_resource')
        SynchronousHyperbandRungSystem._assert_positive_int(
            max_resource, 'max_resource')
        assert min_resource < max_resource
        assert reduction_factor >= 2, \
            f"reduction_factor = {reduction_factor} must be >= 2"
        SynchronousHyperbandRungSystem._assert_positive_int(
            num_brackets, 'num_brackets')
        s_max = int(np.ceil(
            (np.log(max_resource) - np.log(min_resource)) /
            np.log(reduction_factor)))
        msg_prefix = f"min_resource = {min_resource}, max_resource = " +\
                     f"{max_resource}, reduction_factor = {reduction_factor}"
        if s_max <= 0:
            logger.warning(
                msg_prefix +\
                ": supports only one bracket with a single rung level of "
                "size 1. Is that really what you want?")
            return [[(1, max_resource)]]
        if num_brackets > s_max + 1:
            logger.warning(
                msg_prefix +\
                f": does not support num_brackets = {num_brackets}, but at "
                f"most {s_max + 1}. I am switching to the latter one.")
            num_brackets = s_max + 1
        rung_systems = []
        for bracket in range(num_brackets):
            rungs = []
            r_num_m1 = s_max - bracket
            pre_fact = (s_max + 1) / (r_num_m1 + 1)
            for rung in range(r_num_m1):
                resource = int(round(
                    min_resource * np.power(
                        reduction_factor, rung + bracket)))
                rsize = int(np.ceil(
                    pre_fact * np.power(reduction_factor, r_num_m1 - rung)))
                rungs.append((rsize, resource))
            rungs.append((int(np.ceil(pre_fact)), max_resource))
            rung_systems.append(rungs)
        parts = [f"Bracket {i}: rungs = {rungs}" for i, rungs in enumerate(rung_systems)]
        logger.info('\n'.join(parts))

        return rung_systems

    @staticmethod
    def _assert_positive_int(x: int, name: str):
        assert round(x) == x and x >= 1, \
            f"{name} = {x} must be a positive integer"
