import torch.nn as nn
from torchlibrosa.stft import STFT, ISTFT, magphase


class Wav2Spec(nn.Module):
    def __init__(self, hop_length, window_size):
        super(Wav2Spec, self).__init__()
        self.hop_length = hop_length
        self.stft = STFT(window_size, hop_length, window_size)

    def forward(self, audio):
        real, imag = self.stft(audio.reshape(-1, audio.shape[-1])[:, :-self.hop_length])
        spec, cos, sin = magphase(real, imag)
        return spec, cos, sin


class Spec2Wav(nn.Module):
    def __init__(self, hop_length, window_size):
        super(Spec2Wav, self).__init__()
        self.istft = ISTFT(window_size, hop_length, window_size)

    def forward(self, real, imag, audio_len):
        audio = self.istft(real, imag, audio_len)
        return audio


# class DW(nn.Module):
#     def __init__(self, in_channel, hidden_channel, hidden_size, pe_size, svs_size):
#         super(DW, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(in_channel, hidden_channel, (3, 3), (1, 1), (1, 1)),
#             nn.ReLU(),
#             nn.Conv2d(hidden_channel, hidden_channel, (3, 3), (1, 1), (1, 1)),
#             nn.ReLU(),
#             nn.Conv2d(hidden_channel, 1, (3, 3), (1, 1), (1, 1)),
#             nn.ReLU()
#         )
#         self.dw_svs = nn.Sequential(
#             nn.Linear(hidden_size, svs_size),
#             nn.Sigmoid()
#             # nn.ReLU(),
#         )
#         self.dw_pe = nn.Sequential(
#             nn.Linear(hidden_size, pe_size),
#             nn.Sigmoid()
#             # nn.ReLU(),
#         )
#
#     def forward(self, in_feature, scale=1):
#         out = self.model(in_feature)
#         dw_pe = self.dw_pe(out.squeeze(1)) * scale
#         dw_svs = self.dw_svs(out.squeeze(1)) * scale
#         return dw_pe, dw_svs


class DW(nn.Module):
    def __init__(self, in_channel, hidden_channel, hidden_size, pe_size, svs_size):
        super(DW, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, hidden_channel, (3, 3), (1, 1), (1, 1)),
            nn.ReLU(),
            nn.Conv2d(hidden_channel, hidden_channel, (3, 3), (1, 1), (1, 1)),
            nn.ReLU(),
        )
        self.dw_svs = nn.Sequential(
            nn.Linear(hidden_size * hidden_channel, 2048),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(2048, svs_size),
            nn.Sigmoid()
            # nn.ReLU(),
        )
        self.dw_pe = nn.Sequential(
            nn.Linear(hidden_size * hidden_channel, 2048),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(2048, pe_size),
            nn.Sigmoid()
            # nn.ReLU(),
        )

    def forward(self, in_feature, scale=2):
        out = self.model(in_feature).transpose(1, 2).flatten(2)
        dw_pe = self.dw_pe(out) * scale
        dw_svs = self.dw_svs(out) * scale
        return dw_pe, dw_svs


class DW_FC(nn.Module):
    def __init__(self, in_size, pe_size, svs_size):
        super(DW_FC, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, 2048),
            nn.ReLU(),
        )
        self.dw_svs = nn.Sequential(
            nn.Linear(2048, svs_size),
            nn.Sigmoid()
            # nn.ReLU(),
        )
        self.dw_pe = nn.Sequential(
            nn.Linear(2048, pe_size),
            nn.Sigmoid()
            # nn.ReLU(),
        )

    def forward(self, in_feature, scale=1):
        out = self.model(in_feature)
        dw_pe = self.dw_pe(out) * scale
        dw_svs = self.dw_svs(out) * scale
        return dw_pe, dw_svs
