# -*- coding: utf-8 -*-

# Copyright 2020 MINH ANH (@dathudeptrai)
#  MIT License (https://opensource.org/licenses/MIT)

"""Tensorflow MelGAN modules complatible with pytorch."""

import tensorflow as tf

import numpy as np

from parallel_wavegan.layers.tf_layers import TFConvTranspose1d
from parallel_wavegan.layers.tf_layers import TFReflectionPad1d
from parallel_wavegan.layers.tf_layers import TFResidualStack


class TFMelGANGenerator(tf.keras.layers.Layer):
    """Tensorflow MelGAN generator module."""

    def __init__(self,
                 in_channels=80,
                 out_channels=1,
                 kernel_size=7,
                 channels=512,
                 bias=True,
                 upsample_scales=[8, 8, 2, 2],
                 stack_kernel_size=3,
                 stacks=3,
                 nonlinear_activation="LeakyReLU",
                 nonlinear_activation_params={"alpha": 0.2},
                 pad="ReflectionPad1d",
                 pad_params={},
                 use_final_nonlinear_activation=True,
                 use_weight_norm=True,
                 use_causal_conv=False,
                 ):
        """Initialize TFMelGANGenerator module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Kernel size of initial and final conv layer.
            channels (int): Initial number of channels for conv layer.
            bias (bool): Whether to add bias parameter in convolution layers.
            upsample_scales (list): List of upsampling scales.
            stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
            stacks (int): Number of stacks in a single residual stack.
            nonlinear_activation (str): Activation function module name.
            nonlinear_activation_params (dict): Hyperparameters for activation function.
            pad (str): Padding function module name before dilated convolution layer.
            pad_params (dict): Hyperparameters for padding function.
            use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
            use_weight_norm (bool): No effect but keep it as is to be the same as pytorch version.
            use_causal_conv (bool): Whether to use causal convolution.

        """
        super(TFMelGANGenerator, self).__init__()

        # check hyper parameters is valid
        assert not use_causal_conv, "Not supported yet."
        assert channels >= np.prod(upsample_scales)
        assert channels % (2 ** len(upsample_scales)) == 0
        assert pad == "ReflectionPad1d", f"Not supported (pad={pad})."

        # add initial layer
        layers = []
        layers += [
            TFReflectionPad1d((kernel_size - 1) // 2),
            tf.keras.layers.Conv2D(filters=channels,
                                   kernel_size=(kernel_size, 1),
                                   padding="valid",
                                   use_bias=bias)
        ]

        for i, upsample_scale in enumerate(upsample_scales):
            # add upsampling layer
            layers += [
                getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
                TFConvTranspose1d(
                    channels=channels // (2 ** (i + 1)),
                    kernel_size=upsample_scale * 2,
                    stride=upsample_scale,
                    padding="same",
                )
            ]

            # add residual stack
            for j in range(stacks):
                layers += [
                    TFResidualStack(
                        kernel_size=stack_kernel_size,
                        channels=channels // (2 ** (i + 1)),
                        dilation=stack_kernel_size ** j,
                        bias=bias,
                        nonlinear_activation=nonlinear_activation,
                        nonlinear_activation_params=nonlinear_activation_params,
                        padding="same",
                    )
                ]

        # add final layer
        layers += [
            getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
            TFReflectionPad1d((kernel_size - 1) // 2),
            tf.keras.layers.Conv2D(filters=out_channels,
                                   kernel_size=(kernel_size, 1),
                                   use_bias=bias),
        ]
        if use_final_nonlinear_activation:
            layers += [tf.keras.layers.Activation("tanh")]

        self.melgan = tf.keras.models.Sequential(layers)

    # TODO(kan-bayashi): Fix hard coded dimension
    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32)])
    def call(self, c):
        """Calculate forward propagation.

        Args:
            c (Tensor): Input tensor (B, T, in_channels).

        Returns:
            Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels).

        """
        c = tf.expand_dims(c, 2)
        c = self.melgan(c)
        return c[:, :, 0, :]
