import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from AdversarialClassifier import AdversarialClassifier


class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert (kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)
        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal



class Encoder(nn.Module):
    """
    Encoder module:
    """

    def __init__(self, dim_mel, dim_content, dim_speaker, dim_pre):
        super(Encoder, self).__init__()

        self.lstm1 = nn.LSTM(dim_mel, dim_pre, 1, batch_first=True)

        convolutions = []
        for i in range(3):
            conv_layer = nn.Sequential(
                ConvNorm(dim_pre,
                         dim_pre,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(dim_pre))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True)

        self.linear_projection = LinearNorm(1024, dim_content+dim_speaker)

    def forward(self, x):

        # self.lstm1.flatten_parameters()
        x, _ = self.lstm1(x)
        x = x.transpose(1, 2)

        for conv in self.convolutions:
            x = F.relu(conv(x))
        x = x.transpose(1, 2)

        outputs, _ = self.lstm2(x)

        encoder_output = self.linear_projection(outputs)

        return encoder_output #(batchsize, length, dim_content+dim_spk)


class Decoder(nn.Module):
    """Decoder module:
    """

    def __init__(self, dim_neck, dim_emb, dim_pre):
        super(Decoder, self).__init__()

        self.lstm1 = nn.LSTM(dim_neck + dim_emb, dim_pre, 1, batch_first=True)

        convolutions = []
        for i in range(3):
            conv_layer = nn.Sequential(
                ConvNorm(dim_pre,
                         dim_pre,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(dim_pre))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)

        self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True)

        self.linear_projection = LinearNorm(1024, 80)

    def forward(self, x):

        # self.lstm1.flatten_parameters()
        x, _ = self.lstm1(x)
        x = x.transpose(1, 2)

        for conv in self.convolutions:
            x = F.relu(conv(x))
        x = x.transpose(1, 2)

        outputs, _ = self.lstm2(x)

        decoder_output = self.linear_projection(outputs)

        return decoder_output


class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(80, 512,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(512))
        )

        for i in range(1, 5 - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(512,
                             512,
                             kernel_size=5, stride=1,
                             padding=2,
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(512))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(512, 80,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(80))
        )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = torch.tanh(self.convolutions[i](x))

        x = self.convolutions[-1](x)

        return x


class Generator(nn.Module):
    """
    Generator network.
    dim_mel: dimension of mel
    dim_neck:dimension of content embedding
    dim_emb: dimension of speaker embedding
    dim_pre: dimension of hidden code.
    """

    def __init__(self, dim_mel, dim_neck, dim_emb, dim_pre):
        super(Generator, self).__init__()

        self.encoder = Encoder(dim_mel, dim_neck, dim_emb, dim_pre)
        self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
        self.postnet = Postnet()
        self.dim_emb = dim_emb
        self.dim_neck = dim_neck
        self.adversarial1 = AdversarialClassifier(dim_neck, dim_neck, num_classes=99)
        self.adversarial2 = AdversarialClassifier(dim_emb, dim_neck, num_classes=99)

    def forward(self, x1, x2=None, dim_neck=32):
        z = self.encoder(x)
        if x2 is None:
            spk_embedding_source = torch.split(z, (self.dim_neck, self.dim_emb), dim=-1)[1]
            content_embedding_source = torch.split(z, (self.dim_neck, self.dim_emb), dim=-1)[0]
            spk_embedding = torch.norm(spk_embedding_source, dim=1, keepdim=True)
            z_new = torch.cat((content_embedding_source, spk_embedding.expand(-1, x1.size(1), -1)), dim=-1)
            mel_outputs = self.decoder(z_new)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
            mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)
            spk_dim_predict = self.adversarial2(spk_embedding)
            content_dim_predict = self.adversarial1(content_embedding_source)
            return mel_outputs, mel_outputs_postnet, z, spk_dim_predict, content_dim_predict
        else:
            print("inference phase")
            encoder_output_target = self.encoder(x2)
            spk_embedding_target = torch.split(encoder_output_target, (self.dim_neck, self.dim_emb), dim=-1)[1]
            spk_embedding_target = torch.norm(spk_embedding_target, dim=1, keepdim=True)
            content_embedding_source = torch.split(z, (self.dim_neck, self.dim_emb), dim=-1)[0]
            print("test shape:content shape:{}, spk shape: {}".format(content_embedding_source.shape, spk_embedding_target.shape))
            encoder_outputs = torch.cat((content_embedding_source, spk_embedding_target.expand(-1, x1.size(1), -1)), dim=-1)
            mel_outputs = self.decoder(encoder_outputs)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1))
            mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1)
            return mel_outputs_postnet

if __name__ == '__main__':
    x = torch.rand(4, 128, 80)
    x2 = torch.rand(4, 193, 80)
    x3 = torch.rand(4, 67, 80)
    dim_neck = 32
    dim_emb = 256
    dim_pre = 512
    #freq =16
    dim_mel = 80
    #encoder = Encoder(dim_neck, dim_emb, freq)
    generator = Generator(dim_mel, dim_neck, dim_emb, dim_pre)
    output, output1, z, spk_dim, content_dim = generator(x, dim_neck=dim_neck)
    test1 = generator(x, x2, dim_neck)
    test2 = generator(x, x3, dim_neck)
    print(output.shape)
    print("test shape:", test1.shape, test2.shape)


