import torch
import torch.nn as nn
from typing import List
from hyper_params import Z_DIM, DEVICE, Z_Dim, CHANNEL, CLASSES

proto = torch.randn(Z_Dim, device=DEVICE)

class TC_net(nn.Module):
    def __init__(self, in_channels=Z_DIM, outputs=CHANNEL, classes=CLASSES):
        super(TC_net, self).__init__()
        self.tc1 = nn.ConvTranspose2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)
        self.tc2 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
        self.tc3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.tc4 = nn.ConvTranspose2d(in_channels=64, out_channels=outputs, kernel_size=4, stride=2, padding=1, bias=False)
        self.act = nn.LeakyReLU(0.01)
        self.sigmoid = nn.Sigmoid()
        self.num_classes = classes
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)

    def forward(self, proto):
        x = proto.reshape(-1, 20, 7, 7)
        x = self.tc1(x)
        x = self.bn1(x)
        x = self.act(x)
        
        x = self.tc2(x)
        x = self.bn2(x)
        x = self.act(x)

        x = self.tc3(x)
        x = self.bn3(x)
        x = self.act(x)
 
        x = self.tc4(x)
        x = self.sigmoid(x)
      
        return x
