"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import functools
import tensorflow as tf


class NormActivation(tf.keras.layers.Layer):
    """Combined Normalization and Activation layers."""

    def __init__(self,
                momentum=0.997,
                epsilon=1e-4,
                trainable=True,
                init_zero=False,
                use_activation=True,
                activation='relu',
                fused=True,
                name=None):
        """A class to construct layers for a batch normalization followed by a ReLU.

        Args:
          momentum: momentum for the moving average.
          epsilon: small float added to variance to avoid dividing by zero.
          trainable: `bool`, if True also add variables to the graph collection
            GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
            layer.
          init_zero: `bool` if True, initializes scale parameter of batch
            normalization with 0. If False, initialize it with 1.
          fused: `bool` fused option in batch normalziation.
          use_actiation: `bool`, whether to add the optional activation layer after
            the batch normalization layer.
          activation: 'string', the type of the activation layer. Currently support
            `relu` and `swish`.
          name: `str` name for the operation.
        """
        super().__init__(trainable=trainable)

        if init_zero:
            gamma_initializer = tf.keras.initializers.Zeros()
        else:
            gamma_initializer = tf.keras.initializers.Ones()

        self._normalization_op = tf.keras.layers.BatchNormalization(momentum=momentum,
                                                                    epsilon=epsilon,
                                                                    center=True,
                                                                    scale=True,
                                                                    trainable=trainable,
                                                                    fused=fused,
                                                                    gamma_initializer=gamma_initializer,
                                                                    name=name)

        self._use_activation = use_activation
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))

    def __call__(self, inputs, is_training=None):
        """Builds the normalization layer followed by an optional activation layer.

        Args:
            inputs: `Tensor` of shape `[batch, channels, ...]`.
            is_training: `boolean`, if True if model is in training mode.

        Returns:
            A normalized `Tensor` with the same `data_format`.
        """
        # We will need to keep training=None by default, so that it can be inherit
        # from keras.Model.training
        if is_training and self.trainable:
            is_training = True
        inputs = self._normalization_op(inputs, training=is_training)

        if self._use_activation:
            inputs = self._activation_op(inputs)
        return inputs


def norm_activation_builder(momentum=0.997,
                            epsilon=1e-4,
                            trainable=True,
                            activation='relu',
                            **kwargs):

    return functools.partial(NormActivation,
                             momentum=momentum,
                             epsilon=epsilon,
                             trainable=trainable,
                             activation=activation,
                             **kwargs)


def compose(*funcs):
    """Compose arbitrarily many functions, evaluated left to right.

    Reference: https://mathieularose.com/function-composition-in-python/
    """
    return functools.reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)


@functools.wraps(tf.keras.layers.Conv2D)
def YoloConv2D(*args, **kwargs):
    """Wrapper to set Yolo parameters for Conv2D."""
    L2_FACTOR = 1e-5
    yolo_conv_kwargs = {'kernel_regularizer': tf.keras.regularizers.l2(L2_FACTOR)}
    yolo_conv_kwargs['bias_regularizer'] = tf.keras.regularizers.l2(L2_FACTOR)
    yolo_conv_kwargs.update(kwargs)
    return tf.keras.layers.Conv2D(*args, **yolo_conv_kwargs)


@functools.wraps(YoloConv2D)
def DarknetConv2D(*args, **kwargs):
    """Wrapper to set Darknet parameters for YoloConv2D."""
    darknet_conv_kwargs = {'padding': 'valid' if kwargs.get('strides')==(2,2) else 'same'}
    darknet_conv_kwargs.update(kwargs)
    return YoloConv2D(*args, **darknet_conv_kwargs)
