import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        # SNGAN paper recommends orthogonal initialization
        if hasattr(m, 'weight_orig'):
            # Handle spectral_norm layers
            nn.init.orthogonal_(m.weight_orig.data)
        elif hasattr(m, 'weight'):
            nn.init.orthogonal_(m.weight.data)
            
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


class GenResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, upsample=False):
        super(GenResBlock, self).__init__()
        self.upsample = upsample
        self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.activation = nn.ReLU()
        
        # Shortcut connection
        self.c_sc = nn.Conv2d(in_ch, out_ch, 1, 1, 0) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        h = x
        h = self.activation(self.bn1(h))
        # Upsampling is typically placed before the first convolution
        if self.upsample:
            h = F.interpolate(h, scale_factor=2)
        h = self.c1(h)
        h = self.activation(self.bn2(h))
        h = self.c2(h)

        # Handle Shortcut
        sc = x
        if self.upsample:
            sc = F.interpolate(sc, scale_factor=2)
        sc = self.c_sc(sc)

        return h + sc

class DiscResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, downsample=False):
        super(DiscResBlock, self).__init__()
        self.downsample = downsample
        self.activation = nn.ReLU()
        
        # Discriminator uses Spectral Norm
        self.c1 = spectral_norm(nn.Conv2d(in_ch, out_ch, 3, 1, 1))
        self.c2 = spectral_norm(nn.Conv2d(out_ch, out_ch, 3, 1, 1))
        
        self.c_sc = spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0)) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        h = x
        h = self.activation(h)
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.downsample:
            h = F.avg_pool2d(h, 2)

        sc = x
        sc = self.c_sc(sc)
        if self.downsample:
            sc = F.avg_pool2d(sc, 2)

        return h + sc

class Generator(nn.Module):
    def __init__(self, z_dim=128, nc=1, img_size=128, ngf=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.img_size = img_size
        
        # SNGAN default starts from 4x4
        # 64x64 needs 4 upsamples (4->8->16->32->64)
        # 128x128 needs 5 upsamples
        
        self.linear = nn.Linear(z_dim, ngf * 16 * 4 * 4)
        
        # Define Block list
        self.blocks = nn.ModuleList()
        
        # Block 1: 4 -> 8
        self.blocks.append(GenResBlock(ngf * 16, ngf * 16, upsample=True))
        # Block 2: 8 -> 16
        self.blocks.append(GenResBlock(ngf * 16, ngf * 8, upsample=True))
        # Block 3: 16 -> 32
        self.blocks.append(GenResBlock(ngf * 8, ngf * 4, upsample=True))
        # Block 4: 32 -> 64
        self.blocks.append(GenResBlock(ngf * 4, ngf * 2, upsample=True))
        
        # If 128x128, add a Block
        if self.img_size >= 128:
            # Block 5: 64 -> 128
            self.blocks.append(GenResBlock(ngf * 2, ngf, upsample=True))
            out_ch = ngf
        else:
            out_ch = ngf * 2
        
        # If 256x256, add another Block
        if self.img_size >= 256:
            self.blocks.append(GenResBlock(ngf, ngf // 2, upsample=True))
            out_ch = ngf // 2

        self.bn_end = nn.BatchNorm2d(out_ch)
        self.activation = nn.ReLU()
        self.conv_end = nn.Conv2d(out_ch, nc, 3, 1, 1)
        self.tanh = nn.Tanh()

        self.apply(weights_init)

    def forward(self, z):
        # Linear map + Reshape
        out = self.linear(z)
        out = out.view(-1, out.shape[1] // 16, 4, 4)
        
        # Pass through all ResBlocks
        for block in self.blocks:
            out = block(out)
            
        # Final processing
        out = self.activation(self.bn_end(out))
        out = self.conv_end(out)
        out = self.tanh(out)
        return out

class Discriminator(nn.Module):
    def __init__(self, nc=1, img_size=128, ndf=64):
        super(Discriminator, self).__init__()
        self.img_size = img_size

        self.blocks = nn.ModuleList()
        
        # Head Layer
        if self.img_size == 256:
            # 256 -> 128
            self.blocks.append(nn.Sequential(
                spectral_norm(nn.Conv2d(nc, ndf // 2, 3, 1, 1)),
                nn.ReLU(),
                nn.AvgPool2d(2)
            ))
            cur_dim = ndf // 2
        elif self.img_size == 128:
            # 128 -> 64
            self.blocks.append(nn.Sequential(
                spectral_norm(nn.Conv2d(nc, ndf, 3, 1, 1)),
                nn.ReLU(),
                nn.AvgPool2d(2)
            ))
            cur_dim = ndf
        else:
            # 64 -> 64 (No downsample at head)
            self.blocks.append(spectral_norm(nn.Conv2d(nc, ndf, 3, 1, 1)))
            cur_dim = ndf

        if self.img_size == 256:
            # Extra Block: 128 -> 64
            self.blocks.append(DiscResBlock(cur_dim, ndf, downsample=True))
            cur_dim = ndf
        
        # Standard path from 64 down to 4
        # 64 -> 32
        self.blocks.append(DiscResBlock(cur_dim, ndf * 2, downsample=True))
        # 32 -> 16
        self.blocks.append(DiscResBlock(ndf * 2, ndf * 4, downsample=True))
        # 16 -> 8
        self.blocks.append(DiscResBlock(ndf * 4, ndf * 8, downsample=True))
        # 8 -> 4
        self.blocks.append(DiscResBlock(ndf * 8, ndf * 16, downsample=True))
        
        self.activation = nn.ReLU()
        self.linear = spectral_norm(nn.Linear(ndf * 16, 1))

        self.apply(weights_init)

    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
            
        out = self.activation(out)
        out = torch.sum(out, dim=[2, 3]) # Global Sum Pooling
        logits = self.linear(out)
        logits = logits.view(-1, 1, 1, 1)
        
        return torch.sigmoid(logits), logits

if __name__ == "__main__":
    device = "cpu"
    # Test 128x128, RGB
    z = torch.randn(10, 128).to(device)
    G = Generator(z_dim=128, nc=3, img_size=128).to(device)
    D = Discriminator(nc=3, img_size=128).to(device)
    
    fake_img = G(z)
    print(f"G output: {fake_img.shape}") # Should be [10, 3, 128, 128]
    
    prob, logits = D(fake_img)
    print(f"D output: {prob.shape}, {logits.shape}") # Should be [10, 1, 1, 1]
    
    # Verify compatibility with 64x64
    G_64 = Generator(z_dim=128, nc=3, img_size=64).to(device)
    D_64 = Discriminator(nc=3, img_size=64).to(device)
    print(f"G_64 output: {G_64(z).shape}")