import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.init as init
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, out_channels, initial_bl=False ,stride=1):
        super(ResidualBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.initial_bl = initial_bl
        if initial_bl:
            self.norm0 = nn.LayerNorm([3,input_sz,input_sz], elementwise_affine=False)
            self.conv0 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.LayerNorm([self.in_channels,input_sz,input_sz], elementwise_affine=False)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.norm2 = nn.LayerNorm([self.out_channels,output_sz,output_sz], elementwise_affine=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                #nn.LayerNorm([self.out_channels,output_sz,output_sz], elementwise_affine=False),
            )
        self.relu = nn.ReLU(inplace=True)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        if initial_bl:
            self.error_num += in_channels*input_sz*input_sz
            self.gradient_num += 9*in_channels*in_channels
            self.activation_num = 3*input_sz*input_sz
        # to compute the gradients it actually needs the previous act and later err
        self.error_num += out_channels*output_sz*output_sz + out_channels*output_sz*output_sz
        self.activation_num += in_channels*input_sz*input_sz +out_channels*output_sz*output_sz
        self.gradient_num += 9*in_channels*out_channels + 9*out_channels*out_channels
        if stride != 1 or in_channels != out_channels:
            #share input act
            self.error_num += out_channels*output_sz*output_sz
            self.gradient_num += in_channels*out_channels
            
    def forward(self, x):
        if self.initial_bl:
            x = self.norm0(x)
            x = self.conv0(x)
        x = self.norm1(x)
        identity = x
        #out = self.norm1(x)
        out = self.conv1(x)
        out = self.relu(out)
        out = self.norm2(out)
        out = self.conv2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class FF_sequential(nn.Sequential):
    def forward(self, x, y, opts):
        for module in self._modules.values():
            _, x = module(x, y, opts)
        return x
    
class FF_customblock(nn.Sequential):
    # mask is its last channel number
    def __init__(self, *args, mask, pool_en=0, downsample_en=[2,2,0,8]):
        super(FF_customblock, self).__init__(*args)
        self.pool_en = pool_en
        mask = mask*pool_en*pool_en
        self.lc = nn.Linear(10, mask)
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i >= len(modules_list) - 1:
                self.activation_num += 10
                self.gradient_num += mask*10
                self.error_num += mask
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
            self.error_num += module.error_num
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:  # Check if it's the last module
                break
            x = module(x)
        if self.pool_en:
            pool = nn.AdaptiveMaxPool2d((self.pool_en, self.pool_en))
            x_temp = pool(x)
        else:
            x_temp = x
        c = self.lc(y)
        c_reshaped = c.view(c.shape[0], x_temp.shape[1], x_temp.shape[2], x_temp.shape[3])
        f = x_temp*c_reshaped
        self.h = f
        return f, x
    
class FF_customblock_v2(nn.Sequential):
    # mask is its last channel number
    def __init__(self, *args, mask, pool_en=0, downsample_en=[2,2,0,8]):
        super(FF_customblock_v2, self).__init__(*args)
        self.pool_en = pool_en
  
        self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=mask, out_channels=pool_en, kernel_size=downsample_en[0], \
                stride=downsample_en[1], padding=downsample_en[2], bias=False, groups=1),
                nn.AdaptiveMaxPool2d((downsample_en[3], downsample_en[3])),
            )
        in_channels = mask
        mask = pool_en * downsample_en[3] * downsample_en[3]
        
        self.lc = nn.Linear(10, mask)
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1: # label channel
                self.activation_num += 10
                self.gradient_num += mask*10
                self.error_num += mask
                break
            elif i == len(modules_list) - 2: #downsample
                self.error_num += mask
                self.activation_num += in_channels*downsample_en[3]*downsample_en[3]
                self.gradient_num += downsample_en[0]*downsample_en[0]*pool_en*in_channels
            else:
                self.activation_num += module.activation_num
                self.gradient_num += module.gradient_num
                self.error_num += module.error_num
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 2:  # Check if it's the last module
                break
            x = module(x)
        if self.pool_en:
            #pool = nn.AdaptiveMaxPool2d((self.pool_en, self.pool_en))
            #x_temp = pool(x)
            x_temp = self.downsample(x)
        else:
            x_temp = x
        c = self.lc(y)
        c_reshaped = c.view(c.shape[0], x_temp.shape[1], x_temp.shape[2], x_temp.shape[3])
        f = x_temp*c_reshaped
        self.h = f
        return f, x

