import torch
import torch.nn as nn


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Generator(nn.Module):

    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.Tanh = nn.Tanh()
        self.conv1 = nn.ConvTranspose2d(nz, ngf * 8, 2, 1, 0, bias=False)
        self.BatchNorm1 = nn.BatchNorm2d(ngf * 8)

        self.conv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False)
        self.BatchNorm2 = nn.BatchNorm2d(ngf * 4)

        self.conv3 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)
        self.BatchNorm3 = nn.BatchNorm2d(ngf * 2)

        self.conv4 = nn.ConvTranspose2d(ngf * 2, ngf * 1, 4, 2, 1, bias=False)
        self.BatchNorm4 = nn.BatchNorm2d(ngf * 1)

        self.conv5 = nn.ConvTranspose2d(ngf * 1, nc, 4, 2, 1, bias=False)

        self.apply(weights_init)

    def forward(self, input):
        x = self.conv1(input)
        x = self.BatchNorm1(x)
        x = self.ReLU(x)
        #print(x.size())

        x = self.conv2(x)
        x = self.BatchNorm2(x)
        x = self.ReLU(x)
        #print(x.size())

        x = self.conv3(x)
        x = self.BatchNorm3(x)
        x = self.ReLU(x)
        #print(x.size())

        x = self.conv4(x)
        x = self.BatchNorm4(x)
        x = self.ReLU(x)
        #print(x.size())

        x = self.conv5(x)
        output = self.Tanh(x)
        return output


class Discriminator(nn.Module):

    def __init__(self, ndf, nc, nb_label):
        super(Discriminator, self).__init__()
        self.normal = ModelwNorm()
        self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
        self.conv1 = nn.Conv2d(nc, ndf * 1, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(ndf * 1, ndf * 2, 4, 2, 1, bias=False)
        self.BatchNorm2 = nn.BatchNorm2d(ndf * 2)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)
        self.BatchNorm3 = nn.BatchNorm2d(ndf * 4)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)
        self.BatchNorm4 = nn.BatchNorm2d(ndf * 8)
        self.conv5 = nn.Conv2d(ndf * 8, ndf * 1, 2, 1, 0, bias=False)
        self.disc_linear = nn.Linear(ndf * 1, 1)
        self.aux_linear = nn.Linear(ndf * 1, nb_label)
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.ndf = ndf
        self.apply(weights_init)

    def forward(self, input):
        x = self.normal(input)
        x = self.conv1(x)
        x = self.LeakyReLU(x)

        x = self.conv2(x)
        x = self.BatchNorm2(x)
        x = self.LeakyReLU(x)

        x = self.conv3(x)
        x = self.BatchNorm3(x)
        x = self.LeakyReLU(x)

        x = self.conv4(x)
        x = self.BatchNorm4(x)
        x = self.LeakyReLU(x)

        x = self.conv5(x)
        x = x.view(-1, self.ndf * 1)
        c = self.aux_linear(x)
        #c = self.softmax(c)
        s = self.disc_linear(x)
        s = self.sigmoid(s)
        return s, c


class ModelwNorm(nn.Module):
    def __init__(self):
        super(ModelwNorm, self).__init__()
        self.mean = torch.tensor([0.507, 0.487, 0.441]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.267, 0.256, 0.276]).view(1, 3, 1, 1)

    def forward(self, x):
        m, s = self.mean.cuda(), self.std.cuda()
        return (x - m) / s


if __name__ == '__main__':
    x = torch.randn(1, 100, 1, 1)
    y = torch.ones(1, 100)
    g = Generator(nz=100, ngf=64, nc=3)
    d = Discriminator(ndf=64, nc=3, nb_label=100)
    o = g(x)
    print(o.size())

    import torchvision.transforms as transforms
    from PIL import Image
    import matplotlib.pyplot as plt

    generated_image = transforms.ToPILImage()(o.squeeze(0))
    plt.imshow(generated_image)
    plt.axis('off')
    #plt.show()

    o = d(o)
    print(o[0].size(), o[1].size())