import torch
from torch import nn
import torch.nn.functional as F


# Define the generator model
class Generator(nn.Module):

    def __init__(self, noise_channels, image_channels, features):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Transpose block 1
            nn.ConvTranspose2d(noise_channels,
                               features * 16,
                               kernel_size=4,
                               stride=1,
                               padding=0),
            nn.ReLU(),

            # Transpose block 2
            nn.ConvTranspose2d(features * 16,
                               features * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Transpose block 3
            nn.ConvTranspose2d(features * 8,
                               features * 4,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Transpose block 4
            nn.ConvTranspose2d(features * 4,
                               features * 2,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Last transpose block (different configuration)
            nn.ConvTranspose2d(features * 2,
                               image_channels,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.Tanh(),  # Tanh activation for the final layer
        )

    def forward(self, x):
        # Forward pass for the generator
        return self.model(x)


# Define the discriminator model
class Discriminator(nn.Module):

    def __init__(self, image_channels, features):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # Conv block 1
            nn.Conv2d(image_channels,
                      features,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(0.2),

            # Conv block 2
            nn.Conv2d(features,
                      features * 2,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),

            # Conv block 3
            nn.Conv2d(features * 2,
                      features * 4,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),

            # Conv block 4
            nn.Conv2d(features * 4,
                      features * 8,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2),

            # Conv block 5 (different configuration)
            nn.Conv2d(features * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),  # Sigmoid activation for the final layer
        )

    def forward(self, x):
        # Forward pass for the discriminator
        return self.model(x)

# Define the classifier model
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
        # Dropout module with 0.2 drop probability
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        return F.log_softmax(self.fc4(x), dim=1)

# Define the noised discriminator model
class Noised_Discriminator(nn.Module):

    def __init__(self, image_channels, features, noise_std=0.1):
        super(Noised_Discriminator, self).__init__()
        self.noise_std = noise_std
        # Define the noised discriminator model
        self.model = nn.Sequential(
            # Conv block 1
            nn.Conv2d(image_channels,
                      features,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(0.2),

            # Conv block 2
            nn.Conv2d(features,
                      features * 2,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),

            # Conv block 3
            nn.Conv2d(features * 2,
                      features * 4,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),

            # Conv block 4
            nn.Conv2d(features * 4,
                      features * 8,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2),

            # Conv block 5 (different configuration)
            nn.Conv2d(features * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),  # Sigmoid activation for the final layer
        )

    def forward(self, x):
        # Forward pass for the discriminator
        for module in self.model:
            x = module(x)
            if isinstance(module, nn.Conv2d) or isinstance(
                    module, nn.BatchNorm2d):
                # Add Gaussian noise to the output of Conv2d and BatchNorm2d layers
                x += self.noise_std * torch.randn_like(x)
        return x
