from torch import nn


# Define the generator model
class Generator(nn.Module):

    def __init__(self, noise_channels, image_channels, features):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Transpose block 1
            nn.ConvTranspose2d(noise_channels,
                               features * 16,
                               kernel_size=4,
                               stride=1,
                               padding=0),
            nn.ReLU(),

            # Transpose block 2
            nn.ConvTranspose2d(features * 16,
                               features * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Transpose block 3
            nn.ConvTranspose2d(features * 8,
                               features * 4,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Transpose block 4
            nn.ConvTranspose2d(features * 4,
                               features * 2,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.ReLU(),

            # Last transpose block (different configuration)
            nn.ConvTranspose2d(features * 2,
                               image_channels,
                               kernel_size=4,
                               stride=2,
                               padding=1),
            nn.Tanh(),  # Tanh activation for the final layer
        )

    def forward(self, x):
        # Forward pass for the generator
        return self.model(x)


# Define the discriminator model
class Discriminator(nn.Module):

    def __init__(self, image_channels, features):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # Conv block 1
            nn.Conv2d(image_channels,
                      features,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(0.2),

            # Conv block 2
            nn.Conv2d(features,
                      features * 2,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),

            # Conv block 3
            nn.Conv2d(features * 2,
                      features * 4,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),

            # Conv block 4
            nn.Conv2d(features * 4,
                      features * 8,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2),

            # Conv block 5 (different configuration)
            nn.Conv2d(features * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),  # Sigmoid activation for the final layer
        )

    def forward(self, x):
        # Forward pass for the discriminator
        return self.model(x)
