import torch
import torch.nn as nn
import torch.nn.functional as F
    
####################################
#       Omniglot Decoder          #
###################################
class OmniglotDecoder(nn.Module):
    def __init__(self):
        super(OmniglotDecoder, self).__init__()
        self.decoder = nn.Sequential(
            # 1→5    ⟵ 4×4 upsample
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=0, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # 5→5
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # 5→13   ⟵ 4×4 upsample
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=0, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # 13→13
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            # 13→28  ⟵ 4×4 upsample (needs output_padding=1 to land on 28)
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=0, output_padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            # 28→28
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
            # nn.Sigmoid(),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.decoder(x)
        return x
    
#######################################
#         28*28 Decoder         #
#######################################
class Decoder28x28(nn.Module):
    def __init__(self):
        super(Decoder28x28, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, output_padding=0),  # output: 16x7x7
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),  # output: 8x14x14
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # output: 1x28x28
            nn.ReLU(inplace=True)  # final activation for pixel values
        )
    def forward(self, x):
        x = self.decoder(x)
        return x

    
####################################
#          RN Block                #
####################################
class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super(ResnetBlock, self).__init__()

        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        self.conv_0 = nn.Conv2d(in_channels=fin, out_channels=self.fhidden, kernel_size=3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(in_channels=self.fhidden, out_channels=self.fout, kernel_size=3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(in_channels=fin, out_channels=self.fout, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn0 = nn.BatchNorm2d(self.fin)
        self.bn1 = nn.BatchNorm2d(self.fhidden)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(F.relu(self.bn0(x)))
        dx = self.conv_1(F.relu(self.bn1(dx)))
        out = x_s + 0.1 * dx
        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s
    
class Resnet_Decoder(nn.Module):
    def __init__(self, s0=2, nf=8, nf_max=256, size=32):
        super(Resnet_Decoder, self).__init__()

        self.s0 = s0
        self.nf = nf  
        self.nf_max = nf_max 

        nlayers = int(torch.log2(torch.tensor(size / s0).float()))
        self.nf0 = min(nf_max, nf * 2 ** nlayers)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2 ** (nlayers - i), nf_max)
            nf1 = min(nf * 2 ** (nlayers - i - 1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            ]
        blocks += [
            ResnetBlock(nf, nf),
        ]
        self.resnet = nn.Sequential(*blocks)

        self.bn0 = nn.BatchNorm2d(nf)
        self.conv_img = nn.ConvTranspose2d(nf, 3, kernel_size=3, padding=1)


    def forward(self, z):
        out = self.resnet(z)
        out = self.conv_img(F.relu(self.bn0(out)))
        out = F.relu(out)
        return out
