from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt

cudnn.benchmark = True

#set manual seed to a constant get a consistent output
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

class Discriminator(nn.Module):
    def __init__(self, ngpu, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, 1, 1, 2, 0, bias=False),
            # nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output.view(-1, 1).squeeze(1)
    
class TransformerDisciminator(nn.Module):
    def __init__(self, ngpu, nc=14, ndf=256):
        super(TransformerDisciminator, self).__init__()
        self.mlp = nn.Conv2d(nc, 64, 1, 1, 1, bias=False)
        self.cnn = nn.Sequential(nn.Conv2d(64, ndf, 4, 2, 1, bias=False),
                                nn.LeakyReLU(0.2, inplace=True))
        encoder_layer = nn.TransformerEncoderLayer(d_model=ndf, nhead=4)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.out = nn.Linear(ndf, 1)
    
    def forward(self, x):
        x = self.mlp(x)
        x = self.cnn(x)
        x = x.reshape([x.size(0), x.size(1), -1])
        x = x.transpose(2, 1)
        x = self.transformer_encoder(x)
        x = self.out(x)
        x = torch.mean(x, 1)
        return x.squeeze(1)
        

# Testing
if __name__ == "__main__":
    # net = Discriminator(1, nc=14, ndf=64)
    net = TransformerDisciminator(1, 14, 64)
    print(net)
    x_test = torch.rand(4 * 14 * 8 * 8).reshape([4, 14, 8, 8])
    out_test = net(x_test)
    print(out_test.shape)