import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
import numpy as np
import torch.nn.init as nn_init


class Expression(nn.Module):
    def __init__(self, func):
        super(Expression, self).__init__()
        self.func = func
    
    def forward(self, input):
        return self.func(input)

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        if self.equalInOut:
            out = self.relu2(self.bn2(self.conv1(out)))
        else:
            out = self.relu2(self.bn2(self.conv1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        if not self.equalInOut:
            return torch.add(self.convShortcut(x), out)
        else:
            return torch.add(x, out)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) // 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out_avg = F.avg_pool2d(out, 8)
        out = out_avg.view(-1, self.nChannels)
        return out_avg, self.fc(out)


class ResNet32x32(nn.Module):
    def __init__(self, block, layers, channels, groups=1, num_classes=1000, downsample='basic'):
        super().__init__()
        assert len(layers) == 3
        self.downsample_mode = downsample
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.layer1 = self._make_layer(block, channels, groups, layers[0])
        self.layer2 = self._make_layer(
            block, channels * 2, groups, layers[1], stride=2)
        self.layer3 = self._make_layer(
            block, channels * 4, groups, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc1 = nn.Linear(block.out_channels(
            channels * 4, groups), num_classes)
        self.fc2 = nn.Linear(block.out_channels(
            channels * 4, groups), num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, groups, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != block.out_channels(planes, groups):
            if self.downsample_mode == 'basic' or stride == 1:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, block.out_channels(planes, groups),
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(block.out_channels(planes, groups)),
                )
            elif self.downsample_mode == 'shift_conv':
                downsample = ShiftConvDownsample(in_channels=self.inplanes,
                                                 out_channels=block.out_channels(planes, groups))
            else:
                assert False

        layers = []
        layers.append(block(self.inplanes, planes, groups, stride, downsample))
        self.inplanes = block.out_channels(planes, groups)
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc1(x), self.fc2(x)



class GaussianNoise(nn.Module):
    def __init__(self, sigma):
        super(GaussianNoise, self).__init__()
        self.sigma = sigma

    def forward(self, input):
        if self.training:
            noise = Variable(input.data.new(input.size()).normal_(std=self.sigma))
            return input + noise
        else:
            return input

class LLPNet(nn.Module):
    def __init__(self, num_classes):
        super(LLPNet, self).__init__()

        self.num_classes = num_classes

        self.gaosi = nn.Sequential(GaussianNoise(0.05), nn.Dropout2d(0.15)) 

        self.core_net1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            )

        self.avgpool = nn.AvgPool2d(6)
        self.out_net1 = nn.Linear(128, self.num_classes)

    def forward(self, x):
        x = self.gaosi(x)
        x = self.core_net1(x)
        x = self.core_net2(x)
        x = self.core_net3(x)
        x_avg = self.avgpool(x)
        x = x_avg.view(x_avg.size(0), -1)
        return self.out_net1(x)


class LLPNet_new(nn.Module):
    def __init__(self, num_classes):
        super(LLPNet_new, self).__init__()

        self.num_classes = num_classes

        self.gaosi = nn.Sequential(GaussianNoise(0.05), nn.Dropout2d(0.15)) 

        self.core_net1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            )

        self.avgpool = nn.AvgPool2d(6)
        self.out_net1 = nn.Linear(128, self.num_classes)

    def forward(self, x):
        x = self.gaosi(x)
        x = self.core_net1(x)
        x = self.core_net2(x)
        x = self.core_net3(x)
        x_avg = self.avgpool(x)
        x = x_avg.view(x_avg.size(0), -1)
        return x_avg, self.out_net1(x)




def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()


class generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, input_size=32):
        super(generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return x


class generator_new(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, input_size=32):
        super(generator_new, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 512 * (self.input_size // 8) * (self.input_size // 8)),
            nn.BatchNorm1d(512 * (self.input_size // 8) * (self.input_size // 8)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 5, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 5, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, self.output_dim, 5, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 512, (self.input_size // 8), (self.input_size // 8))
        x = self.deconv(x)

        return x


class LLPNet_rotation(nn.Module):
    def __init__(self, num_classes, num_classes_rotation):
        super(LLPNet_rotation, self).__init__()

        self.num_classes = num_classes
        self.num_classes_rotation = num_classes_rotation
        self.gaosi = nn.Sequential(GaussianNoise(0.05), nn.Dropout2d(0.15)) 
        self.core_net1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            )

        self.avgpool = nn.AvgPool2d(6)
        self.out_net1 = nn.Linear(128, self.num_classes)
        self.out_net2 = nn.Linear(128, self.num_classes_rotation)

    def forward(self, x):
        x = self.gaosi(x)
        x = self.core_net1(x)
        x = self.core_net2(x)
        x = self.core_net3(x)
        x_avg = self.avgpool(x)
        x = x_avg.view(x_avg.size(0), -1)
        return self.out_net2(x), self.out_net1(x)


class LLPNet_rotation_test(nn.Module):
    def __init__(self, num_classes, num_classes_rotation):
        super(LLPNet_rotation_test, self).__init__()

        self.num_classes = num_classes
        self.num_classes_rotation = num_classes_rotation
        self.gaosi = nn.Sequential(GaussianNoise(0.05), nn.Dropout2d(0.15)) 
        self.core_net1 = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2),
            nn.Dropout(0.5)) 

        self.core_net3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2),
            )

        self.avgpool = nn.AvgPool2d(6)
        self.out_net1 = nn.Linear(128, self.num_classes)
        self.out_net2 = nn.Linear(128, self.num_classes_rotation)
        self.out_net3 = nn.Linear(128, self.num_classes_rotation)

    def forward(self, x):
        x = self.gaosi(x)
        x = self.core_net1(x)
        x = self.core_net2(x)
        x = self.core_net3(x)
        x_avg = self.avgpool(x)
        x = x_avg.view(x_avg.size(0), -1)
        return self.out_net3(x), self.out_net2(x), self.out_net1(x)