import torch
import torch.nn as nn
import torch.nn.functional as F


# resampling functions
def MeanPool(input):
    output = (
        input[:, :, ::2, ::2]
        + input[:, :, 1::2, ::2]
        + input[:, :, ::2, 1::2]
        + input[:, :, 1::2, 1::2]
    ) / 4
    return output


def Upsample(input):
    output = torch.cat((input, input, input, input), dim=1)
    output = F.pixel_shuffle(output, 2)
    return output


# Residual Blocks
class ResBlock(nn.Module):
    def __init__(self, num_channel, resample="None"):
        super(ResBlock, self).__init__()

        self.resample = resample
        self.conv1 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.conv2 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.norm1 = nn.BatchNorm2d(num_channel)
        self.norm2 = nn.BatchNorm2d(num_channel)

        if self.resample != "None":
            self.conv3 = nn.Conv2d(num_channel, num_channel, kernel_size=1)

    def forward(self, input):
        if self.resample == "down":
            output = self.conv1(F.relu(input))
            output = MeanPool(self.conv2(F.relu(output)))
            shortcut = MeanPool(self.conv3(input))
            output = output + shortcut

        elif self.resample == "up":
            output = self.conv1(Upsample(F.relu(self.norm1(input))))
            output = self.conv2(F.relu(self.norm2(output)))
            shortcut = self.conv3(Upsample(input))
            output = output + shortcut

        elif self.resample == "None":
            output = self.conv1(F.relu(self.norm1(input)))
            output = self.conv2(F.relu(self.norm2(output)))
            output = output + input

        else:
            raise Exception("invalid resample value")

        return output


class OptimizedResBlock(nn.Module):
    def __init__(self, num_channel):
        super(OptimizedResBlock, self).__init__()

        self.conv1 = nn.Conv2d(3, num_channel, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(
            num_channel, num_channel, kernel_size=3, stride=1, padding=1
        )
        self.conv3 = nn.Conv2d(3, num_channel, kernel_size=1)

    def forward(self, input):
        output = F.relu(self.conv1(input))
        output = MeanPool(self.conv2(output))
        shortcut = self.conv3(MeanPool(input))
        output = output + shortcut
        return output


# Generator and Discriminator
class Generator(nn.Module):
    def __init__(self, num_channel=128, latent_dim=128):
        super(Generator, self).__init__()

        self.num_channel = num_channel
        self.block1 = ResBlock(num_channel=self.num_channel, resample="up")
        self.block2 = ResBlock(num_channel=self.num_channel, resample="up")
        self.block3 = ResBlock(num_channel=self.num_channel, resample="up")
        self.linear = nn.Linear(latent_dim, 4 * 4 * self.num_channel)
        self.conv = nn.Conv2d(self.num_channel, 3, kernel_size=3, stride=1, padding=1)
        self.norm = nn.BatchNorm2d(self.num_channel)

        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(
                    module.weight.data, mode="fan_in", nonlinearity="relu"
                )
                module.bias.data.zero_()

    def forward(self, input):
        output = self.linear(input)
        output = output.view([-1, self.num_channel, 4, 4])
        output = self.block1(output)
        output = self.block2(output)
        output = self.block3(output)
        output = F.relu(self.norm(output))
        output = torch.tanh(self.conv(output))
        return output


class Discriminator(nn.Module):
    def __init__(
        self, act=False, eps=1, norm=False, bn=False, clip=False, num_channel=128, nc=3
    ):
        super(Discriminator, self).__init__()

        self.num_channel = num_channel
        self.block1 = OptimizedResBlock(num_channel=self.num_channel)
        self.block2 = ResBlock(num_channel=self.num_channel, resample="down")
        self.block3 = ResBlock(num_channel=self.num_channel, resample="None")
        self.block4 = ResBlock(num_channel=self.num_channel, resample="None")
        self.linear = nn.Linear(self.num_channel, 1)
        self.act = act

        self.act = act
        self.eps = eps
        self.norm = norm
        self.clip = clip
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(
                    module.weight.data, mode="fan_in", nonlinearity="relu"
                )
                module.bias.data.zero_()

    def forward(self, input):
        output = self.block1(input)
        output = self.block2(output)
        output = self.block3(output)
        output = self.block4(output)
        output = F.relu(output)
        output = torch.mean(output, axis=[2, 3])
        output = self.linear(output)
        return output
