
"""
resnet for cifar in pytorch
Reference:
[1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
[2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
"""
import torch
import torch.nn as nn
import math
import torch.nn.functional as fcnal




from torch.autograd import Variable

def conv3x3(in_planes, out_planes, stride=1):
    """
    3x3 convolution with padding
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class PreActBasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, noise_coef=None):
        super(PreActBasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.noise_coef = noise_coef
        
    
    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out = self.relu(out)
        
        if self.downsample is not None:
            residual = self.downsample(out)
        
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        out += residual   #skip connection
        
        if self.noise_coef is not None: # Test Variable and rand
            return out + self.noise_coef * torch.randn_like(out)
            #return out + self.noise_coef * torch.std(out) * torch.randn_like(out)
            '''
            out1 = out.view(out.size(0),-1)
            clip = 1e-7 * out1.norm(p=2,dim=1,keepdim=True)
            out1=torch.max(out1.abs(),clip)
            out1=out1.view(out.size())
            return out + self.noise_coef * out1 * torch.randn_like(out)
            '''
            #return out + self.noise_coef * out * torch.randn_like(out)
        else:
            return out




class PreAct_ResNet_Cifar(nn.Module):
    def __init__(self, block, layers, num_classes=10, noise_coef=None):
        super(PreAct_ResNet_Cifar, self).__init__()
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 16, layers[0], noise_coef=noise_coef)
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2, noise_coef=noise_coef)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2, noise_coef=noise_coef)
        self.bn = nn.BatchNorm2d(64*block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64*block.expansion, num_classes)
        #self.loss = nn.CrossEntropyLoss()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    
    def _make_layer(self, block, planes, blocks, stride=1, noise_coef=None):
        downsample = None
        if stride != 1 or self.inplanes != planes*block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, noise_coef=noise_coef))
        self.inplanes = planes*block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, noise_coef=noise_coef))
        return nn.Sequential(*layers)
    
    #def forward(self, x, target):
    def forward(self, x):
        x = self.conv1(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        x = self.bn(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        #norm=x.norm(p=2,dim=1,keepdim=True)
        
        x= self.fc(x)
        
        #x= x + 1e-2 * norm * torch.randn_like(x)
        #loss = self.loss(x, target)
        
        #return x, loss
        return x


class Ensemble_PreAct_ResNet_Cifar(nn.Module):
    def __init__(self, block, layers, num_classes=10, num_ensembles=3, noise_coef=0.0):
        super(Ensemble_PreAct_ResNet_Cifar, self).__init__()
        self.num_ensembles = num_ensembles
        # for emsemble resnet we should use Noisy Blocks.
        self.ensemble = nn.ModuleList([PreAct_ResNet_Cifar(block, layers, num_classes=num_classes, noise_coef=noise_coef) for i in range(num_ensembles)])
        # self.ensemble = nn.ModuleList([ResNet_Cifar(block, layers, num_classes=num_classes) for i in range(num_ensembles)])
    
    def forward(self, x):
    #def forward(self, x, target):
        ret = 0.0
        for net in self.ensemble:
            ret += net(x)
            #ret += net(x, target)
        ret /= self.num_ensembles
        
        return ret



def en_preactresnet8_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [1, 1, 1], **kwargs)
    return model

def en_preactresnet20_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [3, 3, 3], **kwargs)
    return model

def en_preactresnet44_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [7, 7, 7], **kwargs)
    return model

def en_preactresnet32_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [5, 5, 5], **kwargs)
    return model
def en_preactresnet56_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [9, 9, 9], **kwargs)
    return model

def en_preactresnet110_cifar(**kwargs):
    model = Ensemble_PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
    return model


if __name__ == '__main__':
    net = en_preactresnet20_cifar()
    y = net(torch.autograd.Variable(torch.randn(1, 3, 32, 32)))
    print(net)
    print(y.size())
