import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch import autograd



class Critic(nn.Module):
    def __init__(self, filter_sizes, leaky_relu_alpha, img_channels,  bias=True):
        super(Critic, self).__init__()

        # Network architecture
        # Input conv | out:[16 x 16 x 128]
        self.input = nn.Sequential (
            nn.Conv2d(in_channels=img_channels, out_channels=filter_sizes[0], kernel_size=4, stride=2, padding=1, bias=bias),
            nn.LeakyReLU(leaky_relu_alpha, inplace=True))
        
        # Hidden conv 1 | out:[8 x 8 x 256]
        self.hidden_conv_1 = nn.Sequential(
            nn.Conv2d(in_channels=filter_sizes[0], out_channels=filter_sizes[1], kernel_size=4, stride=2, padding=1, bias=bias),
            nn.LeakyReLU(leaky_relu_alpha, inplace=True))
        
        # Hidden conv 2  | out:[4 x 4 x 512]
        self.hidden_conv_2 = nn.Sequential(
            nn.Conv2d(in_channels=filter_sizes[1], out_channels=filter_sizes[2], kernel_size=4, stride=2, padding=1, bias=bias),
            nn.LeakyReLU(leaky_relu_alpha, inplace=True))
        
        # Out conv | out:[1 x 1 x 1]
        self.output = nn.Sequential(
            nn.Conv2d(in_channels=filter_sizes[2], out_channels=1, kernel_size=4, stride=1, padding=0, bias=bias))



    def forward(self, x):
        features = self.input(x)

        features = self.hidden_conv_1(features)
        features = self.hidden_conv_2(features)

        output = self.output(features)

        return output

class Generator(nn.Module):
    def __init__(self, filter_sizes, leaky_relu_alpha, latent_dim, bias=True, bnorm_affine=True):
        super(Generator, self).__init__()
      
        # Network architecture
        # Input Tconv | out:[512 x 512 x 1024]
        self.input = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_dim, out_channels=filter_sizes[0], kernel_size=4, stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(num_features=filter_sizes[0], affine=bnorm_affine),
            nn.ReLU(inplace=True))
        
        # Hidden Tconv 1 | out:[256 x 256 x 512]
        self.hidden_tconv_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=filter_sizes[0], out_channels=filter_sizes[1], kernel_size=4, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(num_features=filter_sizes[1], affine=bnorm_affine),
            nn.ReLU(inplace=True))

        # Input Tconv 2 | out:[128 x 128 x 256]
        self.hidden_tconv_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=filter_sizes[1], out_channels=filter_sizes[2], kernel_size=4, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(num_features=filter_sizes[2], affine=bnorm_affine),
            nn.ReLU(inplace=True))
        
        # Output Tconv | out:[32 x 32 x 64]
        self.output = nn.Sequential(
            nn.ConvTranspose2d(in_channels=filter_sizes[2], out_channels=3, kernel_size=4, stride=2, padding=1, bias=bias),
            nn.Tanh())
        
        self.latent_dim = latent_dim

    def forward(self, x):
        features = self.input(x)
        
        features = self.hidden_tconv_1(features)
        features = self.hidden_tconv_2(features)

        output = self.output(features)

        return output
    

    # Custom weights initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1 and m.weight is not None:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)