
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import math


class SincLowpassPool(nn.Module):
    def __init__(self,
                 channel_num=256,
                 kernel_length=257,
                 per_channel_pool = True,
                 init_freq_val = 0.1,
                 stride=600,
                 padding='valid',
                 sample_rate = 24000,
                 trainable = True,
                 use_bias = True ):
        super(SincLowpassPool, self).__init__()
        self.channel_num = channel_num
        assert kernel_length % 2 == 1
        self.per_channel_pool = per_channel_pool
        self.trainable = trainable
        self.kernel_length = kernel_length
        self.stride = stride
        self.padding = padding
        self.init_freq_val = init_freq_val
        self.use_bias = use_bias
        self.sample_rate = sample_rate

        self.init_trainable_sigma()
        self.prepare_filterbank()

    def init_trainable_sigma(self):
        if self.per_channel_pool:
            init_freq = np.ones([self.channel_num], np.float32)*self.init_freq_val
            self.sigma = nn.Parameter( data=torch.from_numpy(init_freq).to(dtype=torch.float32),
                                       requires_grad=self.trainable )
        else:
            init_freq = np.ones(shape=[1], dtype=np.float32)*self.init_freq_val
            self.sigma = nn.Parameter( data=torch.from_numpy(init_freq).to(dtype=torch.float32),
                                       requires_grad=self.trainable )

    def prepare_filterbank(self):
        n_lin = np.linspace(start=0,
                            stop=self.kernel_length / 2 - 1,
                            num=self.kernel_length // 2,
                            dtype=np.float32)

        # window = 0.54 - 0.46 * np.cos(2 * math.pi * n_lin / self.kernel_length)
        #
        # self.window = nn.Parameter(data=torch.from_numpy(window).to(dtype=torch.float32),
        #                            requires_grad=False)
        n = (self.kernel_length - 1) / 2.
        n_ = 2 * math.pi * np.arange(-n, 0, 1)
        n_ = np.reshape(n_, newshape=[1, -1])

        self.n_ = nn.Parameter(data=torch.from_numpy(n_).to(dtype=torch.float32),
                               requires_grad=False)

        self.bandpass_center = nn.Parameter(data=torch.from_numpy(np.ones([self.channel_num,1], np.float32)).to(dtype=torch.float32),
                                            requires_grad=False)

    def construct_filterbank(self):
        '''
        explictly construct the filter bank during each forward process, I find
        constructing it beforehand doesn't work. Strange things!
        :return: temporally constructed filter bank
        '''
        if self.per_channel_pool:
            sigma = self.sigma
        else:
            sigma = torch.tile(self.sigma, dims=[self.channel_num])

        sigma = torch.clamp( sigma,
                             min=0.001,
                             max=0.5)

        sigma = torch.unsqueeze(sigma, dim=-1 ) #[channel_num, 1]
        f_times_t = torch.matmul(sigma, self.n_) #[channel_num, kernel_length//2-1]

        # band_pass_left = (torch.sin(f_times_t)/(self.n_ / 2)) * self.window
        band_pass_left = torch.sin(f_times_t)/f_times_t

        band_pass_right = torch.flip(band_pass_left, dims=[1])

        band_pass = torch.cat([band_pass_left,
                               self.bandpass_center,
                               band_pass_right],
                              dim=1)


        # reshape the filter bank to fit 2D conv operation
        sinc_filters = band_pass.view(self.channel_num,
                                      1,
                                      1,
                                      self.kernel_length)

        return sinc_filters

    def get_sinc_filterbank(self):
        return self.construct_filterbank()

    def forward(self, input_feat):
        '''
        Forward pass sinclow-pass filter bank pooling, because we use depthwise convolution
        to achieve pooling, we rearrange the freqbins axis to channels dimension
        :param waveform: [B, n_channels, 1, timelen]
        :return: Torch Tensor
        '''
        sinc_filterbank = self.construct_filterbank()
        input_feat = F.pad( input_feat, pad=(self.kernel_length//2, self.kernel_length//2), mode='constant', value=0.)
        pooled_feat = F.conv2d(input=input_feat,
                               weight=sinc_filterbank,
                               stride=self.stride,
                               padding=self.padding,
                               bias=None,
                               groups=self.channel_num)

        return pooled_feat
