from typing import NoReturn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import magphase


def init_layer(layer: nn.Module) -> NoReturn:
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)


def init_bn(bn: nn.Module) -> NoReturn:
    r"""Initialize a Batchnorm layer."""
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)
    bn.running_mean.data.fill_(0.0)
    bn.running_var.data.fill_(1.0)


class ConvBlockRes(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
        r"""Residual block."""
        super(ConvBlockRes, self).__init__()

        self.activation = activation
        padding = [kernel_size[0] // 2, kernel_size[1] // 2]
        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.abn2 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
            )
            self.is_shortcut = True
        else:
            self.is_shortcut = False

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.is_shortcut:
            init_layer(self.shortcut)

    def forward(self, x):
        origin = x
        x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
        x = self.conv2(self.abn2(x))

        if self.is_shortcut:
            return self.shortcut(origin) + x
        else:
            return origin + x


class EncoderBlockRes4B(nn.Module):
    def __init__(
            self, in_channels, out_channels, kernel_size, downsample, activation, momentum
    ):
        r"""Encoder block, contains 8 convolutional layers."""
        super(EncoderBlockRes4B, self).__init__()

        self.conv_block1 = ConvBlockRes(
            in_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block2 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.downsample = downsample

    def forward(self, x):
        encoder = self.conv_block1(x)
        encoder = self.conv_block2(encoder)
        encoder = self.conv_block3(encoder)
        encoder = self.conv_block4(encoder)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder


class DecoderBlockRes4B(nn.Module):
    def __init__(
            self, in_channels, out_channels, kernel_size, upsample, activation, momentum
    ):
        r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
        super(DecoderBlockRes4B, self).__init__()
        self.kernel_size = kernel_size
        self.stride = upsample
        self.activation = activation

        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.stride,
            stride=self.stride,
            padding=(0, 0),
            bias=False,
            dilation=(1, 1),
        )

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.conv_block2 = ConvBlockRes(
            out_channels * 2, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block5 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_layer(self.conv1)

    def forward(self, input_tensor, concat_tensor):
        x = self.conv1(F.relu_(self.bn1(input_tensor)))
        x = torch.cat((x, concat_tensor), dim=1)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        return x


class ByteSep(nn.Module):
    def __init__(self, input_channels, in_size, hop_length, K=4):
        super(ByteSep, self).__init__()

        # self.input_channels = input_channels
        # self.target_sources_num = target_sources_num
        activation = 'leaky_relu'
        momentum = 0.01
        # self.stft = Wav2Spec(hop_length, WINDOW_SIZE)
        # self.istft = Spec2Wav(hop_length, WINDOW_SIZE)
        # Downsample rate along the time axis.
        # self.K = 4  # outputs: |M|, cos∠M, sin∠M, Q
        self.time_downsample_ratio = 2 ** 5  # This number equals 2^{#encoder_blcoks}
        self.bn0 = nn.BatchNorm2d(in_size+1, momentum=momentum)

        self.encoder_block1 = EncoderBlockRes4B(
            in_channels=input_channels,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block2 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block3 = EncoderBlockRes4B(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block4 = EncoderBlockRes4B(
            in_channels=128,
            out_channels=256,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block5 = EncoderBlockRes4B(
            in_channels=256,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block6 = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7a = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7b = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7c = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7d = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block1 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block2 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block3 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=256,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block4 = DecoderBlockRes4B(
            in_channels=256,
            out_channels=128,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block5 = DecoderBlockRes4B(
            in_channels=128,
            out_channels=64,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block6 = DecoderBlockRes4B(
            in_channels=64,
            out_channels=32,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv_block1 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=K,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn0)
        init_layer(self.after_conv2)

    def forward(self, spec_m):
        # Batch normalize on individual frequency bins.
        x = spec_m.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        x = x[..., 0: x.shape[-1] - 1]  # (bs, channels, T, F)

        # UNet
        (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
        (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
        (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
        (x4_pool, x4) = self.encoder_block4(x3_pool)  # x4_pool: (bs, 256, T / 16, F / 16)
        (x5_pool, x5) = self.encoder_block5(x4_pool)  # x5_pool: (bs, 384, T / 32, F / 32)
        (x6_pool, x6) = self.encoder_block6(x5_pool)  # x6_pool: (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7a(x6_pool)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
        x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
        x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
        x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
        x10 = self.decoder_block4(x9, x3)  # (bs, 128, T / 4, F / 4)
        x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
        x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
        (x, _) = self.after_conv_block1(x12)  # (bs, 32, T, F)

        x = self.after_conv2(x)  # (bs, channels * 3, T, F)
        x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 1024 -> 1025.
        mask_spec = torch.sigmoid(x[:, 0, :, :].unsqueeze(1))
        _mask_real = torch.tanh(x[:, 1, :, :].unsqueeze(1))
        _mask_imag = torch.tanh(x[:, 2, :, :].unsqueeze(1))
        _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
        linear_spec = x[:, 3, :, :].unsqueeze(1)
        out_spec = F.relu(spec_m.detach() * mask_spec + linear_spec)
        return out_spec, mask_cos, mask_sin

    # def forward(self, audio_m):
    #     spec_m, cos_m, sin_m = self.stft(audio_m)
    #     # Batch normalize on individual frequency bins.
    #     x = spec_m.transpose(1, 3)
    #     x = self.bn0(x)
    #     x = x.transpose(1, 3)
    #     x = x[..., 0: x.shape[-1] - 1]  # (bs, channels, T, F)
    #
    #     # UNet
    #     (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
    #     (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
    #     (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
    #     (x4_pool, x4) = self.encoder_block4(x3_pool)  # x4_pool: (bs, 256, T / 16, F / 16)
    #     (x5_pool, x5) = self.encoder_block5(x4_pool)  # x5_pool: (bs, 384, T / 32, F / 32)
    #     (x6_pool, x6) = self.encoder_block6(x5_pool)  # x6_pool: (bs, 384, T / 32, F / 64)
    #     (x_center, _) = self.conv_block7a(x6_pool)  # (bs, 384, T / 32, F / 64)
    #     (x_center, _) = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
    #     (x_center, _) = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
    #     (x_center, _) = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
    #     x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
    #     x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
    #     x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
    #     x10 = self.decoder_block4(x9, x3)  # (bs, 128, T / 4, F / 4)
    #     x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
    #     x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
    #     (x, _) = self.after_conv_block1(x12)  # (bs, 32, T, F)
    #
    #     x = self.after_conv2(x)  # (bs, channels * 3, T, F)
    #     x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 1024 -> 1025.
    #     mask_spec = torch.sigmoid(x[:, 0, :, :].unsqueeze(1))
    #     _mask_real = torch.tanh(x[:, 1, :, :].unsqueeze(1))
    #     _mask_imag = torch.tanh(x[:, 2, :, :].unsqueeze(1))
    #     _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
    #     out_cos = cos_m * mask_cos - sin_m * mask_sin
    #     out_sin = sin_m * mask_cos + cos_m * mask_sin
    #     linear_spec = x[:, 3, :, :].unsqueeze(1)
    #     out_spec = F.relu(spec_m.detach() * mask_spec + linear_spec)
    #     out_real, out_imag = out_spec * out_cos, out_spec * out_sin
    #     out_audio = self.istft(out_real, out_imag, audio_m.shape[-1])
    #     return out_spec.squeeze(1), out_audio
