from torch import nn
from torch.nn import functional as F

class ResidualBlock(nn.Module):
    '''
    Residual Block
    '''
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
                nn.BatchNorm2d(outchannel),
                # nn.ReLU(inplace=True),
                nn.LeakyReLU(inplace=False),
                nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
                nn.BatchNorm2d(outchannel) )
        self.right = shortcut
        self.out_relu = nn.ReLU(inplace=False)

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out = out.clone() + residual
        return self.out_relu(out)

class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers
    """
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=False)(self.residual_function(x) + self.shortcut(x))

class ResNet34(nn.Module):
    '''
    实现主module：ResNet34
    ResNet34包含多个layer，每个layer又包含多个Residual block
    用子module来实现Residual block，用_make_layer函数来实现layer
    '''
    def __init__(self, n_outputs=10):
        super(ResNet34, self).__init__()
        self.model_name = 'resnet34'

        # 前几层: 图像转换
        self.pre = nn.Sequential(
                nn.Conv2d(3, 64, 3, 1, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=False),
                nn.MaxPool2d(3, 2, 1))
        
        # 重复的layer，分别有3，4，6，3个residual block
        self.layer1 = self._make_layer( 64, 128, 3)
        self.layer2 = self._make_layer( 128, 256, 4, stride=1)
        self.layer3 = self._make_layer( 256, 512, 6, stride=2)
        self.layer4 = self._make_layer( 512, 512, 3, stride=2)

        #分类用的全连接
        self.fc = nn.Linear(512, n_outputs)
    
    def _make_layer(self,  inchannel, outchannel, block_num, stride=1):
        '''
        构建layer,包含多个residual block
        '''
        shortcut = nn.Sequential(
                nn.Conv2d(inchannel,outchannel,1,stride, bias=False),
                nn.BatchNorm2d(outchannel))
        
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))
        
        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
        
    def forward(self, x):
        # print(x.size())
        # x = self.pre(x)
        # print(x.size())
        
        # x = self.layer1(x)
        # print(x.size())
        # x = self.layer2(x)
        # print(x.size())
        # x = self.layer3(x)
        # print(x.size())
        # x = self.layer4(x)
        # print(x.size())

        # x = F.avg_pool2d(x, 7)
        # print(x.size())
        # x = x.view(x.size(0), -1)
        # print(x.size())
        x = self.pre(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class ResNet18(nn.Module):
    '''
    Adapted ResNet18
    '''
    def __init__(self, n_outputs=10):
        super(ResNet18, self).__init__()
        self.model_name = 'resnet18'

        # 前几层: 图像转换
        self.pre = nn.Sequential(
                nn.Conv2d(3, 64, 3, 1, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=False),
                nn.MaxPool2d(3, 2, 1))
        
        # 重复的layer，分别有3，4，6，3个residual block
        self.layer1 = self._make_layer( 64, 128, 2)
        self.layer2 = self._make_layer( 128, 256, 3, stride=1)
        self.layer3 = self._make_layer( 256, 512, 5, stride=2)
        self.layer4 = self._make_layer( 512, 512, 2, stride=2)

        #分类用的全连接
        self.fc = nn.Linear(512, n_outputs)
    
    def _make_layer(self,  inchannel, outchannel, block_num, stride=1):
        '''
        构建layer,包含多个residual block
        '''
        shortcut = nn.Sequential(
                nn.Conv2d(inchannel,outchannel,1,stride, bias=False),
                nn.BatchNorm2d(outchannel))
        
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))
        
        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
        
    def forward(self, x):
        # print(x.size())
        # x = self.pre(x)
        # print(x.size())
        
        # x = self.layer1(x)
        # print(x.size())
        # x = self.layer2(x)
        # print(x.size())
        # x = self.layer3(x)
        # print(x.size())
        # x = self.layer4(x)
        # print(x.size())

        # x = F.avg_pool2d(x, 7)
        # print(x.size())
        # x = x.view(x.size(0), -1)
        # print(x.size())
        x = self.pre(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class NewResNet34(nn.Module):
    '''
    实现主module：ResNet34
    ResNet34包含多个layer，每个layer又包含多个Residual block
    用子module来实现Residual block，用_make_layer函数来实现layer
    '''
    def __init__(self, n_outputs=10):
        super(NewResNet34, self).__init__()
        self.model_name = 'new_resnet34'

        # 前几层: 图像转换
        self.pre = nn.Sequential(
                nn.Conv2d(3, 64, 3, 1, 1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=False),
                nn.MaxPool2d(3, 2, 1))
        
        # 重复的layer，分别有3，4，6，3个residual block
        self.layer1 = self._make_layer( 64, 128, 3)
        self.layer2 = self._make_layer( 128, 256, 4, stride=1)
        self.layer3 = self._make_layer( 256, 512, 6, stride=2)
        self.layer4 = self._make_layer( 512, 512, 3, stride=2)

        #分类用的全连接
        # self.fc = nn.Linear(512, n_outputs)
        self.new_fc = nn.Linear(512, n_outputs)
    
    def _make_layer(self,  inchannel, outchannel, block_num, stride=1):
        '''
        构建layer,包含多个residual block
        '''
        shortcut = nn.Sequential(
                nn.Conv2d(inchannel,outchannel,1,stride, bias=False),
                nn.BatchNorm2d(outchannel))
        
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))
        
        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
        
    def forward(self, x):
        # print(x.size())
        # x = self.pre(x)
        # print(x.size())
        
        # x = self.layer1(x)
        # print(x.size())
        # x = self.layer2(x)
        # print(x.size())
        # x = self.layer3(x)
        # print(x.size())
        # x = self.layer4(x)
        # print(x.size())

        # x = F.avg_pool2d(x, 7)
        # print(x.size())
        # x = x.view(x.size(0), -1)
        # print(x.size())
        x = self.pre(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        x = self.new_fc(x)
        return x


class ResNet50(nn.Module):

    def __init__(self, n_outputs=100):
        super().__init__()

        self.in_channels = 64
        block = BottleNeck
        num_block = [3, 4, 6, 3]
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, n_outputs)
        self.features = None

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer
        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output1 = self.conv1(x)
        output2 = self.conv2_x(output1)
        output3 = self.conv3_x(output2)
        output4 = self.conv4_x(output3)
        output5 = self.conv5_x(output4)
        if not self.training:
            self.features = [(output1.detach(), output2.detach()), (output2.detach(), output3.detach()),
                             (output3.detach(), output4.detach()), (output4.detach(), output5.detach())]
        output = self.avg_pool(output5)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output


class NewResNet50(nn.Module):

    def __init__(self, n_outputs=10):
        super().__init__()

        self.in_channels = 64
        block = BottleNeck
        num_block = [3, 4, 6, 3]
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.new_fc = nn.Linear(512 * block.expansion, n_outputs)
        self.features = None

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer
        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output1 = self.conv1(x)
        output2 = self.conv2_x(output1)
        output3 = self.conv3_x(output2)
        output4 = self.conv4_x(output3)
        output5 = self.conv5_x(output4)
        if not self.training:
            self.features = [(output1.detach(), output2.detach()), (output2.detach(), output3.detach()),
                             (output3.detach(), output4.detach()), (output4.detach(), output5.detach())]
        output = self.avg_pool(output5)
        output = output.view(output.size(0), -1)
        output = self.new_fc(output)

        return output


# add a new fc layer to discriminate real and fake attacks
class DisResNet50(nn.Module):

    def __init__(self, n_outputs=10):
        super().__init__()

        self.in_channels = 64
        block = BottleNeck
        num_block = [3, 4, 6, 3]
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False))
        # we use a different inputsize than the original paper
        # so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.new_fc = nn.Linear(512 * block.expansion, n_outputs)
        self.dis_fc = nn.Linear(512 * block.expansion, 2)
        self.features = None

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer
        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output1 = self.conv1(x)
        output2 = self.conv2_x(output1)
        output3 = self.conv3_x(output2)
        output4 = self.conv4_x(output3)
        output5 = self.conv5_x(output4)
        if not self.training:
            self.features = [(output1.detach(), output2.detach()), (output2.detach(), output3.detach()),
                             (output3.detach(), output4.detach()), (output4.detach(), output5.detach())]
        output = self.avg_pool(output5)
        output = output.view(output.size(0), -1)
        label_output = self.new_fc(output)
        dis_output = self.dis_fc(output)

        return label_output, dis_output