import torch
import torch.nn as nn

class Generator(nn.Module):
    """
    Generator network that transforms latent vectors into images.
    """
    def __init__(self, ngpu, nc, nz, ngf):
        """
        Initialize the Generator.
        
        Args:
            ngpu (int): Number of GPUs available.
            nc (int): Number of channels in the output image.
            nz (int): Size of the latent noise vector.
            ngf (int): Base number of generator filters.
        """
        super(Generator, self).__init__()
        self.ngpu = ngpu
        
        # Define the generator layers as a sequential model
        layers = [
            # Input is latent vector Z: (nz) x 1 x 1, output: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(nz, ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            # Upsample: state size: (ngf*8) x 4 x 4 -> (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            # Upsample: state size: (ngf*4) x 8 x 8 -> (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            # Upsample: state size: (ngf*2) x 16 x 16 -> (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # Upsample: state size: (ngf) x 32 x 32 -> (ngf//2) x 64 x 64
            nn.ConvTranspose2d(ngf, ngf // 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf // 2),
            nn.ReLU(True),
            
            # Final convolution: output state size: (nc) x 64 x 64
            nn.ConvTranspose2d(ngf // 2, nc, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()  # Output values in the range [-1, 1]
        ]
        
        self.main = nn.Sequential(*layers)
    
    def forward(self, input):
        """
        Forward pass of the generator.
        
        Args:
            input (Tensor): Input latent vector.
        
        Returns:
            Tensor: Generated image.
        """
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

    
class Discriminator(nn.Module):
    """
    Discriminator network that classifies images as real or fake.
    """
    def __init__(self, ngpu, nc, ndf, noise_std=0.0):
        """
        Initialize the Discriminator.
        
        Args:
            ngpu (int): Number of GPUs available.
            nc (int): Number of channels in the input image.
            ndf (int): Base number of discriminator filters.
            imageSize (int): Size (height/width) of the input image.
            noise_std (float): Standard deviation for noise added during forward pass.
        """
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.noise_std = noise_std
        
        # Define the discriminator layers as a sequential model
        layers = [
            # Input: (nc) x imageSize x imageSize, output: (ndf) x (imageSize/2) x (imageSize/2)
            nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Downsample: (ndf) x (imageSize/2) x (imageSize/2) -> (ndf*2) x (imageSize/4) x (imageSize/4)
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Downsample: (ndf*2) x (imageSize/4) x (imageSize/4) -> (ndf*4) x (imageSize/8) x (imageSize/8)
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Downsample: (ndf*4) x (imageSize/8) x (imageSize/8) -> (ndf*8) x (imageSize/16) x (imageSize/16)
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        
        # Final convolution that reduces spatial dimensions to 1x1 and outputs a single channel
        layers.append(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=0, bias=False))
        layers.append(nn.Sigmoid())
        
        self.main = nn.Sequential(*layers)
    
    def forward(self, input):
        """
        Forward pass of the discriminator.
        
        Args:
            input (Tensor): Input image.
        
        Returns:
            Tensor: Discriminator output, a scalar probability for each input.
        """
        # If using multiple GPUs, run in parallel
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            # Apply each module sequentially; add noise to intermediate outputs if desired
            output = input
            for module in self.main:
                output = module(output)
                # For convolution and batch norm layers, add noise if noise_std > 0
                if isinstance(module, nn.Conv2d) or isinstance(module, nn.BatchNorm2d):
                    output += self.noise_std * torch.randn_like(output)
        
        # Reshape output to a one-dimensional tensor for classification
        return output.view(-1, 1).squeeze(1)


class Toy_Generator(nn.Module):
    """
    Toy Generator network for the GAN.

    This network takes a latent noise vector and transforms it into an output of the specified dimensionality using linear layers interleaved with batch normalization and LeakyReLU activations.
    """
    def __init__(self, noise_channels, num_dimensions):
        """
        Initialize the Generator.

        Args:
            noise_channels (int): Size of the input noise vector.
            num_dimensions (int): Dimensionality of the generator output.
        """
        super(Generator, self).__init__()

        def block(in_features, out_features, normalize=True):
            """
            Create a generator block consisting of a linear layer, optional batch normalization,
            and a LeakyReLU activation.

            Args:
                in_features (int): Number of input features.
                out_features (int): Number of output features.
                normalize (bool): Whether to apply batch normalization.
            
            Returns:
                list: A list of layers for the block.
            """
            layers = [nn.Linear(in_features, out_features)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_features, momentum=0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # Define the generator model using a sequence of blocks.
        self.model = nn.Sequential(
            *block(noise_channels, 128, normalize=False),  # First block: linear + activation (no normalization)
            *block(128, 256, normalize=False),             # Second block
            *block(256, 512, normalize=False),             # Third block
            nn.Linear(512, num_dimensions)                 # Final linear layer to produce output
        )

    def forward(self, z):
        """
        Forward pass of the generator.

        Args:
            z (Tensor): Input noise vector.
        
        Returns:
            Tensor: Generated output.
        """
        img = self.model(z)
        return img


class Toy_GeneratorDiscriminator(nn.Module):
    """
    Toy Discriminator network for the GAN.

    This network takes an input (either real or generated) and outputs a probability value indicating whether the input is real (closer to 1) or fake (closer to 0).
    """
    def __init__(self, num_dimensions):
        """
        Initialize the Discriminator.

        Args:
            num_dimensions (int): Dimensionality of the input data.
        """
        super(Discriminator, self).__init__()

        # Define the discriminator model as a sequence of layers.
        self.model = nn.Sequential(
            nn.Linear(num_dimensions, 512),       # First linear layer
            nn.LeakyReLU(0.2, inplace=True),        # Activation function
            nn.Linear(512, 512),                   # Second linear layer
            nn.LeakyReLU(0.2, inplace=True),        # Activation function
            nn.Linear(512, 1),                     # Final linear layer for binary output
            nn.Sigmoid()                           # Sigmoid to convert output to probability
        )

    def forward(self, img):
        """
        Forward pass of the discriminator.

        Args:
            img (Tensor): Input data (either real or generated).
        
        Returns:
            Tensor: Probability indicating the input's validity.
        """
        validity = self.model(img)
        return validity
