import torch
import torch.nn as nn
import SincLowpassPool

import DyadicSinc_FilterBank

class PDNNet(nn.Module):
    def __init__(self,
                 decomp_degree = 8,
                 sample_freq = 24000,
                 kernel_length = 1025,
                 stride = 1,
                 padding = 'same',
                 freq_overlap = 20,
                 divide_in_melscale=False,
                 use_relu_act = True,
                 device = None ):
        super(PDNNet, self).__init__()
        self.decomp_degree = decomp_degree
        self.sample_freq = sample_freq
        self.stride = stride
        self.padding = padding
        self.kernel_length = kernel_length
        self.divide_in_melscale = divide_in_melscale
        self.freq_overlap = freq_overlap
        self.use_relu_act = use_relu_act
        self.device = device
        self.prepare_dyadic_filters()
        self.init_activation_layer()
        self.init_backbone_layers()

    def init_activation_layer(self):
        self.relu = nn.ReLU(inplace=True)

    def init_backbone_layers(self):
        '''
        given the pre-extracted feat: [B, channel_num, 1, timelen], we concatenate
        a quasi- encoder-decoder neural network to learn framewise representation.
        In our implementation, the pre-extracted feature size is:
        [B, 256, 1, 24000*10]
        [B, 256, 1, 40*10]
        [B, 512, 1, 40*10]
        [B, 512, 1, 20*10]
        [B, 1024,1, 20*10]
        [B, 1024,1, 10*10]
        [B, 512, 1, 10*10]
        [B, 256, 1, 10*10]
        [B, 1, 1, 10*10]
        :return: [B, 1, 1, 10*10]
        '''
        self.downsample_layer1 = SincLowpassPool.SincLowpassPool(channel_num=256,
                                                 kernel_length=129,
                                                 per_channel_pool=True,
                                                 init_freq_val=0.4,
                                                 stride=5,
                                                 padding='valid',
                                                 trainable=True,
                                                 use_bias=True)

        self.conv_block1 = nn.Sequential(nn.Conv2d(in_channels=256,
                                                   out_channels=512,
                                                   kernel_size=(1,3),
                                                   stride=(1,1),
                                                   padding='same'),
                                         nn.BatchNorm2d(num_features=512),
                                         nn.ReLU(inplace=True))

        self.downsample_layer2 = SincLowpassPool.SincLowpassPool(channel_num=512,
                                                 kernel_length=129,
                                                 per_channel_pool=True,
                                                 init_freq_val=0.4,
                                                 stride=5,
                                                 padding='valid',
                                                 trainable=True,
                                                 use_bias=True)

        self.conv_block2 = nn.Sequential(nn.Conv2d(in_channels=512,
                                                   out_channels=1024,
                                                   kernel_size=(1, 3),
                                                   stride=(1, 1),
                                                   padding='same'),
                                         nn.BatchNorm2d(num_features=1024),
                                         nn.ReLU(inplace=True))
        self.downsample_layer3 = SincLowpassPool.SincLowpassPool(channel_num=1024,
                                                 kernel_length=129,
                                                 per_channel_pool=True,
                                                 init_freq_val=0.4,
                                                 stride=3,
                                                 padding='valid',
                                                 trainable=True,
                                                 use_bias=True)

        self.conv_block3 = nn.Sequential(nn.Conv2d(in_channels=1024,
                                                   out_channels=512,
                                                   kernel_size=(1, 3),
                                                   stride=(1, 1),
                                                   padding='same'),
                                         nn.BatchNorm2d(num_features=512),
                                         nn.ReLU(inplace=True))

        self.conv_block4 = nn.Sequential(nn.Conv2d(in_channels=512,
                                                   out_channels=256,
                                                   kernel_size=(1, 3),
                                                   stride=(1, 1),
                                                   padding='same'),
                                         nn.BatchNorm2d(num_features=256),
                                         nn.ReLU(inplace=True))

        self.fc_linear_score = nn.Sequential(nn.Linear(in_features=256,
                                                 out_features=1,
                                                 bias=True),
                                             nn.Sigmoid())
        # self.fc_linear_score = nn.Sequential(nn.Linear(in_features=256,
        #                                                out_features=128,
        #                                                bias=True),
        #                                      nn.BatchNorm1d(num_features=50),
        #                                      nn.ReLU(),
        #                                      nn.Linear(in_features=128,
        #                                                out_features=1,
        #                                                bias=True),
        #                                      nn.Sigmoid())


    def prepare_dyadic_filters(self):
        #first degree
        self.decomposer_degree1 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=2,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap = self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        self.downsample_degree1 = SincLowpassPool.SincLowpassPool(
            channel_num=2,
            kernel_length=129,
            per_channel_pool=True,
            init_freq_val=0.4,
            stride=2,
            padding='valid',
            trainable=True,
            use_bias=True)

        self.decomposer_degree2 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=4,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        self.downsample_degree2 = SincLowpassPool.SincLowpassPool(
            channel_num=4,
            kernel_length=129,
            per_channel_pool=True,
            init_freq_val=0.4,
            stride=2,
            padding='valid',
            trainable=True,
            use_bias=True)

        self.decomposer_degree3 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=8,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        self.downsample_degree3 = SincLowpassPool.SincLowpassPool(
            channel_num=8,
            kernel_length=129,
            per_channel_pool=True,
            init_freq_val=0.4,
            stride=2,
            padding='valid',
            trainable=True,
            use_bias=True)

        self.decomposer_degree4 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=16,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        self.downsample_degree4 = SincLowpassPool.SincLowpassPool(
            channel_num=16,
            kernel_length=129,
            per_channel_pool=True,
            init_freq_val=0.4,
            stride=2,
            padding='valid',
            trainable=True,
            use_bias=True)

        self.decomposer_degree5 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=32,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        self.downsample_degree5 = SincLowpassPool.SincLowpassPool(
            channel_num=32,
            kernel_length=129,
            per_channel_pool=True,
            init_freq_val=0.4,
            stride=2,
            padding='valid',
            trainable=True,
            use_bias=True)

        self.decomposer_degree6 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=64,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        # self.downsample_degree6 = SincLowpassPool.SincLowpassPool(
        #     channel_num=64,
        #     kernel_length=257,
        #     per_channel_pool=True,
        #     init_freq_val=0.4,
        #     stride=2,
        #     padding='valid',
        #     trainable=True,
        #     use_bias=True)

        self.decomposer_degree7 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=128,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        # self.downsample_degree7 = SincLowpassPool.SincLowpassPool(
        #     channel_num=128,
        #     kernel_length=257,
        #     per_channel_pool=True,
        #     init_freq_val=0.4,
        #     stride=2,
        #     padding='valid',
        #     trainable=True,
        #     use_bias=True)

        self.decomposer_degree8 = DyadicSinc_FilterBank.DyadicSincFilterBank(
            filter_num=256,
            kernel_length=self.kernel_length,
            low_freq2model=0,
            high_freq2model=self.sample_freq,
            sample_freq=24000,
            min_low_hz=20,
            min_band_hz=50,
            freq_overlap=self.freq_overlap,
            output_channel=1,
            stride=self.stride,
            padding=self.padding,
            divide_in_melscale=self.divide_in_melscale,
            device=self.device,)

        # self.downsample_degree8 = SincLowpassPool.SincLowpassPool(
        #     channel_num=256,
        #     kernel_length=257,
        #     per_channel_pool=True,
        #     init_freq_val=0.4,
        #     stride=2,
        #     padding='valid',
        #     trainable=True,
        #     use_bias=True)

    def timefreq_rep_frontend(self, input_waveform):
        output_waveform_degree1 = self.decomposer_degree1(input_waveform)
        if self.use_relu_act:
            output_waveform_degree1 = self.relu(output_waveform_degree1)
        output_waveform_degree1 = self.downsample_degree1(output_waveform_degree1)

        output_waveform_degree2 = self.decomposer_degree2(
            output_waveform_degree1)
        if self.use_relu_act:
            output_waveform_degree2 = self.relu(output_waveform_degree2)
        output_waveform_degree2 = self.downsample_degree2(output_waveform_degree2)

        output_waveform_degree3 = self.decomposer_degree3(output_waveform_degree2)
        if self.use_relu_act:
            output_waveform_degree3 = self.relu(output_waveform_degree3)
        output_waveform_degree3 = self.downsample_degree3(output_waveform_degree3)

        output_waveform_degree4 = self.decomposer_degree4(output_waveform_degree3)
        if self.use_relu_act:
            output_waveform_degree4 = self.relu(output_waveform_degree4)
        output_waveform_degree4 = self.downsample_degree4(output_waveform_degree4)

        output_waveform_degree5 = self.decomposer_degree5(output_waveform_degree4)
        if self.use_relu_act:
            output_waveform_degree5 = self.relu(output_waveform_degree5)
        output_waveform_degree5 = self.downsample_degree5(output_waveform_degree5)

        output_waveform_degree6 = self.decomposer_degree6(output_waveform_degree5)
        if self.use_relu_act:
            output_waveform_degree6 = self.relu(output_waveform_degree6)
        # output_waveform_degree6 = self.downsample_degree6(output_waveform_degree6)

        output_waveform_degree7 = self.decomposer_degree7(output_waveform_degree6)
        if self.use_relu_act:
            output_waveform_degree7 = self.relu(output_waveform_degree7)
        # output_waveform_degree7 = self.downsample_degree7(output_waveform_degree7)

        output_waveform_degree8 = self.decomposer_degree8(output_waveform_degree7)
        if self.use_relu_act:
            output_waveform_degree8 = self.relu(output_waveform_degree8)
        # output_waveform_degree8 = self.downsample_degree8(output_waveform_degree8)

        return output_waveform_degree8

    def backbone_encoding(self, frontend_encode_feat ):
        '''
        Given the front-end encoded feature, further encode the front-end feature
        to framewise representation
        :param frentend_encode_feat: [B, input_channel_num, 1, T]
        :return: [B, output_channel_num, 1, T1], where T1 << T
        '''
        interm_feat = self.downsample_layer1(frontend_encode_feat)
        interm_feat = self.conv_block1( interm_feat )
        interm_feat = self.downsample_layer2( interm_feat )
        interm_feat = self.conv_block2(interm_feat)
        interm_feat = self.downsample_layer3( interm_feat )
        interm_feat = self.conv_block3(interm_feat)
        interm_feat = self.conv_block4( interm_feat )

        return interm_feat


    def forward(self, input_waveform):
        '''
        forward pass of one batch of input waveform
        :param input_waveform: [B, wavelen]
        :return: [B, wavelen, freqbins]
        '''
        input_waveform = torch.unsqueeze( input_waveform, dim=1 )
        input_waveform = torch.unsqueeze( input_waveform, dim=1 )
        frontend_feat_rep = self.timefreq_rep_frontend( input_waveform )

        framewise_feat_encode = self.backbone_encoding( frontend_feat_rep ) #[10, 256, 1, 50]

        framewise_feat_encode = torch.permute( framewise_feat_encode,
                                               dims=[0,3,2,1] )

        framewise_feat_encode = torch.squeeze(framewise_feat_encode)
        densitymap_pred = self.fc_linear_score( framewise_feat_encode )
        # densitymap_pred = densitymap_pred*1.4
        # densitymap_pred = torch.clamp(densitymap_pred,
        #                               min=0.,
        #                               max=1.35)
        densitymap_pred = torch.squeeze(densitymap_pred)

        return densitymap_pred








