# import torch
# import torch.nn as nn  
  
# class Generator(nn.Module):  
#     def __init__(self, z_dim=100,im=32, channels=3, num_classes=10):  
#         super(Generator, self).__init__()  
          
#         self.label_emb = nn.Embedding(num_classes, 128)  
#         self.dim=im//8
#         self.init_block = nn.Sequential(  
#             nn.Linear(z_dim + 128, 128 * self.dim* self.dim),  
#             nn.BatchNorm1d(128 * self.dim * self.dim),  
#             nn.ReLU(True),  
#         )  
          
#         self.conv_blocks = nn.Sequential(  
#             nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  
#             nn.BatchNorm2d(64),  
#             nn.ReLU(True),  
              
#             nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),  
#             nn.BatchNorm2d(32),  
#             nn.ReLU(True),  
              
#             nn.ConvTranspose2d(32, channels, 4, 2, 1, bias=False),  
#             nn.Tanh(),  
#         )  
  
#     def forward(self, noise, labels):  #torch.Size([32])
#         label_emb = self.label_emb(labels)  #torch.Size([32, 128])
#         gen_input = torch.cat([noise, label_emb], 1)  #torch.Size([32, 228])
          
#         x = self.init_block(gen_input)  #torch.Size([32, 100352])
#         x = x.view(-1, 128, self.dim, self.dim) # torch.Size([32, 128, 28, 28])
#         x = self.conv_blocks(x)  #torch.Size([32, 3, 224, 224])
          
#         return x
    
  
# class Discriminator(nn.Module):  
#     def __init__(self, im=32,channels=3, num_classes=10):  
#         super(Discriminator, self).__init__()  
          
#         self.features = nn.Sequential(  
#             nn.Conv2d(channels, 32, 4, 2, 1, bias=False),  
#             nn.LeakyReLU(0.2, inplace=True),  
              
#             nn.Conv2d(32, 64, 4, 2, 1, bias=False),  
#             nn.BatchNorm2d(64),  
#             nn.LeakyReLU(0.2, inplace=True),  
              
#             nn.Conv2d(64, 128, 4, 2, 1, bias=False),  
#             nn.BatchNorm2d(128),  
#             nn.LeakyReLU(0.2, inplace=True),  
#         )  
          
#         dim=im//8
#         self.validity = nn.Linear(128 * dim * dim, 1, bias=False)  
#         self.label = nn.Linear(128 * dim * dim, num_classes, bias=False)  
  
#     def forward(self, img):  #torch.Size([32, 3, 224, 224])  torch.Size([64, 1, 28, 28])
#         features = self.features(img)  #torch.Size([32, 128, 28, 28]) torch.Size([64, 128, 3, 3])
#         features = features.view(features.size(0), -1)  #torch.Size([32, 100352]) torch.Size([64, 1152])
          
#         validity = self.validity(features)  #torch.Size([32, 100352]) torch.Size([64, 1])
#         label = self.label(features)  #  torch.Size([32, 31])        torch.Size([64, 10])
          
#         return label, validity

import torch
import torch.nn as nn  
  
class Generator(nn.Module):  
    def __init__(self, z_dim=100,im=32, channels=3, num_classes=10):  
        super(Generator, self).__init__()  
          
        self.label_emb = nn.Embedding(num_classes, im*4)  
        self.im_size=im
        self.dim=im//8
        self.init_block = nn.Sequential(  
            nn.Linear(z_dim + im*4, im*4 * self.dim* self.dim),  
            nn.BatchNorm1d(im*4 * self.dim * self.dim),  
            nn.ReLU(True),  
        )  
          
        self.conv_blocks = nn.Sequential(  
            nn.ConvTranspose2d(im*4, im*2, 4, 2, 1, bias=False),  
            nn.BatchNorm2d(im*2),  
            nn.ReLU(True),  
              
            nn.ConvTranspose2d(im*2, im, 4, 2, 1, bias=False),  
            nn.BatchNorm2d(im),  
            nn.ReLU(True),  
              
            nn.ConvTranspose2d(im, channels, 4, 2, 1, bias=False),  
            nn.Tanh(),  
        )  
  
    def forward(self, noise, labels):  #torch.Size([32])
        label_emb = self.label_emb(labels)  #torch.Size([32, 128])
        gen_input = torch.cat([noise, label_emb], 1)  #torch.Size([32, 228])
          
        x = self.init_block(gen_input)  #torch.Size([32, 100352])
        x = x.view(-1, self.im_size*4, self.dim, self.dim) # torch.Size([32, 128, 28, 28])
        x = self.conv_blocks(x)  #torch.Size([32, 3, 224, 224])
          
        return x
# torch.Size([32, 896])     torch.Size([32, 996])   torch.Size([32, 702464])  torch.Size([32, 896, 28, 28]) torch.Size([32, 3, 224, 224])

class Discriminator(nn.Module):  
    def __init__(self, im=32,channels=3, num_classes=10):  
        super(Discriminator, self).__init__()  
          
        self.features = nn.Sequential(  
            nn.Conv2d(channels, im, 4, 2, 1, bias=False),  
            
            nn.LeakyReLU(0.2, inplace=True),  
              
            nn.Conv2d(im, im*2, 4, 2, 1, bias=False),  
            nn.BatchNorm2d(im*2),  
            nn.LeakyReLU(0.2, inplace=True),  
              
            nn.Conv2d(im*2, im*4, 4, 2, 1, bias=False),  
            nn.BatchNorm2d( im*4),  
            nn.LeakyReLU(0.2, inplace=True),  
        )  
          
        dim=im//8
        self.validity = nn.Linear(im*4 * dim * dim, 1, bias=False)  
        self.label = nn.Linear(im*4 * dim * dim, num_classes, bias=False)  
  
    def forward(self, img):  #torch.Size([32, 3, 224, 224])  torch.Size([64, 1, 28, 28])
        features = self.features(img)  #torch.Size([32, 128, 28, 28]) torch.Size([64, 128, 3, 3])
        features = features.view(features.size(0), -1)  #torch.Size([32, 100352]) torch.Size([64, 1152])
          
        validity = self.validity(features)  #torch.Size([32, 1]) torch.Size([64, 1])
        label = self.label(features)  #  torch.Size([32, 31])        torch.Size([64, 10])
          
        return label, validity
    #torch.Size([32, 3, 224, 224])   torch.Size([32, 896, 28, 28])  torch.Size([32, 702464])