class Resnet_ff_new(nn.Module):
    def __init__(self, combo, downsample_en=[6,4,3,8]):
        super(Resnet_ff_new, self).__init__()
        if combo == 0: # (a+b)+(c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    mask = 32, pool_en=downsample_en[0])  # kernel stride padding ouputsz
            CB2 = FF_customblock(ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=downsample_en[1])
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 1: # (a+b)+c
            CB1 = FF_customblock_v2(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    mask = 32, pool_en=16, downsample_en=[3,downsample_en[0],downsample_en[1],downsample_en[2]])  # kernel stride padding ouputsz
            CB2 = FF_customblock(ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=downsample_en[3])
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 2: # (a+b)+c
            CB1 = FF_customblock_v2(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    mask = 32, pool_en=downsample_en[0], downsample_en=[3,1,1,downsample_en[2]])  # kernel stride padding ouputsz
            CB2 = FF_customblock_v2(ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=downsample_en[1], downsample_en=[3,1,1,downsample_en[3]])
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 3: # (a) + (b) + (c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    mask = 16, pool_en=downsample_en[0])
            CB2 = FF_customblock(ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    mask = 32, pool_en=downsample_en[1])
            CB3 = FF_customblock(ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=downsample_en[2])
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 4: # a + (b+c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    mask = 16, pool_en=downsample_en[0])  # kernel stride padding ouputsz
            CB2 = FF_customblock(ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=downsample_en[1])
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 5: # (a+b+c)+(d)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    mask = 64, pool_en=4)
            CB2 = FF_customblock(ResidualBlock(input_sz=8, output_sz=4, in_channels=64, out_channels=128, stride=2), mask=128, pool_en=2)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 6: # (a)+(b+c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),mask = 16, pool_en=6)
            CB2 = FF_customblock(ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                 ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                 mask = 64, pool_en=3)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 7: # (a)+(2b+c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=3, out_channels=16, initial_bl=True), mask = 16, pool_en=6)
            CB2 = FF_customblock(ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                 ResidualBlock(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1),
                                 ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                 mask = 64, pool_en=3)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 8: # (a)+(2b+c)+(c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=3, out_channels=16, initial_bl=True), mask = 16, pool_en=4)
            CB2 = FF_customblock(ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                 ResidualBlock(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1),
                                 ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                 mask=64, pool_en=2)
            CB3 = FF_customblock(ResidualBlock(input_sz=8, output_sz=8, in_channels=64, out_channels=64, stride=1),
                                 mask=64, pool_en=2)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 9: # (a+b)+(b+2c)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=3, out_channels=16, initial_bl=True),
                                ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2), 
                                mask = 32, pool_en=4)
            CB2 = FF_customblock(ResidualBlock(input_sz=16, output_sz=16, in_channels=32, out_channels=32),
                                ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                ResidualBlock(input_sz=8, output_sz=8, in_channels=64, out_channels=64),
                                mask = 64, pool_en=3)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 10: # (a+b)+(b+2c+d)
            CB1 = FF_customblock(ResidualBlock(input_sz=32, output_sz=32, in_channels=3, out_channels=16, initial_bl=True),
                                 ResidualBlock(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2), mask = 32, pool_en=4)
            CB2 = FF_customblock(ResidualBlock(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1),
                                 ResidualBlock(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                 ResidualBlock(input_sz=8, output_sz=8, in_channels=64, out_channels=64), 
                                 ResidualBlock(input_sz=8, output_sz=4, in_channels=64, out_channels=128, stride=2),
                                 mask = 128, pool_en=2)
            self.blocks = FF_sequential(CB1, CB2)
    
    def forward(self, x, y, opts=None, diff_res=np.array([1,1])):
        x = self.blocks(x, y ,opts)
        #can neglect this diff_res if it is only training. Just work as inference's weights on layers' goodness
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).sum(1)*1)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)

