# -*- coding: utf-8 -*

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, args, if_bn, if_skip, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.if_skip = if_skip
        self.args = args
        if if_bn == False:
            if self.args.activation == 'tanh':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.Tanh(),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False)
                )
            elif self.args.activation == 'sigmoid':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.Sigmoid(),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False)
                )
            elif self.args.activation == 'leaky_relu':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False)
                )
            else:
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False)
                )

            self.shortcut = nn.Sequential()
            if stride != 1 or inchannel != outchannel:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False)
                )

        else:
            if self.args.activation == 'tanh':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel),
                    nn.Tanh(),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel)
                )
            elif self.args.activation == 'sigmoid':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel),
                    nn.Sigmoid(),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel)
                )
            elif self.args.activation == 'leaky_relu':
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel)
                )
            else:
                self.left = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(outchannel)
                )

            self.shortcut = nn.Sequential()
            if stride != 1 or inchannel != outchannel:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(outchannel)
                )

    def forward(self, x):
        out = self.left(x)
        if self.if_skip == True:
            out += self.shortcut(x)
        if self.args.activation == 'tanh':
            out = F.tanh(out)
        elif self.args.activation == 'sigmoid':
            out = F.sigmoid(out)
        elif self.args.activation == 'leaky_relu':
            out = F.leaky_relu(out)
        else:
            out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, args, ResidualBlock, blocks):
        super(ResNet, self).__init__()
        self.args = args
        if self.args.data == 'MNIST':
            init_channel = 1
        elif self.args.data == 'cifar10' or args.data == 'cifar100':
            init_channel = 3
        self.inchannel = 64
        if self.args.bn == 'no_bn':
            if self.args.activation == 'tanh':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.Tanh()
                )
            elif self.args.activation == 'sigmoid':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.Sigmoid()
                )
            elif self.args.activation == 'leaky_relu':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.LeakyReLU()
                )
            else:
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.ReLU()
                )
        else:
            if self.args.activation == 'tanh':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(64),
                    nn.Tanh()
                )
            elif self.args.activation == 'sigmoid':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(64),
                    nn.Sigmoid()
                )
            elif self.args.activation == 'leaky_relu':
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU()
                )
            else:
                self.conv1 = nn.Sequential(
                    nn.Conv2d(init_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(64),
                    nn.ReLU()
                )
        if self.args.skip == 1:
            self.layer1 = self.make_layer(False, ResidualBlock, 64,  blocks[0], stride=1)
            self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2)
            self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
            self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
        elif self.args.skip == 2:
            self.layer1 = self.make_layer(False, ResidualBlock, 64,  blocks[0], stride=1)
            self.layer2 = self.make_layer(False, ResidualBlock, 128, blocks[1], stride=2)
            self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
            self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
        elif self.args.skip == 3:
            self.layer1 = self.make_layer(False, ResidualBlock, 64,  blocks[0], stride=1)
            self.layer2 = self.make_layer(False, ResidualBlock, 128, blocks[1], stride=2)
            self.layer3 = self.make_layer(False, ResidualBlock, 256, blocks[2], stride=2)
            self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
        elif self.args.skip == 4:
            self.layer1 = self.make_layer(False, ResidualBlock, 64,  blocks[0], stride=1)
            self.layer2 = self.make_layer(False, ResidualBlock, 128, blocks[1], stride=2)
            self.layer3 = self.make_layer(False, ResidualBlock, 256, blocks[2], stride=2)
            self.layer4 = self.make_layer(False, ResidualBlock, 512, blocks[3], stride=2)
        else:
            if self.args.bn == 'no_1':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1, if_bn=False)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
            elif self.args.bn == 'no_2':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2, if_bn=False)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
            elif self.args.bn == 'no_3':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2, if_bn=False)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
            elif self.args.bn == 'no_4':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2, if_bn=False)
            elif self.args.bn == 'no_12':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1, if_bn=False)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2, if_bn=False)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
            elif self.args.bn == 'no_123':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1, if_bn=False)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2, if_bn=False)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2, if_bn=False)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
            elif self.args.bn == 'no_bn':
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1, if_bn=False)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2, if_bn=False)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2, if_bn=False)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2, if_bn=False)
            else:
                self.layer1 = self.make_layer(True, ResidualBlock, 64,  blocks[0], stride=1)
                self.layer2 = self.make_layer(True, ResidualBlock, 128, blocks[1], stride=2)
                self.layer3 = self.make_layer(True, ResidualBlock, 256, blocks[2], stride=2)
                self.layer4 = self.make_layer(True, ResidualBlock, 512, blocks[3], stride=2)
        self.fc = nn.Linear(512, self.args.num_classes)

    def make_layer(self, if_skip, block, channels, num_blocks, stride, if_bn=True):
        strides = [stride] + [1] * (num_blocks - 1)   #strides=[1,1]
        layers = []
        #for stride in strides:
        length = len(strides)
        for i in range(length):
            if i==0:
                layers.append(block(self.args, True, True, self.inchannel, channels, strides[i]))
            else:
                layers.append(block(self.args, if_bn, if_skip, self.inchannel, channels, strides[i]))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def resnet18(args, **kwargs):
    """Constructs a ResNet-18 model.
    """
    return ResNet(args, ResidualBlock, [2, 2, 2, 2], **kwargs)

def resnet34(args, **kwargs):
    """Constructs a ResNet-34 model.
    """
    return ResNet(args, ResidualBlock, [3, 4, 6, 3], **kwargs)
