import torch
from torchvision import transforms as T
from torchaudio import transforms as aT
from data.transforms import audio as aT2
from pytorchvideo import transforms as vT

import numpy as np
import copy

class AudioMAE_transform:
    """
    Audio preprocessing in AudioMAE
    """

    def __init__(self,
                 audio_rate=16000,
                 num_mels=128,
                 audio_size=1024,
                 augment=False):

        # We follow mean and std of AudioSet dataset spectrogram
        mean = -4.2677393
        std = 4.5689974

        self.audio_rate = audio_rate

        transforms = [
            aT2.ToTensor(),
            aT2.ToMono(),
            aT2.Wav2fbank(
                sampling_rate=audio_rate,
                num_mel_bins=num_mels,
                target_length=audio_size,
            ),
            aT2.Permute(0, 2, 1),  # (1, F, T) -> (1, T, F)
            aT2.Normalize(mean, std)
        ]

        self.t = T.Compose(transforms)

    def __call__(self, x):

        spec = self.t(x)

        return spec


class CAV_Audio_transform:
    """
    Audio transform in CAV
    """
    def __init__(self,
                 audio_rate=16000,
                 num_mels=128,
                 audio_size=1024,
                 freqm=0,
                 timem=0,
                 mean=-5.081,
                 std=4.4849,
                 noise=False,
                 training=False,
                 ):

        self.audio_rate = audio_rate
        self.freqm = freqm
        self.timem = timem
        self.noise = noise and training

        if training:
            transforms = [
                aT2.ToTensor(),
                aT2.ToMono(),
                aT2.CAV_Wav2fbank(
                    sampling_rate=audio_rate,
                    num_mel_bins=num_mels,
                    target_length=audio_size,
                    freqm=freqm,
                    timem=timem,
                ),
                aT2.CAV_Normalize(mean, std),
            ]
        else:
            transforms = [
                aT2.ToTensor(),
                aT2.ToMono(),
                aT2.CAV_Wav2fbank(
                    sampling_rate=audio_rate,
                    num_mel_bins=num_mels,
                    target_length=audio_size,
                    freqm=0,
                    timem=0,
                ),
                aT2.CAV_Normalize(mean, std)
            ]

        self.t = T.Compose(transforms)
        self.noise_t = T.Compose([aT2.Noise(audio_size)])

    def __call__(self, x):
        spec = self.t(x)
        if self.noise:
            spec = self.noise_t(spec)
        return spec


class CAV_Audio_transform_buffer:
    def __init__(self, audio_size, **kwargs):
        self.t = T.Compose([aT2.Noise(audio_size)])

    def __call__(self, x):
        stack = []
        for t in x:
            stack.append(self.t(t))
        return torch.stack(stack)



class CAV_Audio_transform_return_shift:
    """
    Audio transform in CAV
    """
    def __init__(self,
                 audio_rate=16000,
                 num_mels=128,
                 audio_size=1024,
                 freqm=0,
                 timem=0,
                 mean=-5.081,
                 std=4.4849,
                 noise=False,
                 training=False,
                 ):

        self.audio_rate = audio_rate
        self.freqm = freqm
        self.timem = timem
        self.noise = noise and training

        if training:
            transforms = [
                aT2.ToTensor(),
                aT2.ToMono(),
                aT2.CAV_Wav2fbank(
                    sampling_rate=audio_rate,
                    num_mel_bins=num_mels,
                    target_length=audio_size,
                    freqm=freqm,
                    timem=timem,
                ),
                aT2.CAV_Normalize(mean, std),
            ]
        else:
            transforms = [
                aT2.ToTensor(),
                aT2.ToMono(),
                aT2.CAV_Wav2fbank(
                    sampling_rate=audio_rate,
                    num_mel_bins=num_mels,
                    target_length=audio_size,
                    freqm=0,
                    timem=0,
                ),
                aT2.CAV_Normalize(mean, std)
            ]

        self.t = T.Compose(transforms)
        self.noise_t = T.Compose([aT2.Noise_return_shift(audio_size)])

    def __call__(self, x):
        spec = self.t(x)
        if self.noise:
            spec = self.noise_t(spec)
        return spec