class ResidualBlock_bp(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, out_channels, initial_bl=False ,stride=1):
        super(ResidualBlock_bp, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.initial_bl = initial_bl
        if initial_bl:
            self.conv0 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
            self.norm0 = nn.BatchNorm2d(self.in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        self.relu = nn.ReLU(inplace=True)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        if initial_bl:
            self.error_num += in_channels*input_sz*input_sz
            self.activation_num += 3*input_sz*input_sz
            self.gradient_num += 9*3*in_channels
        self.activation_num += in_channels*input_sz*input_sz + out_channels*output_sz*output_sz
        self.gradient_num += 9*in_channels*out_channels + 9*out_channels*out_channels
        self.error_num += out_channels*output_sz*output_sz + out_channels*output_sz*output_sz
        if stride != 1 or in_channels != out_channels:
            self.error_num += out_channels*output_sz*output_sz
            #self.activation_num += in_channels*input_sz*input_sz #share input act
            self.gradient_num += 1*in_channels*out_channels

    def forward(self, x):
        if self.initial_bl:
            x = self.conv0(x)
            x = self.norm0(x)
            x = self.relu(x)
            
        identity = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        return out

class Resnet_bp(nn.Module):
    def __init__(self, combo):
        super(Resnet_bp, self).__init__()
        if combo == 4:
            self.blocks = nn.Sequential(ResidualBlock_bp(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock_bp(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    ResidualBlock_bp(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2))  
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.flatten = nn.Flatten()
            num_out = 64
            self.lc = nn.Linear(num_out, 10)
        elif combo == 5:
            self.blocks = nn.Sequential(ResidualBlock_bp(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock_bp(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    ResidualBlock_bp(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1),
                                    ResidualBlock_bp(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2))
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.flatten = nn.Flatten()
            num_out = 64
            self.lc = nn.Linear(num_out, 10)
        elif combo == 6:
            self.blocks = nn.Sequential(ResidualBlock_bp(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                    ResidualBlock_bp(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                    ResidualBlock_bp(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1),
                                    ResidualBlock_bp(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                    ResidualBlock_bp(input_sz=8, output_sz=8, in_channels=64, out_channels=64, stride=1))
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.flatten = nn.Flatten()
            num_out = 64
            self.lc = nn.Linear(num_out, 10)
        elif combo == 7:
            self.blocks = nn.Sequential(ResidualBlock_bp(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True),
                                        ResidualBlock_bp(input_sz=32, output_sz=16, in_channels=16, out_channels=32, stride=2),
                                        ResidualBlock_bp(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2),
                                        ResidualBlock_bp(input_sz=8, output_sz=4, in_channels=64, out_channels=128, stride=2)) 
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.flatten = nn.Flatten()
            num_out = 128
            self.lc = nn.Linear(num_out, 10)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        for block in self.blocks:
            self.activation_num += block.activation_num
            self.gradient_num += block.gradient_num
            self.error_num += block.error_num
        self.error_num += 10
        self.gradient_num += 10*num_out  
        self.activation_num += num_out   
        
        
    def forward(self, x):
        x = self.blocks(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.lc(x)
        return x

# depth-wise + point-wise init in_channels = 8
class depthwise_block(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, output_channels, initial_bl=False ,stride=1):
        super(depthwise_block, self).__init__()
        if initial_bl:
            self.norm0 = nn.LayerNorm([3,input_sz,input_sz], elementwise_affine=False)
            self.conv0 = nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
            #self.bn0 = nn.BatchNorm2d(8)
            input_sz = input_sz//2
        self.norm1 = nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=False)
        self.dwconv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.norm2 = nn.LayerNorm([in_channels,output_sz,output_sz], elementwise_affine=False)
        self.pwconv1 = nn.Conv2d(in_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False)
        #self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.initial_bl = initial_bl
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        if initial_bl:
            self.activation_num += 3*96*96
            self.gradient_num += 9*3*8
            self.error_num += 8*(input_sz)*(input_sz)
        
        self.activation_num += in_channels*(input_sz)*(input_sz) + output_channels*(output_sz)*(output_sz)
        self.gradient_num += 9*in_channels + 1*in_channels*output_channels
        self.error_num += output_channels*(output_sz)*(output_sz) + in_channels*(output_sz)*(output_sz)
        
        
    def forward(self, x):
        if self.initial_bl:
            x = self.norm0(x)
            x = self.conv0(x)
            #x = self.bn0(x)
            x = self.relu(x)
        x = self.norm1(x)
        x = self.dwconv1(x)
        x = self.pwconv1(x)
        #x = self.bn(x)
        x = self.relu(x)
        return x

class InvertedResidual(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, output_channels, stride, expand_ratio, initial_bl=False):
        super(InvertedResidual, self).__init__()
        
        hidden_dim = round(in_channels * expand_ratio)
        self.use_res_connect = stride == 1 and in_channels == output_channels
        
        layers = []
        if initial_bl:
            layers.extend([
                nn.LayerNorm([3,input_sz,input_sz], elementwise_affine=False),
                nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
            ])
            input_sz = input_sz//2
        # 1x1 pointwise convolution to expand
        if expand_ratio != 1:
            layers.extend([
                nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=False),
                nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
                nn.ReLU6(inplace=True)
            ])
        
        # 3x3 depthwise convolution
        layers.extend([
            nn.LayerNorm([hidden_dim,input_sz,input_sz], elementwise_affine=False),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False),
            nn.ReLU6(inplace=True),
        ])
        
        if stride == 2:
            input_sz = input_sz//2
        
        # 1x1 pointwise convolution to project (and squeeze)
        layers.extend([
            nn.LayerNorm([hidden_dim,input_sz,input_sz], elementwise_affine=False),
            nn.Conv2d(hidden_dim, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
            # Note: No ReLU here!
        ])
        
        self.block = nn.Sequential(*layers)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        if initial_bl:
            self.activation_num += 32*(48)*(48)
            self.gradient_num += 9*3*32 + 32
            self.initial_act_num = 3*(96)*(96)
        else:
            self.initial_act_num = in_channels*(input_sz)*(input_sz)
        
        self.activation_num += hidden_dim*(output_sz)*(output_sz) + hidden_dim*(output_sz)*(output_sz)
        self.gradient_num += 9*hidden_dim + 1*hidden_dim*output_channels + hidden_dim + output_channels
        if expand_ratio != 1:
            self.activation_num += hidden_dim*(input_sz)*(input_sz)
            self.gradient_num += 1*in_channels*hidden_dim + hidden_dim
            
        
    def forward(self, x):
        if self.use_res_connect:
            return x + self.block(x)
        else:
            return self.block(x)
class BP_depthwise_block(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, output_channels, initial_bl=False ,stride=1):
        super(BP_depthwise_block, self).__init__()
        self.initial_bl = initial_bl
        if initial_bl:
            self.conv0 = nn.Conv2d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)
            self.norm0 = nn.BatchNorm2d(8)
            input_sz = input_sz//2
        self.dwconv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.pwconv1 = nn.Conv2d(in_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.norm1 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        if initial_bl:
            self.activation_num += 3*96*96
            self.error_num += 8*48*48
            self.gradient_num += 9*3*8
        self.activation_num += in_channels*(input_sz)*(input_sz) + output_channels*(output_sz)*(output_sz)
        self.gradient_num += 9*in_channels + 1*in_channels*output_channels
        self.error_num += in_channels*(output_sz)*(output_sz) + output_channels*(output_sz)*(output_sz)
    
    def forward(self,x):
        if self.initial_bl:
            x = self.conv0(x)
            x = self.norm0(x)
            x = self.relu(x)
        x = self.dwconv1(x)
        x = self.pwconv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        return x
    
class FF_mobileblock(nn.Sequential):
    def __init__(self, *args, mask, pool_en=0, label_len = 1):
        super(FF_mobileblock, self).__init__(*args)
        mask = mask*pool_en*pool_en
        self.lc = nn.Linear(label_len, mask)
        self.pool_en = pool_en
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:
                self.activation_num += 2
                self.gradient_num += mask*label_len
                self.error_num += mask
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
            self.error_num += module.error_num
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:  # Check if it's the last module
                break
            x = module(x)
        if self.pool_en:
            pool = nn.AdaptiveMaxPool2d((self.pool_en, self.pool_en))
            x_temp = pool(x)
        else:
            x_temp = x
        c = self.lc(y)
        c_reshaped = c.view(c.shape[0], x_temp.shape[1], x_temp.shape[2], x_temp.shape[3])
        f = x_temp*c_reshaped
        self.h = f
        return f, x
    
class FF_mobilenet_v1(nn.Module):
    def __init__(self, pool_list=[], combo=0, label_len=2):
        super(FF_mobilenet_v1, self).__init__()
        if combo == 0:  # (011) + (2233) + (333344) B
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12,   output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                 mask = 256, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 1: # (011) + (22333) + (33344)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                 mask = 256, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 2:  # (0112) + (23333) + (3344) D
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 mask = 64, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                 mask = 256, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 3:   # (011) + (2233) + (3333) + (44) C
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB4 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                 mask = 256, pool_en=pool_list[3], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3, CB4)
        elif combo == 4: # (01) + (1223) + (3333) + (344) E
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB4 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                 mask = 256, pool_en=pool_list[3], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3, CB4)
        elif combo == 5: # (011) + (2233) + (333)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 6: # (01) + (12233) + (333)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 7: # (01) + (122) + (333) + (33)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 mask = 64, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB4 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask= 128, pool_en=pool_list[3], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3, CB4)
        elif combo == 8: # (01) + (234)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 mask = 256, pool_en=pool_list[1], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 9: # (0) + (12) + (34)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 mask = 16, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 mask = 64, pool_en=pool_list[1], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 mask = 256, pool_en=pool_list[2], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        elif combo == 10:  # (011) + (2233) + (333344)
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1), 
                                 mask = 16, pool_en=pool_list[0], label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),                                                         
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=pool_list[0], label_len=label_len)
            CB4 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 mask = 64, pool_en=pool_list[1], label_len=label_len)
            CB5 = FF_mobileblock(depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 mask = 64, pool_en=pool_list[1], label_len=label_len)
            CB6 = FF_mobileblock(depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB7 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB8 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB9 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB10 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                  mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB11 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                  mask = 128, pool_en=pool_list[2], label_len=label_len)
            CB12 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                    mask = 256, pool_en=pool_list[3], label_len=label_len)
            CB13 = FF_mobileblock(depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                                    mask = 256, pool_en=pool_list[3], label_len=label_len)
            self.blocks = FF_sequential(CB1, CB2, CB3, CB4, CB5, CB6, CB7, CB8, CB9, CB10, CB11, CB12, CB13) 
        elif combo == 11:
            CB1 = FF_mobileblock(depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                                 depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                                 depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                                 mask = 32, pool_en=8, label_len=label_len)
            CB2 = FF_mobileblock(depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                                 depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                                 depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                                 depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                                 mask = 128, pool_en=4, label_len=label_len)
            CB3 = FF_mobileblock(depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                                 mask = 256, pool_en=2, label_len=label_len)
            
            self.blocks = FF_sequential(CB1, CB2, CB3)
  
    def forward(self, x, y, opts=None, diff_res=np.array([1,1])):
        x = self.blocks(x, y ,opts)
        #can neglect this diff_res if it is only training. Just work as inference's weights on layers' goodness
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).sum(1)*b/base)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)
    
