# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List
import logging
import numpy as np

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common \
    import Configuration, INTERNAL_METRIC_NAME
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state \
    import TuningJobState

logger = logging.getLogger(__name__)

__all__ = ['DebugLogPrinter']


def _param_dict_to_str(params: dict) -> str:
    parts = []
    for name, param in params.items():
        if isinstance(param, float):
            parts.append(f"{name}: {param:.4e}")
        else:
            parts.append(f"{name}: {param}")
    return '{' + ', '.join(parts) + '}'


class DebugLogPrinter(object):
    """
    Supports a concise debug log.
    In particular, information about `get_config` is displayed in a single
    block. For that, different parts are first collected until the end of
    `get_config`.

    """
    def __init__(self):
        self._reset()

    def _reset(self):
        self.get_config_trial_id = None
        self.get_config_type = None
        self.block_info = dict()

    def start_get_config(self, gc_type, trial_id: str):
        assert gc_type in {'random', 'BO'}
        assert trial_id is not None
        assert self.get_config_type is None, \
            "Block for get_config of type '{}' is currently open".format(
                self.get_config_type)
        self.get_config_trial_id = trial_id
        self.get_config_type = gc_type
        logger.debug(f"Starting get_config[{gc_type}] for trial_id {trial_id}")

    def set_final_config(self, config: Configuration):
        assert self.get_config_type is not None, "No block open right now"
        entries = ['{}: {}'.format(k, v) for k, v in config.items()]
        msg = '\n'.join(entries)
        self.block_info['final_config'] = msg

    def _observed_trial_ids(self, state: TuningJobState) -> List[str]:
        trial_ids = []
        for ev in state.trials_evaluations:
            trial_id = ev.trial_id
            metric_entry = ev.metrics.get(INTERNAL_METRIC_NAME)
            if metric_entry is not None:
                if isinstance(metric_entry, dict):
                    for resource in metric_entry.keys():
                        trial_ids.append(trial_id + ':' + resource)
                else:
                    trial_ids.append(trial_id)
        return trial_ids

    def _pending_trial_ids(self, state: TuningJobState) ->List[str]:
        trial_ids = []
        for ev in state.pending_evaluations:
            trial_id = ev.trial_id
            resource = ev.resource
            if resource is None:
                trial_ids.append(trial_id)
            else:
                trial_ids.append(trial_id + f":{resource}")
        return trial_ids

    def set_state(self, state: TuningJobState):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        labeled_str = ', '.join(self._observed_trial_ids(state))
        msg = 'Labeled: ' + labeled_str
        if state.pending_evaluations:
            pending_str = ', '.join(self._pending_trial_ids(state))
            msg += '. Pending: ' + pending_str
        self.block_info['state'] = msg

    def set_targets(self, targets: np.ndarray):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        msg = 'Targets: ' + str(targets.reshape((-1,)))
        self.block_info['targets'] = msg

    def set_model_params(self, params: dict):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        msg = 'Model params: ' + _param_dict_to_str(params)
        self.block_info['params'] = msg

    def set_fantasies(self, fantasies: np.ndarray):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        msg = 'Fantasized targets:\n' + str(fantasies)
        self.block_info['fantasies'] = msg

    def set_init_config(
            self, config: Configuration, top_scores: np.ndarray = None):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        entries = ['{}: {}'.format(k, v) for k, v in config.items()]
        msg = "Started BO from (top scorer):\n" + '\n'.join(entries)
        if top_scores is not None:
            msg += ("\nTop score values: " + str(top_scores.reshape((-1,))))
        self.block_info['start_config'] = msg

    def set_num_evaluations(self, num_evals: int):
        assert self.get_config_type == 'BO', "Need to be in 'BO' block"
        self.block_info['num_evals'] = num_evals

    def append_extra(self, extra: str):
        if 'extra' in self.block_info:
            self.block_info['extra'] = '\n'.join(
                [self.block_info['extra'], extra])
        else:
            self.block_info['extra'] = extra

    def write_block(self):
        assert self.get_config_type is not None, "No block open right now"
        info = self.block_info
        trial_id = self.get_config_trial_id
        if 'num_evals' in info:
            parts = ['[{}: {}] ({} evaluations)'.format(
                trial_id, self.get_config_type, info['num_evals'])]
        else:
            parts = ['[{}: {}]'.format(trial_id, self.get_config_type)]
        parts.append(info['final_config'])
        debug_parts = []  # Parts for logger.debug
        if self.get_config_type == 'BO':
            if 'start_config' in info:
                debug_parts.append(info['start_config'])
            # The following 3 should be present!
            for name in ('state', 'targets', 'params'):
                v = info.get(name)
                if v is not None:
                    if name == 'targets':
                        debug_parts.append(v)
                    else:
                        parts.append(v)
                else:
                    logger.info(
                        "debug_log.write_block: '{}' part is missing!".format(
                            name))
            if 'fantasies' in info:
                debug_parts.append(info['fantasies'])
        if 'extra' in info:
            debug_parts.append(info['extra'])
        msg = '\n'.join(parts)
        logger.info(msg)
        msg = '\n'.join(debug_parts)
        logger.debug(msg)
        self._reset()
