import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import sys
sys.path.append('../')
sys.path.append('../dad')

from util.flops import *
from util.memory import *
from set.set import *
from module.operations import *
from set.operations import GatedConv2d as GatedConv2d_s

class GatedNet(nn.Module):
    def __init__(self):
        super(GatedNet, self).__init__()
        self.gated_layers = []
        self.use_gate = False
        self.full_flops = 1.
        self.full_mem = 1.

    def build_gate(self, gate_fn, argdicts={}):
        self.use_gate = True
        if isinstance(argdicts, dict):
            argdicts =  [argdicts]*len(self.gated_layers)
        for i, layer in enumerate(self.gated_layers):
            layer.build_gate(gate_fn, **argdicts[i])

    def build_gate_dep(self, gate_fn, argdicts={}):
        self.use_gate = True
        if isinstance(argdicts, dict):
            argdicts =  [argdicts]*len(self.gated_layers)
        for i, layer in enumerate(self.gated_layers):
            layer.build_gate_dep(gate_fn, **argdicts[i])

    def reset_dep(self):
        for i, layer in enumerate(self.gated_layers):
            layer.dgate.reset()

    def get_reg(self):
        reg = 0.
        for layer in self.gated_layers:
            reg += layer.get_reg()
        return reg

    def get_reg_dep(self):
        reg = 0.
        for layer in self.gated_layers:
            reg += layer.get_reg_dep()
        return reg

    def get_pruned_size(self):
        return [layer.get_num_active() for layer in self.gated_layers]
    def get_pruned_weight(self):
        return [layer.get_weight_nonactive() for layer in self.gated_layers]
    def get_pruned_weight_sum(self):
        sum=0
        for layer in self.gated_layers:
            sum+=layer.get_weight_nonactive()
        return sum

    def get_weight(self):
        return [layer.get_weight() for layer in self.gated_layers]

    def get_pruned_size_dep(self):
        return [int(layer.dgate.num_active) for layer in self.gated_layers]

    def count_flops(self, num_units):
        raise NotImplementedError

    def count_flops_dep(self, num_units, num_units_dep):
        raise NotImplementedError

    def count_memory(self, num_units):
        raise NotImplementedError

    def count_memory_dep(self, num_units, num_units_dep):
        raise NotImplementedError

    def get_speedup(self):
        pruned = self.get_pruned_size()
        return float(self.full_flops) / float(self.count_flops(pruned))

    def get_speedup_dep(self):
        pruned = self.get_pruned_size()
        pruned_dep = self.get_pruned_size_dep()
        return float(self.full_flops) / \
                float(self.count_flops_dep(pruned, pruned_dep))

    def get_memory_saving(self):
        pruned = self.get_pruned_size()
        return float(self.count_memory(pruned)) / float(self.full_mem)

    def get_memory_saving_dep(self):
        pruned = self.get_pruned_size()
        pruned_dep = self.get_pruned_size_dep()
        return float(self.count_memory_dep(pruned, pruned_dep)) / \
                float(self.full_mem)


"""
resnet18 for local/bb.py vs snip
"""

class ResidualBlock_gated(nn.Module):
    def __init__(self, inp, oup, stride=1):
        super(ResidualBlock_gated, self).__init__()
        self.stride=stride

        self.conv = GatedConv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = GatedConv2d(oup, oup, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup)

        if not self.stride==1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup))

    def forward(self, x):

        residual =x
        out=self.relu(self.bn(self.conv(x)))
        out=self.bn2(self.conv2(out))

        if self.stride==1:
            out =out + residual
        else:
            out= out + self.shortcut(residual)

        out = self.relu(out)


        return out
