"""dense net in pytorch



[1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger.

    Densely Connected Convolutional Networks
    https://arxiv.org/abs/1608.06993v5
"""
import torch
import torch.nn as nn
from adjointNetwork import conv2dFirstLayer, conv2dAdjoint, batchNorm, linear   
import gc
#from torch.utils.checkpoint.checkpoint

#"""Bottleneck layers. Although each layer only produces k
#output feature-maps, it typically has many more inputs. It
#has been noted in [37, 11] that a 1×1 convolution can be in-
#troduced as bottleneck layer before each 3×3 convolution
#to reduce the number of input feature-maps, and thus to
#improve computational efficiency."""
class Bottleneck(nn.Module):
    def __init__(self, in_channels, growth_rate, mask_layer, compression_factor, index):
        super().__init__()
        #"""In  our experiments, we let each 1×1 convolution
        #produce 4k feature-maps."""
        inner_channel = 4 * growth_rate
        self.in_channels = in_channels
        self.growth_rate = growth_rate
        self.compression_factor = compression_factor
        self.index = index
        self.mask_layer = mask_layer
        #"""We find this design especially effective for DenseNet and
        #we refer to our network with such a bottleneck layer, i.e.,
        #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` ,
        #as DenseNet-B."""
        
        self.bottle_neck = nn.Sequential(
            batchNorm(in_channels,mask_layer),
            nn.ReLU(inplace=True),
            conv2dAdjoint(in_channels, inner_channel, kernel_size=1, stride=1, padding=0, bias=False,mask_layer=mask_layer,compression_factor=compression_factor),
            batchNorm(inner_channel, mask_layer),
            nn.ReLU(inplace=True),
            conv2dAdjoint(inner_channel, growth_rate, kernel_size=3, stride=1, padding=1, bias=False,mask_layer=mask_layer,compression_factor=compression_factor)
        )

        '''
        self.bottle_neck = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        '''

    def forward(self, x):
        y = torch.utils.checkpoint.checkpoint(self.bottle_neck, x)
        l,c,_,_ = y.shape
        if not self.mask_layer or self.index==0:
            z = torch.cat([x, y], 1)
            return z[:z.shape[0]//2], z[z.shape[0]//2:]
        if self.index == 1:
            #x[l//2:, -self.growth_rate//self.compression_factor:] = y[l//2:, :c//self.compression_factor]
            #y[l//2:, :c//self.compression_factor] = 0.0
            z1 = torch.cat([x[:l//2], y[:l//2]], 1)
            z2 = torch.cat([x[l//2:, :-self.growth_rate//self.compression_factor], 
                                    y[l//2:, :c//self.compression_factor], 
                                    x[l//2:, -self.growth_rate//self.compression_factor:], 
                                    y[l//2:, c//self.compression_factor:]], 1)
            z = torch.cat([z1,z2],0)
        else:
            #GPUtil.showUtilization()
            z1 = torch.cat([x[:l//2], y[:l//2]], 1)
            z2 = torch.cat([x[l//2:, :-(self.growth_rate*self.index)//self.compression_factor], 
                                     y[l//2:, :c//self.compression_factor], 
                                     x[l//2:, (-(self.growth_rate*self.index)//self.compression_factor):],
                                     y[l//2:, c//self.compression_factor:]], 1)
            #z = torch.cat([z1,z2],0)
        return z1, z2

class Merge(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #print(z1.shape)
        #return z1
        torch.cuda.empty_cache()
        return torch.utils.checkpoint.checkpoint(torch.cat, [x[0],x[1]], 0)
        #return torch.cat([x[0],x[1]],0)

#"""We refer to layers between blocks as transition
#layers, which do convolution and pooling."""
class Transition(nn.Module):
    def __init__(self, in_channels, out_channels, mask_layer, compression_factor):
        super().__init__()
        #"""The transition layers used in our experiments
        #consist of a batch normalization layer and an 1×1
        #convolutional layer followed by a 2×2 average pooling
        #layer""".
        self.down_sample = nn.Sequential(
            batchNorm(in_channels, mask_layer),
            conv2dAdjoint(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, mask_layer=mask_layer,compression_factor=compression_factor),
            nn.AvgPool2d(2, stride=2)
        )

    def forward(self, x):
        return torch.utils.checkpoint.checkpoint(self.down_sample, x)
        #return self.down_sample(x)

#DesneNet-BC
#B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3))
#C stands for compression factor(0<=theta<=1)
class DenseNet(nn.Module):
    def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=1000, compression=4):
        super().__init__()
        self.growth_rate = growth_rate

        #"""Before entering the first dense block, a convolution
        #with 16 (or twice the growth rate for DenseNet-BC)
        #output channels is performed on the input images."""
        inner_channels = 2 * growth_rate

        #For convolutional layers with kernel size 3×3, each
        #side of the inputs is zero-padded by one pixel to keep
        #the feature-map size fixed.
        
        #self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False)
        self.conv1 = conv2dFirstLayer(3, inner_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.features = nn.Sequential()

        for index in range(len(nblocks) - 1):
            if index==0:
               mask_layer, compression_factor = False, 1
            else:
               mask_layer, compression_factor = True, compression 
            self.features.add_module("dense_block_layer_{}".format(index), self._make_dense_layers(block, inner_channels, nblocks[index], mask_layer, compression_factor))
            inner_channels += growth_rate * nblocks[index]

            #"""If a dense block contains m feature-maps, we let the
            #following transition layer generate θm output feature-
            #maps, where 0 < θ ≤ 1 is referred to as the compression
            #fac-tor.
            out_channels = int(reduction * inner_channels) # int() will automatic floor the value
            self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels, mask_layer, compression_factor))
            inner_channels = out_channels

        self.features.add_module("dense_block{}".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1], mask_layer, compression_factor))
        inner_channels += growth_rate * nblocks[len(nblocks) - 1]
        self.features.add_module('bn', batchNorm(inner_channels, mask_layer))
        self.features.add_module('relu', nn.ReLU(inplace=True))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.linear = linear(inner_channels, num_class)

    def forward(self, x):
        output = self.conv1(x)
        output = torch.utils.checkpoint.checkpoint(self.features, output)
        output = self.avgpool(output)
        output = output.view(output.size()[0], -1)
        output = self.linear(output)
        return output

    def _make_dense_layers(self, block, in_channels, nblocks, mask_layer, compression_factor):
        dense_block = nn.Sequential()
        for index in range(nblocks):
            if mask_layer:
               dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate,  mask_layer, compression_factor, index))
               dense_block.add_module('merge_layer_{}'.format(index), Merge())
               #in_channels += self.growth_rate//compression_factor
            else:
               dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate,  mask_layer, compression_factor, index))
               dense_block.add_module('merge_layer_{}'.format(index), Merge())
            in_channels += self.growth_rate
        return dense_block

def densenet121(compression, num_class):
    return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, compression=compression, num_class=num_class)

def densenet169():
    return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)

def densenet201():
    return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)

def densenet161():
    return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)

def densenet12(compression, num_class):
    return DenseNet(Bottleneck, [16, 16, 16], growth_rate=12)

def densenet40(compression, num_class):
    return DenseNet(Bottleneck, [31, 31, 31], growth_rate=40)

def densenet24(compression, num_class):
    return DenseNet(Bottleneck, [41, 41, 41], growth_rate=24)

