import torch.nn as nn
import torch
from hyper_params import Z_Dim, CLASSES, DEVICE

class Generator(nn.Module):
    def __init__(self, z=Z_Dim, classes=CLASSES, hidden=2):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(classes, z)
        self.fc_layers = nn.ModuleList()
        fc = nn.Linear(z * 2, z * 2)
        bn = nn.BatchNorm1d(z * 2)
        act = nn.ReLU()
        self.zdim = z
        for _ in range(hidden):
            self.fc_layers += [fc, bn, act]
        self.representation_layer = nn.Linear(z*2, z)

    def forward(self, labels):
        #print(f"shape of labels = {labels.shape}")
        btc = labels.shape[0]
        eps = torch.rand((btc, self.zdim), requires_grad=False, device=DEVICE)
        y_input = self.embedding(labels)
        #print(f"Shape of y_input = {y_input.shape}. Shape of eps = {eps.shape}")
        z = torch.cat((eps, y_input), dim=1)
        for layer in self.fc_layers:
            z = layer(z)
        z = self.representation_layer(z)
        return z

class Predictor(nn.Module):
    def __init__(self, z=Z_Dim, classes=CLASSES):
        super(Predictor, self).__init__()
        self.fc = nn.Linear(z, classes)
    
    def forward(self, z):
        return self.fc(z)

if __name__ == "__main__":
    num_element = 0
    i = 0
    Batch = 16
    generator = Generator().to(DEVICE)
    
    batched_labels = []
    for c in range(CLASSES):
        y_list = []
        for _ in range(Batch):
            y_list.append(torch.tensor(c, dtype=torch.long))
        y_batch = torch.stack(y_list)
        #batched_labels.append(y_batch.view(-1, 1))
        batched_labels.append(y_batch)

    for b in batched_labels:
        b = b.to(DEVICE)
        print(f"batched labels = {b}, shape = {b.shape}")
        z = generator(b)
        print(f"shape of z = {z.shape}.")

    #for i in range(len(batched_labels)):
    #    print(f"batched labels = {batched_labels[i]}, shape = {batched_labels[i].shape}")
    
    '''
    for k, v in generator.named_parameters():
        print(f"No. {i+1}: layer name: {k}, shape: {v.shape}, size = {torch.numel(v)} requires grad = {v.requires_grad}")
        num_element += v.numel()
        i += 1
    '''
        