class BP_mobilenet_v1(nn.Module):
    def __init__(self, combo = 0):
        super(BP_mobilenet_v1, self).__init__()
        if combo == 0:
            self.blocks = nn.Sequential(
                BP_depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                BP_depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                BP_depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                BP_depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                BP_depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                BP_depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                BP_depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
            )
        elif combo == 1:
            self.blocks = nn.Sequential(
                BP_depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                BP_depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                BP_depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                BP_depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                BP_depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                BP_depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
            )
        elif combo == 2:
            self.blocks = nn.Sequential(
                BP_depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                BP_depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                BP_depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                BP_depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                BP_depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
            )
        elif combo == 3:
            self.blocks = nn.Sequential(
                BP_depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                BP_depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                BP_depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                BP_depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                BP_depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                BP_depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
                BP_depthwise_block(input_sz=3, output_sz = 3, in_channels=256, output_channels=256, stride=1),
                BP_depthwise_block(input_sz=3, output_sz = 2, in_channels=256, output_channels=512, stride=2),
                BP_depthwise_block(input_sz=2, output_sz = 2, in_channels=512, output_channels=512, stride=1),
            )
        elif combo == 5:
            self.blocks = nn.Sequential(
                BP_depthwise_block(input_sz=96, output_sz = 48, in_channels=8, output_channels=16, initial_bl=True, stride=1),
                BP_depthwise_block(input_sz=48, output_sz = 24, in_channels=16, output_channels=32, stride=2),
                BP_depthwise_block(input_sz=24, output_sz = 24, in_channels=32, output_channels=32, stride=1),
                BP_depthwise_block(input_sz=24, output_sz = 12, in_channels=32, output_channels=64, stride=2),
                BP_depthwise_block(input_sz=12, output_sz = 12, in_channels=64, output_channels=64, stride=1),
                BP_depthwise_block(input_sz=12, output_sz = 6, in_channels=64, output_channels=128, stride=2),
                BP_depthwise_block(input_sz=6, output_sz = 6, in_channels=128, output_channels=128, stride=1),
                BP_depthwise_block(input_sz=6, output_sz = 3, in_channels=128, output_channels=256, stride=2),
            )

        if combo == 1:
            out_num = 128
            pool_size = 6
        elif combo == 3:
            out_num = 512
            pool_size = 2
        else:
            out_num = 256
            pool_size = 3
            
        self.flatten = nn.Flatten()
        self.avgpool = nn.AvgPool2d(pool_size, stride=1)
        #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(out_num, 2)
        
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        for block in self.blocks:
            self.activation_num += block.activation_num
            self.gradient_num += block.gradient_num
            self.error_num += block.error_num
        self.activation_num += 2
        self.gradient_num += 2*out_num
        self.error_num += out_num
    
    def forward(self, x):
        x = self.blocks(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    




## Just for motivational example on MNIST
"""
Start here
"""
class MLP_layer(nn.Module):
    def __init__(self, in_dim_feat, n_neurons, normalize_input=False):
        super(MLP_layer, self).__init__()
        self.normalize_input = normalize_input
        if self.normalize_input:
            self.norm = nn.LayerNorm(in_dim_feat, elementwise_affine=False)
        self.fc = nn.Linear(in_dim_feat, n_neurons)
        self.relu = nn.ReLU(True)

        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.activation_num += n_neurons
        self.gradient_num += in_dim_feat*n_neurons + n_neurons  # bias + weights
        self.initial_act_num = in_dim_feat
        
    def forward(self, h):
        if self.normalize_input:
            h = self.norm(h)
        h = self.fc(h)
        h = self.relu(h)
        self.h = h
        return h

class FF_MLP_Block(nn.Sequential):
    def __init__(self, *args, mask, inlen = 10):
        super(FF_MLP_Block, self).__init__(*args)
        self.fc = nn.Linear(inlen, mask)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:
                #self.activation_num += mask
                #self.gradient_num += mask*inlen + mask
                self.goodness_num = mask
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
        
        self.activation_num += modules_list[0].initial_act_num #compensate the input activation number
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:  # Check if it's the last module
                break
            x = module(x)
        c = self.fc(y)
        # * or + ?
        f = x*c
        self.h = f
        return f, x

class FF_MLP_Net(nn.Module):
    def __init__(self, nn_strct, label_len=10):
        super(FF_MLP_Net, self).__init__()
        CB1 = FF_MLP_Block(MLP_layer(nn_strct[0], nn_strct[1], normalize_input=False), inlen=label_len ,mask=nn_strct[1])
        CB2 = FF_MLP_Block(MLP_layer(nn_strct[1], nn_strct[2], normalize_input=True), mask=nn_strct[2])
        CB3 = FF_MLP_Block(MLP_layer(nn_strct[2], nn_strct[3], normalize_input=True), mask=nn_strct[2])
        self.blocks = FF_sequential(CB1, CB2, CB3)
        
    def forward(self, x, y, opts=None, diff_res=np.array([1,1])):
        x = self.blocks(x, y ,opts)
        #can neglect this diff_res if it is only training. Just work as inference's weights on layers' goodness
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).mean(1)*b/base)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)
    