class Net_gated(GatedNet):
    def __init__(self, num_classes=10):
        super(Net_gated, self).__init__()
        block=ResidualBlock_gated
        num_blocks=[2,2,2,2]
        self.num_classes=num_classes
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512, num_classes)

        for layer in self.modules():
            if isinstance(layer,GatedConv2d):
                self.gated_layers.append(layer)
        self.full_flops = self.count_flops([
            64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])
        self.full_mem = self.count_memory([
             64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])


    def _make_layer(self, block, oup, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            layers.append(block(self.inp, oup, stride))
            self.inp = oup
        return nn.Sequential(*layers)


    def count_flops(self, num_units):   # some PUT it to _make_layer
            #num_units : prune candidate
            # count_flops_ : without mask generation
            flops = count_flops_conv(32, 32, 3, 64, 3, padding=1) \

            flops += count_flops_conv(32, 32, 64, num_units[0], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

            flops += count_flops_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                    +count_flops_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                    + count_flops_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

            flops +=  count_flops_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                    +count_flops_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                    +count_flops_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

            flops += count_flops_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                    +count_flops_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                    + count_flops_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \
                    + count_flops_global_avg_pool(4,4,num_units[15])

            flops += count_flops_dense(num_units[15], self.num_classes)
            return flops

    def count_flops_dep(self, num_units, num_units_dep):
        # _dep dbb ask only , whynot bb?
        flops = count_flops_conv(32, 32, 3, 64, 3, padding=1) \

        flops += count_flops_conv(32, 32, 64, num_units[0], 3, padding=1) \
                + count_flops_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                + count_flops_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                + count_flops_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

        flops += count_flops_conv(32, 32, num_units[3], num_units[4], 3,   stride=2,padding=1) \
                +count_flops_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                + count_flops_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                + count_flops_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                + count_flops_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

        flops +=  count_flops_conv(16, 16, num_units[7], num_units[8], 3,  stride=2,padding=1) \
                +count_flops_conv(16, 16, num_units[7], num_units[9], 1,  stride=2, padding=1) \
                +count_flops_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                + count_flops_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                + count_flops_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

        flops += count_flops_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                +count_flops_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                + count_flops_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                + count_flops_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                + count_flops_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \
                + count_flops_global_avg_pool(4,4,num_units[15])\
                +count_flops_conv_dbb(32, 32, num_units[0])\
                +count_flops_conv_dbb(32, 32, num_units[1])\
                +count_flops_conv_dbb(32, 32, num_units[2])\
                +count_flops_conv_dbb(32, 32, num_units[3])\
                +count_flops_conv_dbb(32, 32, num_units[4])\
                +count_flops_conv_dbb(16, 16, num_units[5])\
                +count_flops_conv_dbb(16, 16, num_units[6])\
                +count_flops_conv_dbb(16, 16, num_units[7])\
                +count_flops_conv_dbb(16, 16, num_units[8])\
                +count_flops_conv_dbb(8, 8, num_units[9])\
                +count_flops_conv_dbb(8, 8, num_units[10])\
                +count_flops_conv_dbb(8, 8, num_units[11])\
                +count_flops_conv_dbb(8, 8, num_units[12])\
                +count_flops_conv_dbb(4, 4, num_units[13])\
                +count_flops_conv_dbb(4, 4, num_units[14])\
                +count_flops_conv_dbb(4, 4, num_units[15])\


        flops += count_flops_dense(num_units[15], self.num_classes)


        return flops

    def count_memory(self, num_units):

        mem = count_memory_conv(32, 32, 3, 64, 3, padding=1) \

        mem += count_memory_conv(32, 32, 64, num_units[0], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

        mem += count_memory_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                +count_memory_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                + count_memory_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

        mem +=  count_memory_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                +count_memory_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                +count_memory_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

        mem += count_memory_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                +count_memory_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                + count_memory_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \


        mem += count_memory_dense(num_units[15], self.num_classes)

        return mem

    def count_memory_dep(self, num_units, num_units_dep):

        mem = count_memory_conv(32, 32, 3, 64, 3, padding=1) \

        mem += count_memory_conv(32, 32, 64, num_units[0], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

        mem += count_memory_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                +count_memory_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                + count_memory_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

        mem +=  count_memory_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                +count_memory_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                +count_memory_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

        mem += count_memory_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                +count_memory_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                + count_memory_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \
                +count_flops_conv_dbb(num_units[0])\
                +count_flops_conv_dbb(num_units[1])\
                +count_flops_conv_dbb(num_units[2])\
                +count_flops_conv_dbb(num_units[3])\
                +count_flops_conv_dbb(num_units[4])\
                +count_flops_conv_dbb(num_units[5])\
                +count_flops_conv_dbb(num_units[6])\
                +count_flops_conv_dbb(num_units[7])\
                +count_flops_conv_dbb(num_units[8])\
                +count_flops_conv_dbb(num_units[9])\
                +count_flops_conv_dbb(num_units[10])\
                +count_flops_conv_dbb(num_units[11])\
                +count_flops_conv_dbb(num_units[12])\
                +count_flops_conv_dbb(num_units[13])\
                +count_flops_conv_dbb(num_units[14])\
                +count_flops_conv_dbb(num_units[15])
        mem += count_memory_dense(num_units[15], self.num_classes)
        return mem

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out

#RESNET18
class Net(nn.Module):
    def __init__(self,  num_classes=10):
        super(Net, self).__init__()

        block=ResidualBlock
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512, num_classes)




    def _make_layer(self, block, oup, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            layers.append(block(self.inp, oup, stride))
            self.inp = oup
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)  #layer  14,15,16,17 (2 plain blocks )
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out


class Net_I(nn.Module):
    def __init__(self,  num_classes=10):
        super(Net_I, self).__init__()

        block=ResidualBlock
        num_blocks=[2,2,2,2]
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.linear = nn.Linear(512, num_classes)




    def _make_layer(self, block, oup, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            layers.append(block(self.inp, oup, stride))
            self.inp = oup
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out=self.maxpool(out)
        out = self.layer1(out)   # layer 2,3,4,5  (2 plain blocks )
        out = self.layer2(out)   # layer 6,7,8,9  (2 plain blocks )
        out = self.layer3(out)   # layer 10,11,12,13 (2 plain blocks )
        out = self.layer4(out)
         #layer  14,15,16,17 (2 plain blocks )
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out
class ResidualBlock(nn.Module):
    def __init__(self, inp, oup, stride=1):
        super(ResidualBlock, self).__init__()
        self.stride=stride

        self.conv = nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(oup, oup, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup)
        if stride >1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup))

    def forward(self, x):

        residual =x
        out=self.relu(self.bn(self.conv(x)))
        out=self.bn2(self.conv2(out))

        if self.stride==1:
            out =out + residual
        else:
            out= out + self.shortcut(residual)

        out = self.relu(out)


        return out
class ResidualBlock_shortcut(nn.Module):
    def __init__(self, inp, oup, stride=1):
        super(ResidualBlock_shortcut, self).__init__()
        self.stride=stride

        self.conv = nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(oup, oup, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup)

        self.shortcut = nn.Sequential(
            nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(oup))

    def forward(self, x):

        residual =x
        out=self.relu(self.bn(self.conv(x)))
        out=self.bn2(self.conv2(out))



        out= out + self.shortcut(residual)

        out = self.relu(out)


        return out

class ResidualBlock_set(nn.Module):
    def __init__(self, inp, oup, stride=1):
        super(ResidualBlock_set, self).__init__()
        self.stride=stride

        self.conv = GatedConv2d_s(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU()

        self.conv2 = GatedConv2d_s(oup, oup, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(oup)

        if not self.stride==1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inp, oup, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(oup))


    def forward(self, x,s):

        residual =x
        residual_s =s
        x,s=self.conv(x,s)
        x=self.relu(self.bn(x))
        s=F.relu(s)

        x,s=self.conv2(x,s)
        x=self.bn2(x)

        if self.stride==1:
            x =x + residual
            s=s+residual_s
        else:
            x= x + self.shortcut(residual)
            s=s+F.conv2d(input=residual_s,weight=self.shortcut[0].weight,bias=self.shortcut[0].bias,stride=self.shortcut[0].stride,padding=self.shortcut[0].padding,dilation=self.shortcut[0].dilation,groups=self.shortcut[0].groups)


        x = self.relu(x)
        s=F.relu(s)

        return x,s
class Net_set(GatedNet):
    def __init__(self, num_classes=10):
        super(Net_set, self).__init__()
        block=ResidualBlock_set
        num_blocks=[2,2,2,2]
        self.num_classes=num_classes
        self.inp = 64
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512, num_classes)

        for layer in self.modules():
            if isinstance(layer,GatedConv2d_s):
                self.gated_layers.append(layer)
        self.full_flops = self.count_flops([
            64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])
        self.full_mem = self.count_memory([
             64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])


    def _make_layer(self, block, oup, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            layers.append(block(self.inp, oup, stride))
            self.inp = oup
        return nn.ModuleList(layers)


    def count_flops(self, num_units):   # some PUT it to _make_layer
            #num_units : prune candidate
            # count_flops_ : without mask generation
            flops = count_flops_conv(32, 32, 3, 64, 3, padding=1) \

            flops += count_flops_conv(32, 32, 64, num_units[0], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

            flops += count_flops_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                    +count_flops_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                    + count_flops_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

            flops +=  count_flops_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                    +count_flops_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                    +count_flops_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

            flops += count_flops_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                    +count_flops_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                    + count_flops_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \
                    + count_flops_global_avg_pool(4,4,num_units[15])

            flops += count_flops_dense(num_units[15], self.num_classes)
            return flops

    def count_flops_dep(self, num_units, num_units_dep):

        return self.count_flops(num_units_dep)

    def count_memory(self, num_units):

        mem = count_memory_conv(32, 32, 3, 64, 3, padding=1) \

        mem += count_memory_conv(32, 32, 64, num_units[0], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

        mem += count_memory_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                +count_memory_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                + count_memory_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

        mem +=  count_memory_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                +count_memory_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                +count_memory_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

        mem += count_memory_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                +count_memory_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                + count_memory_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \


        mem += count_memory_dense(num_units[15], self.num_classes)

        return mem

    def count_memory_dep(self, num_units, num_units_dep):


        return self.count_memory(num_units_dep)
    def set_apply(self):
        self.set_func=SetTransformer(3072,1,3072)
        return True
    def forward(self, x,s):

        s=self.set_func(s).view(1,3,32,32)
        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        s=F.conv2d(input=s,weight=self.conv0.weight,bias=self.conv0.bias,stride=self.conv0.stride,padding=self.conv0.padding,dilation=self.conv0.dilation,groups=self.conv0.groups)
        s=F.relu(s)

        for b in self.layer1:
            out,s = b(out,s)   # layer 2,3,4,5  (2 plain blocks )
        for b in self.layer2:
            out,s = b(out,s)   # l
        for b in self.layer3:
            out,s = b(out,s)   #
        for b in self.layer4:
            out,s = b(out,s)   #
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out
class Net_set_I(GatedNet):
    def __init__(self, num_classes=10):
        super(Net_set_I, self).__init__()
        block=ResidualBlock_set
        num_blocks=[2,2,2,2]
        self.num_classes=num_classes
        self.inp = 64

        self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn0 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)  # compute on input (mem save)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)


        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.linear = nn.Linear(512, num_classes)

        for layer in self.modules():
            if isinstance(layer,GatedConv2d_s):
                self.gated_layers.append(layer)
        self.full_flops = self.count_flops([
            64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])
        self.full_mem = self.count_memory([
             64, 64, 64, 64,
            128, 128, 128, 128,
            256, 256, 256, 256,
            512, 512, 512, 512])


    def _make_layer(self, block, oup, num_block, stride=1):
        layers = []
        # stride when out features*2 is 2 = depth *2 and w,h /2
        strides = [stride] + [1]*(num_block-1)
        for stride in strides:
            layers.append(block(self.inp, oup, stride))
            self.inp = oup
        return nn.ModuleList(layers)


    def count_flops(self, num_units):   # some PUT it to _make_layer
            #num_units : prune candidate
            # count_flops_ : without mask generation
            flops = count_flops_conv(32, 32, 3, 64, 3, padding=1) \

            flops += count_flops_conv(32, 32, 64, num_units[0], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                    + count_flops_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

            flops += count_flops_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                    +count_flops_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                    + count_flops_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                    + count_flops_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

            flops +=  count_flops_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                    +count_flops_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                    +count_flops_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                    + count_flops_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

            flops += count_flops_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                    +count_flops_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                    + count_flops_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                    + count_flops_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \
                    + count_flops_global_avg_pool(4,4,num_units[15])

            flops += count_flops_dense(num_units[15], self.num_classes)
            return flops

    def count_flops_dep(self, num_units, num_units_dep):

        return self.count_flops(num_units_dep)

    def count_memory(self, num_units):

        mem = count_memory_conv(32, 32, 3, 64, 3, padding=1) \

        mem += count_memory_conv(32, 32, 64, num_units[0], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[0], num_units[1], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[1], num_units[2], 3, padding=1) \
                + count_memory_conv(32, 32, num_units[2], num_units[3], 3, padding=1) \

        mem += count_memory_conv(32, 32, num_units[3], num_units[4], 3, stride=2, padding=1) \
                +count_memory_conv(32, 32, num_units[3], num_units[5], 1, stride=2, padding=1) \
                + count_memory_conv(16, 16, num_units[4], num_units[5], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[5], num_units[6], 3, padding=1) \
                + count_memory_conv(16, 16, num_units[6], num_units[7], 3,  padding=1) \

        mem +=  count_memory_conv(16, 16, num_units[7], num_units[8], 3, stride=2, padding=1) \
                +count_memory_conv(16, 16,num_units[7] , num_units[9], 1,  stride=2, padding=1) \
                +count_memory_conv(8, 8, num_units[8], num_units[9], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[9], num_units[10], 3, padding=1) \
                + count_memory_conv(8, 8, num_units[10], num_units[11], 3, padding=1) \

        mem += count_memory_conv(8, 8, num_units[11], num_units[12], 3,  stride=2,padding=1) \
                +count_memory_conv(8, 8,num_units[11], num_units[13], 1,  stride=2, padding=1) \
                + count_memory_conv(4, 4, num_units[12], num_units[13], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[13], num_units[14], 3, padding=1) \
                + count_memory_conv(4, 4, num_units[14], num_units[15], 3, padding=1) \


        mem += count_memory_dense(num_units[15], self.num_classes)

        return mem

    def count_memory_dep(self, num_units, num_units_dep):


        return self.count_memory(num_units_dep)
    def set_apply(self):
        self.set_func=SetTransformer(150528,1,150528)
        return True
    def forward(self, x,s):

        s=self.set_func(s).view(1,3,224,224)
        out = self.relu(self.bn0(self.conv0(x)))  # layer1
        out=self.maxpool(out)
        s=F.conv2d(input=s,weight=self.conv0.weight,bias=self.conv0.bias,stride=self.conv0.stride,padding=self.conv0.padding,dilation=self.conv0.dilation,groups=self.conv0.groups)
        s=F.relu(s)
        s=self.maxpool(s)

        for b in self.layer1:
            out,s = b(out,s)   # layer 2,3,4,5  (2 plain blocks )
        for b in self.layer2:
            out,s = b(out,s)   # l
        for b in self.layer3:
            out,s = b(out,s)   #
        for b in self.layer4:
            out,s = b(out,s)   #
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)   #layer 18
        return out
