import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


class Critic(nn.Module):
    
    def __init__(self, channels_img, features_c = 64):
        super(Critic, self).__init__()
        
        self.disc = nn.Sequential(
            #size = 3*64*64
            nn.Conv2d(channels_img, features_c, kernel_size = 4, stride = 2, padding = 1), # Size : 32*32
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_c, features_c*2, kernel_size = 4, stride = 2, padding = 1), # size = 16*16
            nn.BatchNorm2d(features_c*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_c*2, features_c*4, kernel_size = 4, stride = 2, padding = 1), # size = 8*8
            nn.BatchNorm2d(features_c*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_c*4, features_c*8, kernel_size = 4, stride = 2, padding = 1), # size = 4*4
            nn.BatchNorm2d(features_c*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(features_c*8, 1, kernel_size = 4, stride = 2, padding = 0) #1*1
            
        )
        
        
    def forward(self, x):
        return self.disc(x)
    

class Generator(nn.Module):
    
    def __init__(self, z_dim, channels_img, features_g = 64):
        super(Generator, self).__init__()
        
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g*16, kernel_size = 4, stride = 1, padding = 0), # size = 4*4
            nn.BatchNorm2d(features_g*16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*16, features_g*8, kernel_size = 4, stride = 2, padding = 1), # size = 8*8
            nn.BatchNorm2d(features_g*8),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size = 4, stride = 2, padding = 1), # size = 16*16
            nn.BatchNorm2d(features_g*4),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size = 4, stride = 2, padding = 1), # size = 32*32
            nn.BatchNorm2d(features_g*2),
            nn.ReLU(),
            
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh()  # [-1, 1]
        )
        
    
    def forward(self, x):
        return self.net(x)
    
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

