import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, img_channels, img_size):
        super().__init__()
        self.input_dim = img_channels * img_size * img_size

        self.net = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平
        return self.net(x)

class Generator(nn.Module):
    def __init__(self, noise_dim, img_channels, img_size):
        super().__init__()
        self.output_dim = img_channels * img_size * img_size
        self.img_channels = img_channels
        self.img_size = img_size

        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.net(z)
        return out.view(z.size(0), self.img_channels, self.img_size, self.img_size)
