import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as spectral_norm

class Generator(nn.Module):
    def __init__(self, z_dim=128, nc=1, img_size=64, ngf=64):
        super(Generator, self).__init__()
        
        self.z_dim = z_dim
        self.nc = nc
        self.img_size = img_size
        
        # Linear: 128 -> 4*4*128
        self.linear = nn.Linear(z_dim, 4 * 4 * 128)
        
        # Deconv1: 4x4 -> 4x4. Input channels 128, Output 1024.
        self.deconv1 = nn.ConvTranspose2d(128, 1024, 4, 1, 1)
        self.bn1 = nn.BatchNorm2d(1024, eps=1e-5)
        
        # Deconv2: 4x4 -> 8x8
        self.deconv2 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(512, eps=1e-5)
        
        # Deconv3: 8x8 -> 16x16
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(256, eps=1e-5)
        
        # Deconv4: 16x16 -> 32x32
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(128, eps=1e-5)
        
        if self.img_size == 64:
            # Deconv5: 32x32 -> 64x64
            self.deconv5 = nn.ConvTranspose2d(128, nc, 4, 2, 1)
        elif self.img_size == 128:
            # Deconv5: 32x32 -> 64x64
            self.deconv5 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
            self.bn5 = nn.BatchNorm2d(64, eps=1e-5)
            # Deconv6: 64x64 -> 128x128
            self.deconv6 = nn.ConvTranspose2d(64, nc, 4, 2, 1)
        else:
             raise ValueError(f"Unsupported img_size: {img_size}")
        
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)

    def forward(self, z):
        # Linear + Reshape
        out = self.linear(z)
        out = out.view(-1, 128, 4, 4) # NCHW
        
        # Deconv1
        out = self.deconv1(out) # 5x5
        out = out[:, :, :4, :4] # Crop to 4x4
        out = self.bn1(out)
        out = self.relu(out)
        
        # Deconv2
        out = self.deconv2(out) # 8x8
        out = self.bn2(out)
        out = self.relu(out)
        
        # Deconv3
        out = self.deconv3(out) # 16x16
        out = self.bn3(out)
        out = self.relu(out)
        
        # Deconv4
        out = self.deconv4(out) # 32x32
        out = self.bn4(out)
        out = self.relu(out)
        
        # Deconv5
        out = self.deconv5(out) # 64x64
        
        if self.img_size == 128:
            out = self.bn5(out)
            out = self.relu(out)
            out = self.deconv6(out) # 128x128
            
        out = self.tanh(out)
        
        return out

class Discriminator(nn.Module):
    """
    Discriminator Network
    
    Architecture (migrated from gan_model.py):
    - Input: Image 64x64x1
    - Conv1: 64x64x1 -> 32x32x64 (Kernel=4, Stride=2, Padding='SAME')
    - Conv2: 32x32x64 -> 16x16x128 (Kernel=4, Stride=2, Padding='SAME')
    - Conv3: 16x16x128 -> 8x8x256 (Kernel=4, Stride=2, Padding='SAME')
    - Conv4: 8x8x256 -> 4x4x512 (Kernel=4, Stride=2, Padding='SAME')
    - Conv5: 4x4x512 -> 2x2x1 (Kernel=4, Stride=2, Padding='SAME')
    - Activation: LeakyReLU (0.2) for hidden layers, Sigmoid for output.
    - Normalization: Spectral Normalization (SN) on all layers.
    """
    def __init__(self, nc=1, img_size=64, ndf=64):
        super(Discriminator, self).__init__()
        
        self.nc = nc
        self.img_size = img_size
        
        if self.img_size == 128:
            # Conv0: 128x128 -> 64x64
            self.conv0 = spectral_norm(nn.Conv2d(nc, 64, 4, 2, 1))
            # Conv1: 64x64 -> 32x32
            self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
        elif self.img_size == 64:
             # Conv1: 64x64 -> 32x32
             self.conv1 = spectral_norm(nn.Conv2d(nc, 64, 4, 2, 1))
        else:
             raise ValueError(f"Unsupported img_size: {img_size}")
        
        # Conv2: 32x32 -> 16x16
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
        
        # Conv3: 16x16 -> 8x8
        self.conv3 = spectral_norm(nn.Conv2d(128, 256, 4, 2, 1))
        
        # Conv4: 8x8 -> 4x4
        self.conv4 = spectral_norm(nn.Conv2d(256, 512, 4, 2, 1))
        
        # Conv5: 4x4 -> 2x2
        self.conv5 = spectral_norm(nn.Conv2d(512, 1, 4, 2, 1))
        
        self.lrelu = nn.LeakyReLU(0.2)
        # Sigmoid is applied in loss or explicitly returned
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.img_size == 128:
            out = self.lrelu(self.conv0(x))
            out = self.lrelu(self.conv1(out))
        else:
            out = self.lrelu(self.conv1(x))
            
        out = self.lrelu(self.conv2(out))
        out = self.lrelu(self.conv3(out))
        out = self.lrelu(self.conv4(out))
        logits = self.conv5(out) # 2x2
        return torch.sigmoid(logits), logits

if __name__ == "__main__":
    # Test models
    z = torch.randn(10, 128)
    G = Generator()
    fake_img = G(z)
    print(f"Generator output shape: {fake_img.shape}") # Should be [10, 1, 64, 64]
    
    D = Discriminator()
    prob, logits = D(fake_img)
    print(f"Discriminator output shape: {prob.shape}") # Should be [10, 1, 2, 2]
