import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from torchsummary import summary

from train import train
from noises import my_noise



class Autoencoder(nn.Module):
    def __init__(self, train_noise, test_noise, distribution):
        super(Autoencoder, self).__init__()
        self.tr_noise = train_noise
        self.te_noise = test_noise
        self.distribution = distribution

        self.encoder1 = self._block_encoder(3, 30, 4, 2, 1)
        self.encoder2 = self._block_encoder(30, 73, 4, 2, 1)
        self.encoder3 = self._block_encoder(73, 100, 4, 2, 1)

        self.decoder1 = self._block_decoder(100, 73, 4, 2, 1)
        self.decoder2 = self._block_decoder(73, 30, 4, 2, 1)
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(30, 3, 4, 2, 1, bias=False),  # [batch, 3, 32, 32]
            nn.BatchNorm2d(3),
            nn.Sigmoid(),
        )

    def _block_encoder(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
        )

    def _block_decoder(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
        )
 

    def batchnoise(self, x):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if self.training:
            y = x + self.tr_noise * torch.mean(x) * my_noise(x, self.distribution).to(device)
        else:
            y = x + self.te_noise * torch.mean(x) * my_noise(x, self.distribution).to(device)
        
        return y


    def forward(self, x):
        x = self.encoder1(self.batchnoise(x))
        x = self.encoder2(self.batchnoise(x))
        x = self.encoder3(self.batchnoise(x))

        x = self.decoder1(self.batchnoise(x))
        x = self.decoder2(self.batchnoise(x))
        x = self.decoder3(self.batchnoise(x))
        return x
    

def define_the_model(args):
    
    models_gen = []
    for _ in range(args.num_classes):
        models_gen.append(Autoencoder(args.train_noise, args.test_noise, args.distribution).to(args.device))

    if args.mprint:
        summary(models_gen[0], input_size=(3, 32, 32))

    return models_gen


     