import torch.nn as nn
import torchvision
from nnAudio.Spectrogram import STFT
import torch.nn.functional as F

class AudioModel(nn.Module):
    def __init__(self):
        super(AudioModel, self).__init__()

        resnet = torchvision.models.resnet18(num_classes=256)
        # n_fft = 1024
        # window_size = 91
        # hop_size = 91  # stride
        n_fft = 512
        window_size = 400
        hop_size = 160  # stride
        self.AudioInconv = nn.Sequential(nn.Conv2d(1, 64, stride=2, kernel_size=7, padding=3, bias=False),
                                    resnet.bn1, resnet.relu, resnet.maxpool)
        self.AudioLayer1 = resnet.layer1
        self.AudioLayer2 = resnet.layer2
        self.AudioLayer3 = resnet.layer3
        self.AudioLayer4 = resnet.layer4
        self.AudioAvgpool = resnet.avgpool
        
        self.stft = STFT(n_fft=n_fft, win_length=window_size, hop_length=hop_size, sr=16000, output_format='Magnitude')

    def forward(self, audio):

        spec = self.stft(audio)
        a_f = self.AudioAvgpool(self.AudioLayer4(self.AudioLayer3(self.AudioLayer2(self.AudioLayer1(self.AudioInconv(spec.unsqueeze(1))))))).flatten(1)
        
        return a_f
