
from xnas import utils

logger = utils.get_logger(__name__)


def load_model(*args, custom_objects=None, **kwargs):
    import tensorflow as tf
    from xnas.search_space_common import TransposeLayer

    custom_objects = custom_objects or {}
    custom_objects['TransposeLayer'] = TransposeLayer
    return tf.keras.models.load_model(
        *args, custom_objects=custom_objects, **kwargs)


def keras_logger_callback(config):
    import tensorflow as tf

    class LoggingCallback(tf.keras.callbacks.Callback):

        def __init__(self):
            id_ = 'ID=' + str(config.get('id', '?'))
            self._logger = utils.get_logger(id_)
            self._epoch = '<<< UNDEFINED >>>'  
            super().__init__()

        def _log_logs(self, logs):
            if not logs:
                return
            maxlen = max(len(k) for k in logs)
            for k, v in logs.items():
                if k in {'outputs'}:  
                    continue
                self._logger.info(f'  {k:>{maxlen}} = {v}')

        def on_train_begin(self, logs=None):
            self._logger.info('!!! Starting training !!!')
            self._log_logs(logs)

        def on_train_end(self, logs=None):
            self._logger.info('!!! Stopped training !!!')
            self._log_logs(logs)

        def on_epoch_begin(self, epoch, logs=None):
            logger.info(f'Start epoch {epoch} of training.')
            self._log_logs(logs)
            self._epoch = epoch

        def on_epoch_end(self, epoch, logs=None):
            logger.info(f'End epoch {epoch} of training.')
            self._log_logs(logs)

        def on_test_begin(self, logs=None):
            self._logger.info('!!! Starting testing !!!')
            self._log_logs(logs)

        def on_test_end(self, logs=None):
            self._logger.info('!!! Stopped testing !!!')
            self._log_logs(logs)

        def on_predict_begin(self, logs=None):
            self._logger.info('!!! Starting predicting !!!')
            self._log_logs(logs)

        def on_predict_end(self, logs=None):
            self._logger.info('!!! Stopped prediction !!!')
            self._log_logs(logs)

        def on_train_batch_begin(self, batch, logs=None):
            self._logger.info(f'...Training: start of batch {batch} (epoch '
                              f'{self._epoch})')
            self._log_logs(logs)

        def on_train_batch_end(self, batch, logs=None):
            self._logger.info(f'...Training: end of batch {batch} (epoch '
                              f'{self._epoch})')
            self._log_logs(logs)

        def on_test_batch_begin(self, batch, logs=None):
            self._logger.info(f'...Evaluating: start of batch {batch}')
            self._log_logs(logs)

        def on_test_batch_end(self, batch, logs=None):
            self._logger.info(f'...Evaluating: end of batch {batch}')
            self._log_logs(logs)

        def on_predict_batch_begin(self, batch, logs=None):
            self._logger.info(f'...Prediction: start of batch {batch}')
            self._log_logs(logs)

        def on_predict_batch_end(self, batch, logs=None):
            self._logger.info(f'...Prediction: end of batch {batch}')
            self._log_logs(logs)

    return LoggingCallback()
