import torch
import torch.nn as nn
import torch.nn.functional as F


class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(5, 5))
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.drop = nn.Dropout2d()

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class CNNMnist2(nn.Module):
    def __init__(self):
        super(CNNMnist2, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(3, 3), stride=2, padding=1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(3, 3), stride=2, padding=1)
        self.fc1 = nn.Linear(784, 32)
        self.fc2 = nn.Linear(32, 10)
        # self.drop = nn.Dropout2d()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class CNNFashion_Mnist(nn.Module):
    def __init__(self):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


class CNNCifar(nn.Module):
    def __init__(self):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


"""ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class BasicBlockWithDropout(BasicBlock):
    def __init__(self, in_planes, planes, stride=1):
        super().__init__(in_planes, planes, stride)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=0.3)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        out = F.dropout(out, p=0.3)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class BottleneckWithDropout(Bottleneck):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=0.3)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.dropout(out, p=0.3)
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        out = F.dropout(out, p=0.3)
        return out


class ResNet(nn.Module):
    def __init__(
        self,
        block,
        num_blocks,
        num_classes=10,
        act=False,
        eps=1,
        norm=False,
        bn=False,
        clip=False,
    ):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

        if act == "sigmoid":
            self.activation = nn.Sigmoid()
        elif act == "tanh":
            self.activation = nn.Tanh()
        else:
            self.activation = None
        if bn:
            self.bn = nn.BatchNorm2d(1)
        else:
            self.bn = None
        self.act = act
        self.eps = eps
        self.norm = norm
        self.clip = clip

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        if self.activation:
            out = self.activation(out)
            if self.act == "sigmoid":
                if self.norm:
                    out = out - 0.5
                out = out * 2
            out = out * self.eps
            if self.bn is not None:
                out = self.bn(out)
        if self.clip:
            out = out.clamp(-1 * self.eps, self.eps)
        return out

        return out


class ImagenetResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ImagenetResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


# WGAN-GP에서 사용한 residual block 정의


class WGANGEN(nn.Module):
    def __init__(self, z_size, image_size, image_channel_size, channel_size):
        # configurations
        super(WGANGEN, self).__init__()
        self.z_size = z_size
        self.image_size = image_size
        self.image_channel_size = image_channel_size
        self.channel_size = channel_size

        # layers
        self.fc = nn.Linear(z_size, (image_size // 8) ** 2 * channel_size * 8)
        self.bn0 = nn.BatchNorm2d(channel_size * 8)
        self.bn1 = nn.BatchNorm2d(channel_size * 4)
        self.deconv1 = nn.ConvTranspose2d(
            channel_size * 8, channel_size * 4, kernel_size=4, stride=2, padding=1
        )
        self.bn2 = nn.BatchNorm2d(channel_size * 2)
        self.deconv2 = nn.ConvTranspose2d(
            channel_size * 4,
            channel_size * 2,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.bn3 = nn.BatchNorm2d(channel_size)
        self.deconv3 = nn.ConvTranspose2d(
            channel_size * 2, channel_size, kernel_size=4, stride=2, padding=1
        )
        self.deconv4 = nn.ConvTranspose2d(
            channel_size, image_channel_size, kernel_size=3, stride=1, padding=1
        )

    def forward(self, z):
        g = F.relu(
            self.bn0(
                self.fc(z).view(
                    z.size(0),
                    self.channel_size * 8,
                    self.image_size // 8,
                    self.image_size // 8,
                )
            )
        )
        g = F.relu(self.bn1(self.deconv1(g)))
        g = F.relu(self.bn2(self.deconv2(g)))
        g = F.relu(self.bn3(self.deconv3(g)))
        g = self.deconv4(g)
        return F.sigmoid(g)


class WGANCritic(nn.Module):
    def __init__(self, image_size, image_channel_size, channel_size):
        # configurations
        super().__init__()
        self.image_size = image_size  # 32
        self.image_channel_size = image_channel_size
        self.channel_size = channel_size

        # layers
        self.conv1 = nn.Conv2d(
            image_channel_size, channel_size, kernel_size=4, stride=2, padding=1
        )
        self.conv2 = nn.Conv2d(
            channel_size, channel_size * 2, kernel_size=4, stride=2, padding=1
        )
        self.conv3 = nn.Conv2d(
            channel_size * 2, channel_size * 4, kernel_size=4, stride=2, padding=1
        )
        self.conv4 = nn.Conv2d(
            channel_size * 4,
            channel_size * 8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc = nn.Linear((image_size // 8) ** 2 * channel_size * 8, 1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))  # Nx3x32x32 -> Nx64x16x16
        # print(x.shape)
        x = F.leaky_relu(self.conv2(x))  # Nx64x16x16 -> Nx128x8x8
        # print(x.shape)
        x = F.leaky_relu(self.conv3(x))  # Nx128x8x8 -> Nx256x4x4
        # print(x.shape)
        x = F.leaky_relu(self.conv4(x))  # Nx256x4x4 -> Nx512x4x4
        # print(x.shape)
        x = x.view(
            -1, (self.image_size // 8) ** 2 * self.channel_size * 8
        )  # Nx512x4x4 -> Nx8192
        # print(x.shape)
        return self.fc(x).squeeze()  # Nx8192 -> Nx1


def WGANGenerator():
    return WGANGEN(100, 32, 3, 64)


def WGANDiscriminator():
    return WGANCritic(32, 3, 64)


def DiscResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1)


def DiscServerResNet8():
    return WGANResDiscriminator()


def GenServerResNet8():
    return WGANResGenerator()


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet18Rot():
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)


def ImagenetResNet18():
    return ImagenetResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)


def ImagenetResNet18Rot():
    return ImagenetResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)


def cifar100ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)


def tinyimagenetResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=200)


def ResNet18WithDropout():
    return ResNet(BasicBlockWithDropout, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet50Rot():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())


class Iris(nn.Module):
    def __init__(self):
        super(Iris, self).__init__()
        self.fc = nn.Linear(4, 3)

    def forward(self, x):
        x = self.fc(x)
        return x


def MeanPool(input):
    output = (
        input[:, :, ::2, ::2]
        + input[:, :, 1::2, ::2]
        + input[:, :, ::2, 1::2]
        + input[:, :, 1::2, 1::2]
    ) / 4
    return output


def Upsample(input):
    output = torch.cat((input, input, input, input), dim=1)
    output = F.pixel_shuffle(output, 2)
    return output


class ResBlock(nn.Module):
    def __init__(self, num_channel, resample="None", isnorm=True):
        super(ResBlock, self).__init__()
        self.resample = resample
        self.isnorm = isnorm
        self.conv1 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.conv2 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.norm1 = nn.BatchNorm2d(num_channel)
        self.norm2 = nn.BatchNorm2d(num_channel)
        if self.resample != "None":
            self.conv3 = nn.Conv2d(
                num_channel, num_channel, kernel_size=3, stride=1, padding=1
            )

    def normalize(self, input, layer_num):
        if self.isnorm:
            if layer_num == 1:
                return self.norm1(input)
            else:
                return self.norm2(input)
        else:
            return input

    def forward(self, input):
        if self.resample == "down":
            output = self.conv1(F.relu(self.normalize(input, layer_num=1)))
            output = MeanPool(self.conv2(F.relu(self.normalize(output, layer_num=2))))
            shortcut = MeanPool(self.conv3(input))
            output = output + shortcut
        elif self.resample == "up":
            output = self.conv1(Upsample(F.relu(self.normalize(input, layer_num=1))))
            output = self.conv2(F.relu(self.normalize(output, layer_num=2)))
            shortcut = self.conv3(Upsample(input))
            output = output + shortcut
        elif self.resample == "None":
            output = self.conv1(F.relu(self.normalize(input, layer_num=1)))
            output = self.conv2(F.relu(self.normalize(output, layer_num=2)))
            output = output + input
        else:
            raise Exception("invalid resample value")
        return output


class OptimizedResBlock(nn.Module):
    def __init__(self, num_channel):
        super(OptimizedResBlock, self).__init__()
        self.conv1 = nn.Conv2d(3, num_channel, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.conv3 = nn.Conv2d(3, num_channel, kernel_size=3, stride=1, padding=1)

    def forward(self, input):
        output = F.relu(self.conv1(input))
        output = MeanPool(self.conv2(output))
        shortcut = self.conv3(MeanPool(input))
        output = output + shortcut
        return output


class WGANResGenerator(nn.Module):
    def __init__(self, num_channel=128, latent_dim=128):
        super(WGANResGenerator, self).__init__()
        self.num_channel = num_channel
        self.block1 = ResBlock(num_channel=num_channel, resample="up")
        self.block2 = ResBlock(num_channel=num_channel, resample="up")
        self.block3 = ResBlock(num_channel=num_channel, resample="up")
        self.linear = nn.Linear(latent_dim, 4 * 4 * num_channel)
        self.convlayer = nn.Conv2d(num_channel, 3, kernel_size=3, stride=1, padding=1)
        self.normalize = nn.BatchNorm2d(num_channel)

    def forward(self, input):
        output = self.linear(input)
        output = output.view([-1, self.num_channel, 4, 4])
        output = self.block1(output)
        output = self.block2(output)
        output = F.relu(self.normalize(self.block3(output)))
        output = F.tanh(self.convlayer(output))
        return output


class WGANResDiscriminator(nn.Module):
    def __init__(self, num_channel=128):
        super(WGANResDiscriminator, self).__init__()
        self.block1 = OptimizedResBlock(num_channel=num_channel)
        self.block2 = ResBlock(num_channel=num_channel, resample="down", isnorm=False)
        self.block3 = ResBlock(num_channel=num_channel, resample="None", isnorm=False)
        self.block4 = ResBlock(num_channel=num_channel, resample="None", isnorm=False)
        self.linear = nn.Linear(num_channel, 1)

    def forward(self, input):
        output = self.block1(input)
        output = self.block2(output)
        output = self.block3(output)
        output = F.relu(self.block4(output))
        output = torch.mean(output, axis=[2, 3])
        output = self.linear(output)
        output = output.view(-1)
        return output


def DiscServerResNet8():
    return WGANResDiscriminator()


def GenServerResNet8():
    return WGANResGenerator()


class ResNet_disc(nn.Module):
    def __init__(
        self,
        block,
        num_blocks,
        act=False,
        eps=1,
        norm=False,
        bn=False,
        clip=False,
        num_classes=10,
    ):
        super(ResNet_disc, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.linear = nn.Linear(256 * block.expansion, num_classes)

        if act == "sigmoid":
            self.activation = nn.Sigmoid()
        elif act == "tanh":
            self.activation = nn.Tanh()
        else:
            self.activation = None
        if bn:
            self.bn = nn.BatchNorm2d(1)
        else:
            self.bn = None
        self.act = act
        self.eps = eps
        self.norm = norm
        self.clip = clip

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        if self.activation:
            out = self.activation(out)
            if self.act == "sigmoid":
                if self.norm:
                    out = out - 0.5
                out = out * 2
            out = out * self.eps
            if self.bn is not None:
                out = self.bn(out)
        if self.clip:
            out = out.clamp(-1 * self.eps, self.eps)
        return out


def ResNet8_disc(*args, **kwargs):
    return ResNet_disc(BasicBlock, [1, 1, 1], num_classes=1, *args, **kwargs)


def ResNet18_disc(*args, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1, *args, **kwargs)
