import math

import einops
import librosa
import numpy as np
import torch
import torch.nn as nn
import torchaudio.functional
from PIL.Image import Image
from matplotlib import pyplot as plt
from scipy import signal

from CrossModal.R2Plus1D import R2Plus1DClassifier

class AudioNet(nn.Module):

    def __init__(self, n_fft, frame, auto_mode=False, num_classes=1, audio_only=False, direct_classify=False, time=10):
        super(AudioNet, self).__init__()
        self.audio_only = audio_only
        self.classify = direct_classify
        self.n_fft = n_fft
        self.frame = frame
        self.time = time
        self.num_classes = num_classes
        self.core_net = R2Plus1DClassifier(1, (2, 2, 2, 2), channel=1)

    def forward(self, x):
        B, T, D = x.shape
        x = einops.rearrange(x, 'B T D -> (B T) D')
        x = stft(x, self.n_fft)
        x = einops.rearrange(x, '(B T) H (W C)-> B H (W T) C', B=B, T=T, C=1)
        x = einops.rearrange(x, 'B H (W T) C -> B C T H W ', T=self.time)
        if not self.audio_only:
            x, fea = self.core_net(x)
            return x, fea
        elif self.classify:
            x, fea = self.core_net(x)
            return x
        else:
            x = self.core_net(x)
            return x[-1].mean(dim=[2, 3])


def generate_filter(start, end, n_fft, frame):
    return [[0. if i > end or i < start else 1. for j in range(frame)] for i in range(n_fft)]


def norm_sigma(x):
    return 2. * torch.sigmoid(x) - 1.


def extract_frequency_indices(low_freq, high_freq, sample_rate, n_fft):
    freq_per_bin = sample_rate / n_fft
    low_bin = int(low_freq / freq_per_bin)
    high_bin = math.ceil(high_freq / freq_per_bin)

    return low_bin, high_bin


def stft(input, n_fft):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    fft_res = torch.stft(input, n_fft * 2, hop_length=n_fft // 2, win_length=n_fft * 2, return_complex=True,
                         window=torch.blackman_window(n_fft * 2, device=device),
                         center=False, pad_mode='reflect', normalized=False, onesided=None)
    log_tensor = torch.log(torch.abs(fft_res) + 1e-8)
    return log_tensor


if __name__ == "__main__":
    import torch

    inputs = torch.rand(1, 513, 80, 3)
    net = AudioNet(n_fft=513, frame=80, time=1, auto_mode=True, audio_only=True, direct_classify=True)
    from thop import profile

    flops, params = profile(net, inputs=(inputs,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
    # outputs = net.forward(inputs)
    # print(outputs.shape)
