#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Union, Tuple

import tensorflow as tf

from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense
from rl_coach.base_parameters import MiddlewareScheme, NetworkComponentParameters
from rl_coach.core_types import MiddlewareEmbedding


class Middleware(object):
    """
    A middleware embedder is the middle part of the network. It takes the embeddings from the input embedders,
    after they were aggregated in some method (for example, concatenation) and passes it through a neural network
    which can be customizable but shared between the heads of the network
    """
    def __init__(self, activation_function=tf.nn.relu,
                 scheme: MiddlewareScheme = MiddlewareScheme.Medium,
                 batchnorm: bool = False, dropout_rate: float = 0.0, name="middleware_embedder", dense_layer=Dense,
                 is_training=False):
        self.name = name
        self.input = None
        self.output = None
        self.activation_function = activation_function
        self.batchnorm = batchnorm
        self.dropout_rate = dropout_rate
        self.scheme = scheme
        self.return_type = MiddlewareEmbedding
        self.dense_layer = dense_layer
        if self.dense_layer is None:
            self.dense_layer = Dense
        self.is_training = is_training

        # layers order is conv -> batchnorm -> activation -> dropout
        if isinstance(self.scheme, MiddlewareScheme):
            self.layers_params = copy.copy(self.schemes[self.scheme])
            self.layers_params = [convert_layer(l) for l in self.layers_params]
        else:
            # if scheme is specified directly, convert to TF layer if it's not a callable object
            # NOTE: if layer object is callable, it must return a TF tensor when invoked
            self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)]

        # we allow adding batchnorm, dropout or activation functions after each layer.
        # The motivation is to simplify the transition between a network with batchnorm and a network without
        # batchnorm to a single flag (the same applies to activation function and dropout)
        if self.batchnorm or self.activation_function or self.dropout_rate > 0:
            for layer_idx in reversed(range(len(self.layers_params))):
                self.layers_params.insert(layer_idx+1,
                                          BatchnormActivationDropout(batchnorm=self.batchnorm,
                                                                     activation_function=self.activation_function,
                                                                     dropout_rate=self.dropout_rate))

    def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Wrapper for building the module graph including scoping and loss creation
        :param input_layer: the input to the graph
        :return: the input placeholder and the output of the last layer
        """
        with tf.variable_scope(self.get_name()):
            self.input = input_layer
            self._build_module()

        return self.input, self.output

    def _build_module(self) -> None:
        """
        Builds the graph of the module
        This method is called early on from __call__. It is expected to store the graph
        in self.output.
        :param input_layer: the input to the graph
        :return: None
        """
        pass

    def get_name(self) -> str:
        """
        Get a formatted name for the module
        :return: the formatted name
        """
        return self.name

    @property
    def schemes(self):
        raise NotImplementedError("Inheriting middleware must define schemes matching its allowed default "
                                  "configurations.")

    def __str__(self):
        result = [str(l) for l in self.layers_params]
        if self.layers_params:
            return '\n'.join(result)
        else:
            return 'No layers'
