import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class SeparableConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SeparableConvBlock, self).__init__()
        self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()
        self.norm1 = nn.LayerNorm([in_channels,25,5], elementwise_affine=False)
        #self.norm2 = nn.LayerNorm([out_channels,25,5], elementwise_affine=False)
        #self.dropout = nn.Dropout(p=0.2)

        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        self.activation_num += in_channels*25*5 + in_channels*25*5
        self.gradient_num += out_channels*(3*3*in_channels) + in_channels*(3*3*in_channels/in_channels)
        self.error_num += out_channels*25*5 + in_channels*25*5

    def forward(self, x):
        x = self.norm1(x)
        x = self.depthwise_conv(x)
        #x = self.norm1(x)
        x = self.pointwise_conv(x)
        #x = self.norm2(x)
        x = self.relu(x)
        return x

class initial_block(nn.Module):
    def __init__(self, input_channel = 1, num_classes=12):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=(10, 4), stride=(2, 2), padding=(5, 1))
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        self.activation_num += input_channel*49*10
        self.gradient_num += (10*4*1)*64
        self.error_num += 64*25*5

    def forward(self, x):
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.relu(x)
        return x
    
class FF_sequential(nn.Sequential):
    def forward(self, x,y):
        for module in self._modules.values():
            _, x = module(x,y)
        return x
    
class FF_customblock(nn.Sequential):
    def __init__(self, *args, channels, pooling=None):   
        super(FF_customblock, self).__init__(*args)
        mask = channels * pooling[0] * pooling[1]
        self.label_embedding = nn.Linear(12, mask)
        self.pool_ = pooling
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:
                self.activation_num += 12
                self.gradient_num += mask*12
                self.error_num += mask
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
            self.error_num += module.error_num
        
    def forward(self, x, y):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:  # Check if it's the last module
                break
            x = module(x)
        if self.pool_:
            pool = nn.AdaptiveAvgPool2d(self.pool_)
            x_temp = pool(x)
        else:
            x_temp = x
        c = self.label_embedding(y)
        c_reshaped = c.view(c.shape[0], x_temp.shape[1], x_temp.shape[2], x_temp.shape[3])
        f = x_temp*c_reshaped
        self.h = f
        return f, x


class DS_CNN_FF(nn.Module):
    def __init__(self, config):
        super(DS_CNN_FF, self).__init__()
        channels = config['ds_channels']
        if config['FF_block_nums'] == 2:
            pooling = config['ds_pooling_2']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]), #(64,64)
                channels = channels[0][1], 
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels = channels[3][1], 
                pooling = pooling[1] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2)
        
        if config['FF_block_nums'] == 2.5:
            pooling = config['ds_pooling_2']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]), #(64,64)
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                channels = channels[1][1], 
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels = channels[3][1], 
                pooling = pooling[1] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2)

        if config['FF_block_nums'] == 2.9: # 3-1
            pooling = config['ds_pooling_2']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]), #(64,64)
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                channels = channels[2][1], 
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels = channels[3][1], 
                pooling = pooling[1] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2)

        if config['FF_block_nums'] == 3: # 1-2-1
            pooling = config['ds_pooling_3']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]),#(64,64)
                #SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                channels = channels[0][1],
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                channels = channels[2][1],
                pooling = pooling[1] # pooling = [3,3]
            )
            block3 = FF_customblock(
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels=channels[3][1],
                pooling = pooling[2] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2, block3)

        if config['FF_block_nums'] == 3.5: # 2-1-1
            pooling = config['ds_pooling_3']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]),#(64,64)
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                channels = channels[1][1],
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                #SeparableConvBlock(channels[1][0],channels[1][1]),#(64, 64)
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64, 64)
                channels = channels[2][1],
                pooling = pooling[1] # pooling = [3,3]
            )
            block3 = FF_customblock(
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels=channels[3][1],
                pooling = pooling[2] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2, block3)

        if config['FF_block_nums'] == 4:
            pooling = config['ds_pooling_4']
            block1 = FF_customblock(
                initial_block(),
                SeparableConvBlock(channels[0][0],channels[0][1]),#(64,64)
                #SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                channels = channels[0][1],
                pooling = pooling[0] # pooling = [3,3]
            )
            block2 = FF_customblock(
                SeparableConvBlock(channels[1][0],channels[1][1]),#(64,64)
                #SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                channels = channels[1][1],
                pooling = pooling[1] # pooling = [3,3]
            )
            block3 = FF_customblock(
                SeparableConvBlock(channels[2][0],channels[2][1]),#(64,64)
                channels = channels[2][1],
                pooling = pooling[2] # pooling = [3,3]   
            )
            block4 = FF_customblock(
                SeparableConvBlock(channels[3][0],channels[3][1]),#(64,64)
                channels=channels[3][1],
                pooling = pooling[3] # pooling = [3,3]
            )
            self.blocks = FF_sequential(block1, block2, block3,block4)

    def forward(self, x, y):
        x = self.blocks(x,y)
        hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
        return torch.cat(hs, dim=1)