#simple conv2d block for ff
class ff_conv2d(nn.Module):
    def __init__(self, input_sz, output_sz, kernel_size, in_channels, output_channels, stride=1, padding=1, norm=True):
        super(ff_conv2d, self).__init__()
        self.norm = norm
        if norm:
            self.norm0 = nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=False)
        self.conv0 = nn.Conv2d(in_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
        # estimate the actvation and gradient number
        self.activation_num = in_channels*(input_sz)*(input_sz)
        self.gradient_num = 9*in_channels*output_channels
        self.error_num = output_channels*(output_sz)*(output_sz)
    
    def forward(self, x):
        if self.norm:
            x = self.norm0(x)
        x = self.conv0(x)
        x = self.relu(x)
        return x

class ff_conv2d_block(nn.Sequential):
    def __init__(self, *args, mask, pool_en=0, label_len = 10):
        super(ff_conv2d_block, self).__init__(*args)
        mask = mask*pool_en*pool_en
        self.lc = nn.Linear(label_len, mask)
        self.pool_en = pool_en
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:
                self.activation_num += label_len
                self.gradient_num += mask*label_len  # bias + weights
                self.error_num += mask
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
            self.error_num += module.error_num
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            if i == len(modules_list) - 1:  # Check if it's the last module
                break
            x = module(x)
        if self.pool_en:
            pool = nn.AdaptiveMaxPool2d((self.pool_en, self.pool_en))
            x_temp = pool(x)
        else:
            x_temp = x
        c = self.lc(y)
        c_reshaped = c.view(x_temp.shape[0], x_temp.shape[1], x_temp.shape[2], x_temp.shape[3])
        f = x_temp*c_reshaped
        self.h = f
        return f, x

class ff_conv2d_oldblock(nn.Sequential):
    def __init__(self, *args, mask, pool_en=0, label_len = 10):
        super(ff_conv2d_oldblock, self).__init__(*args)
        mask = mask*pool_en*pool_en
        self.pool_en = pool_en
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
        self.activation_num += modules_list[0].init_act
        self.goodness_num = mask
        
    def forward(self, x, y, opts):
        # original forward
        modules_list = list(self._modules.values())
        for i, module in enumerate(modules_list):
            x = module(x)
        if self.pool_en:
            pool = nn.AdaptiveMaxPool2d((self.pool_en, self.pool_en))
            x_temp = pool(x)
        else:
            x_temp = x
        self.h = x_temp
        return x_temp, x

class mnist_conv2d(nn.Module):
    def __init__(self, combo=0):
        super(mnist_conv2d, self).__init__()
        if combo == 0: # (a+b)+(c)
            CB1 = ff_conv2d_block(ff_conv2d(input_sz=28, output_sz=14, kernel_size=3, in_channels=1, output_channels=8, stride=2, padding=1, norm=True),
                                  ff_conv2d(input_sz=14, output_sz=7, kernel_size=3, in_channels=8, output_channels=16, stride=2, padding=1, norm=True),
                                  mask = 16, pool_en=4)
            CB2 = ff_conv2d_block(ff_conv2d(input_sz=7, output_sz=4, kernel_size=3, in_channels=16, output_channels=32, stride=2, padding=1, norm=True),
                                  mask = 32, pool_en=2)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 1: # (a+b)+(c) original FF
            CB1 = ff_conv2d_oldblock(ff_conv2d(input_sz=30, output_sz=15, kernel_size=3, in_channels=1, output_channels=8, stride=2, padding=1, norm=True),
                                     ff_conv2d(input_sz=15, output_sz=8, kernel_size=3, in_channels=8, output_channels=16, stride=2, padding=1, norm=True),
                                        mask = 16, pool_en=8)
            CB2 = ff_conv2d_oldblock(ff_conv2d(input_sz=8, output_sz=4, kernel_size=3, in_channels=16, output_channels=32, stride=2, padding=1, norm=True),
                                     mask = 32, pool_en=4)
            self.blocks = FF_sequential(CB1, CB2)
        elif combo == 2: # (a+b)+(c) original FF
            CB1 = ff_conv2d_oldblock(ff_conv2d(input_sz=30, output_sz=15, kernel_size=3, in_channels=1, output_channels=8, stride=2, padding=1, norm=True),
                                     ff_conv2d(input_sz=15, output_sz=8, kernel_size=3, in_channels=8, output_channels=16, stride=2, padding=1, norm=True),
                                        mask = 16, pool_en=8)
            CB2 = ff_conv2d_oldblock(ff_conv2d(input_sz=8, output_sz=4, kernel_size=3, in_channels=16, output_channels=32, stride=2, padding=1, norm=True),
                                     mask = 32, pool_en=4)
            CB3 = ff_conv2d_oldblock(ff_conv2d(input_sz=4, output_sz=2, kernel_size=3, in_channels=32, output_channels=64, stride=2, padding=1, norm=True),
                                     mask = 64, pool_en=2)
            self.blocks = FF_sequential(CB1, CB2, CB3)
        
    def forward(self, x, y, opts=None, diff_res=np.array([1,1])):
        x = self.blocks(x, y ,opts)
        #can neglect this diff_res if it is only training. Just work as inference's weights on layers' goodness
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).sum(1)*b/base)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)
    
class mnist_conv2d_bp(nn.Module):
    def __init__(self, combo=0):
        super(mnist_conv2d_bp, self).__init__()
        if combo == 0:
            self.blocks = nn.Sequential(
                ff_conv2d(input_sz=28, output_sz=14, kernel_size=3, in_channels=1, output_channels=8, stride=2, padding=1, norm=True),
                ff_conv2d(input_sz=14, output_sz=7, kernel_size=3, in_channels=8, output_channels=16, stride=2, padding=1, norm=True),
                ff_conv2d(input_sz=7, output_sz=4, kernel_size=3, in_channels=16, output_channels=32, stride=2, padding=1, norm=True),
            ) # only difference in ff and bp layer is its norm
        self.flatten = nn.Flatten()
        #self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32*4*4, 10)
        
        # estimate the actvation and gradient number
        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        for i, module in enumerate(self.blocks):
            if i == len(self.blocks) - 1:
                self.activation_num += 32*4*4
                self.error_num += 10*32*4*4
                self.gradient_num += 10*32*4*4
                break
            self.activation_num += module.activation_num
            self.gradient_num += module.gradient_num
            self.error_num += module.error_num
        
        
    def forward(self, x):
        x = self.blocks(x)
        #x = self.pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
    
""""
END here
"""""