from torch import nn
import numpy as np
import torch
from torch.nn import functional as F
from .DataGeneratorImg import DataGeneratorImg
from .DataGeneratorText import DataGeneratorText
from .FeatureExtractorImg import FeatureExtractorImg
from .FeatureExtractorText import FeatureExtractorText
from .ResidualBlocks import ResidualBlock1dConv, ResidualBlock1dTransposeConv

def actvn(x):
    out = torch.nn.functional.leaky_relu(x, 2e-1)
    return out


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s

class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Unflatten(torch.nn.Module):
    def __init__(self, ndims):
        super(Unflatten, self).__init__()
        self.ndims = ndims

    def forward(self, x):
        return x.view(x.size(0), *self.ndims)

# Encoders
class EncoderMNIST(nn.Module):
    def __init__(self, hidden_dim, num_hidden_layers):
        super().__init__()
        modules = []
        modules.append(nn.Sequential(nn.Linear(784, hidden_dim), nn.ReLU(True)))
        modules.extend(
            [nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True)) for _ in range(num_hidden_layers - 1)])
        self.enc = nn.Sequential(*modules)

    def forward(self, x):
        h = x.view(*x.size()[:-3], -1)
        h = self.enc(h)
        h = h.view(h.size(0), -1)
        return h

class EncoderMMNIST(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """
    def __init__(self, hidden_dim):
        super().__init__()

        self.shared_encoder = nn.Sequential(                          # input shape (3, 28, 28)
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),     # -> (32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),    # -> (64, 7, 7)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),   # -> (128, 4, 4)
            nn.ReLU(),
            Flatten(),                                                # -> (2048)
            nn.Linear(2048, hidden_dim),                              # -> (ndim_private + ndim_shared)
            nn.ReLU(),
        )

    def forward(self, x):
        z_hat = self.shared_encoder(x)
        return z_hat

class EncoderTransMMNIST(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()

        s0 = self.s0 = 7
        nf = self.nf = 64
        nf_max = self.nf_max = 1024
        size = 28

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)
        self.fc_mu = nn.Linear(self.nf0*s0*s0, hidden_dim)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = actvn(self.resnet(out))
        out = out.view(batch_size, self.nf0*self.s0*self.s0)
        return actvn(self.fc_mu(out))

class EncoderMNISTLabel(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(10, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        z_hat = F.relu(self.fc1(x))
        z_hat = F.relu(self.fc2(z_hat))  
        return z_hat

class EncoderSVHN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0, dilation=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        h = self.conv1(x)
        h = self.relu(h)
        h = self.conv2(h)
        h = self.relu(h)
        h = self.conv3(h)
        h = self.relu(h)
        h = self.conv4(h)
        h = self.relu(h)
        h = h.view(h.size(0), -1)
        return h

class EncoderText(nn.Module):
    def __init__(self, dim, num_features):
        super().__init__()
        self.dim = dim
        self.conv1 = nn.Conv1d(num_features, 2 * self.dim, kernel_size=1)
        self.conv2 = nn.Conv1d(2 * self.dim, 2 * self.dim, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv5 = nn.Conv1d(2 * self.dim, 2 * self.dim, kernel_size=4, stride=2, padding=0, dilation=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.transpose(-2, -1)
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.conv5(out)
        out = self.relu(out)
        h = out.view(-1, 2 * self.dim)
        return h

class EncoderCelebA(nn.Module):
    def __init__(self, a_img=2.0, b_img=0.3):
        super().__init__()
        self.feature_extractor = FeatureExtractorImg(a=a_img, b=b_img)
        self.output = nn.Sequential(nn.Linear(640, 128), nn.BatchNorm1d(128), nn.ReLU(True))
        # 5*DIM_text

    def forward(self, x0):
        h_img = self.feature_extractor(x0)
        h_img = h_img.view(h_img.size(0), -1)
        h_img = self.output(h_img)

        return h_img

class EncoderCelebAText(nn.Module):
    def __init__(self, a_text=2.0, b_text=0.3):
        super().__init__()
        self.feature_extractor = FeatureExtractorText(a=a_text, b=b_text)
        self.output = nn.Sequential(nn.Linear(640, 128), nn.BatchNorm1d(128), nn.ReLU(True))
        # 5*DIM_text

    def forward(self, x1):
        h_text = self.feature_extractor(x1)
        h_text = h_text.view(h_text.size(0), -1)
        h_text = self.output(h_text)        

        return h_text

class EncoderCelebAAttribute(nn.Module):
    def __init__(self, y_dim=40, hidden_dim=128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(y_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
    def forward(self, x):
        return self.network(x)

class EncoderCubImage(nn.Module):
    def __init__(self, img_size=64, s0=1, nfilter=64, nfilter_max=1024, **kwargs):
        super().__init__()
        size = img_size
        s0 = self.s0 = s0
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = self.resnet(out)
        out = out.view(batch_size, self.nf0*self.s0*self.s0)

        return out

class EncoderCubText(nn.Module):
    def __init__(self, dim_text=32, vocab_size=1590):
        super().__init__()
        dim_text = dim_text
        num_features = vocab_size

        self.conv1 = nn.Conv1d(num_features, dim_text,
                               kernel_size=4, stride=2, padding=1, dilation=1)
        self.resblock_1 = self.make_res_block_encoder_feature_extractor(dim_text,
                                                                                 2 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_2 = self.make_res_block_encoder_feature_extractor(2 * dim_text,
                                                                                 3 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_3 = self.make_res_block_encoder_feature_extractor(3 * dim_text,
                                                                                 4 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_4 = self.make_res_block_encoder_feature_extractor(4 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_5 = self.make_res_block_encoder_feature_extractor(5 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_6 = self.make_res_block_encoder_feature_extractor(5 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=0,
                                                                                 dilation=1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        batch_size = x.size()[0]
        x = x.transpose(-2, -1)
        out = self.conv1(x)
        out = self.resblock_1(out)
        out = self.resblock_2(out)
        out = self.resblock_3(out)
        out = self.resblock_4(out)
        out = out.view(batch_size, -1)
        out = self.relu(out)
        return out

    @staticmethod
    def make_res_block_encoder_feature_extractor(in_channels, out_channels, kernelsize, stride, padding, dilation,
                                                 a_val=2.0, b_val=0.3):
        downsample = None
        if (stride != 1) or (in_channels != out_channels) or dilation != 1:
            downsample = nn.Sequential(nn.Conv1d(in_channels, out_channels,
                                                 kernel_size=kernelsize,
                                                 stride=stride,
                                                 padding=padding,
                                                 dilation=dilation),
                                       nn.BatchNorm1d(out_channels))
        layers = []
        layers.append(
            ResidualBlock1dConv(in_channels, out_channels, kernelsize, stride, padding, dilation, downsample, a=a_val,
                                b=b_val))
        return nn.Sequential(*layers)

class EncoderCubText2(nn.Module):
    def __init__(self, dim_text=32, vocab_size=1590):
        super().__init__()
        dim_text = dim_text
        num_features = vocab_size

        self.conv1 = nn.Conv1d(num_features, 4*dim_text,
                               kernel_size=4, stride=2, padding=1, dilation=1)
        self.resblock_1 = self.make_res_block_encoder_feature_extractor(4 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=1, padding=1,
                                                                                 dilation=1)
        self.resblock_2 = self.make_res_block_encoder_feature_extractor(5 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)
        self.resblock_3 = self.make_res_block_encoder_feature_extractor(5 * dim_text,
                                                                                 5 * dim_text,
                                                                                 kernelsize=4, stride=2, padding=1,
                                                                                 dilation=1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        batch_size = x.size()[0]
        x = x.transpose(-2, -1)
        out = self.conv1(x)
        out = self.resblock_1(out)
        out = self.resblock_2(out)
        out = self.resblock_3(out)
        out = out.view(batch_size, -1)
        out = self.relu(out)
        return out
    
    @staticmethod
    def make_res_block_encoder_feature_extractor(in_channels, out_channels, kernelsize, stride, padding, dilation,
                                                 a_val=2.0, b_val=0.3):
        downsample = None
        if (stride != 1) or (in_channels != out_channels) or dilation != 1:
            downsample = nn.Sequential(nn.Conv1d(in_channels, out_channels,
                                                 kernel_size=kernelsize,
                                                 stride=stride,
                                                 padding=padding,
                                                 dilation=dilation),
                                       nn.BatchNorm1d(out_channels))
        layers = []
        layers.append(
            ResidualBlock1dConv(in_channels, out_channels, kernelsize, stride, padding, dilation, downsample, a=a_val,
                                b=b_val))
        return nn.Sequential(*layers)            

# Decoders
class DecoderMNIST(nn.Module):
    def __init__(self, dim, hidden_dim, num_hidden_layers, img_size_mnist):
        super().__init__()
        modules = []
        modules.append(nn.Sequential(nn.Linear(dim, hidden_dim), nn.ReLU(True)))
        modules.extend(
            [nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True)) for _ in range(num_hidden_layers - 1)])
        self.dec = nn.Sequential(*modules)
        self.fc3 = nn.Linear(hidden_dim, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.img_size_mnist = img_size_mnist

    def forward(self, z):
        x_hat = self.dec(z)
        x_hat = self.fc3(x_hat)
        x_hat = self.sigmoid(x_hat)
        x_hat = x_hat.view(*z.size()[:-1], *(1, self.img_size_mnist, self.img_size_mnist))
        return x_hat

class DecoderMMNIST(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """
    def __init__(self, dim):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(dim, 2048),                                                              # -> (2048)
            nn.ReLU(),
            Unflatten((128, 4, 4)),                                                            # -> (128, 4, 4)
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),                   # -> (64, 7, 7)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> (32, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1),   # -> (3, 28, 28)
        )

    def forward(self, z):
        x_hat = self.decoder(z)
        return x_hat

class DecoderTransMMNIST(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """
    def __init__(self, dim):
        super().__init__()

        s0 = self.s0 = 7
        nf = self.nf = 64
        nf_max = self.nf_max = 512
        size = 28

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        self.fc = nn.Linear(dim, self.nf0*s0*s0)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2**(nlayers-i), nf_max)
            nf1 = min(nf * 2**(nlayers-i-1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2)
            ]

        blocks += [
            ResnetBlock(nf, nf),
        ]

        self.resnet = nn.Sequential(*blocks)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, z):
        batch_size = z.size(0)
        out = self.fc(z)
        out = out.view(batch_size, self.nf0, self.s0, self.s0)
        out = self.resnet(out)
        out = self.conv_img(actvn(out))
        return out

class DecoderMNISTLabel(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 10)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return F.softmax(self.fc3(h), dim=1)

class DecoderSVHN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear_factorized = nn.Linear(dim, 128)
        self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=1, padding=0, dilation=1)
        self.conv2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, dilation=1)
        self.relu = nn.ReLU()

    def forward(self, z):
        z = self.linear_factorized(z)
        z = z.view(z.size(0), z.size(1), 1, 1)
        x_hat = self.relu(z)
        x_hat = self.conv1(x_hat)
        x_hat = self.relu(x_hat)
        x_hat = self.conv2(x_hat)
        x_hat = self.relu(x_hat)
        x_hat = self.conv3(x_hat)
        x_hat = self.relu(x_hat)
        x_hat = self.conv4(x_hat)
        return x_hat

class DecoderText(nn.Module):
    def __init__(self, dim, num_features):
        super().__init__()
        self.linear_factorized = nn.Linear(dim, 2 * dim)
        self.conv1 = nn.ConvTranspose1d(2 * dim, 2 * dim, kernel_size=4, stride=1, padding=0, dilation=1)
        self.conv2 = nn.ConvTranspose1d(2 * dim, 2 * dim, kernel_size=4, stride=2, padding=1, dilation=1)
        self.conv_last = nn.Conv1d(2 * dim, num_features, kernel_size=1)
        self.relu = nn.ReLU()
        self.out_act = nn.Softmax(dim=-2)

    def forward(self, z):
        z = self.linear_factorized(z)
        x_hat = z.view(z.size(0), z.size(1), 1)
        x_hat = self.conv1(x_hat)
        x_hat = self.relu(x_hat)
        x_hat = self.conv2(x_hat)
        x_hat = self.relu(x_hat)
        x_hat = self.conv_last(x_hat)
        prob = self.out_act(x_hat)
        prob = prob.transpose(-2, -1)
        return prob

class DecoderCelebA(nn.Module):
    def __init__(self, z_dim=32, a_img=2.0, b_img=0.3, num_layers_img=5, DIM_img=128):
        super().__init__()
        self.feature_generator = nn.Linear(z_dim, num_layers_img * DIM_img, bias=True)
        self.img_generator = DataGeneratorImg(a=a_img, b=b_img)

    def forward(self, c):
        img_feat_hat = self.feature_generator(c)
        img_feat_hat = img_feat_hat.view(img_feat_hat.size(0), img_feat_hat.size(1), 1, 1)
        img_hat = self.img_generator(img_feat_hat)
        return img_hat

class DecoderCelebAText(nn.Module):
    def __init__(self, z_dim=32, a_text=2.0, b_text=0.3, DIM_text=128):
        super(DecoderCelebAText, self).__init__()
        self.feature_generator = nn.Linear(z_dim,
                                           5*DIM_text, bias=True)
        self.text_generator = DataGeneratorText(a=a_text, b=b_text)

    def forward(self, c):
        text_feat_hat = self.feature_generator(c)
        text_feat_hat = text_feat_hat.unsqueeze(-1)
        text_hat = self.text_generator(text_feat_hat)
        text_hat = text_hat.transpose(-2,-1)
        return text_hat

class DecoderCelebAAttribute(nn.Module):
    def __init__(self, z_dim, hidden_dim=128, y_dim=40):
        super(DecoderCelebAAttribute, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, y_dim),
        )

    def forward(self, y):
        return torch.sigmoid(self.network(y))

class DecoderCubImage(nn.Module):
    def __init__(self, z_dim, s0=1, img_size=64, nfilter=64, nfilter_max=512, **kwargs):
        super().__init__()
        size = img_size
        self.z_dim = z_dim

        s0 = self.s0 = s0
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        self.fc = nn.Linear(self.z_dim, self.nf0*s0*s0)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2**(nlayers-i), nf_max)
            nf1 = min(nf * 2**(nlayers-i-1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2)
            ]

        blocks += [
            ResnetBlock(nf, nf),
        ]

        self.resnet = nn.Sequential(*blocks)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, z):
        batch_size = z.size(0)
        out = self.fc(z)
        out = out.view(batch_size, self.nf0, self.s0, self.s0)
        out = self.resnet(out)
        out = self.conv_img(actvn(out))
        return out

class DecoderCubText(nn.Module):
    def __init__(self, z_dim, dim_text=32, vocab_size=1590):
        super().__init__()
        num_features = vocab_size
        self.fc = nn.Linear(z_dim, 5 * dim_text)
        self.resblock_1 = self.res_block_decoder(5 * dim_text, 5 * dim_text,
                                                      kernelsize=4, stride=1, padding=0, dilation=1, o_padding=0)
        self.resblock_2 = self.res_block_decoder(5 * dim_text, 5 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.resblock_3 = self.res_block_decoder(5 * dim_text, 4 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.resblock_4 = self.res_block_decoder(4 * dim_text, 3 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.resblock_5 = self.res_block_decoder(3 * dim_text, 2 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.resblock_6 = self.res_block_decoder(2 * dim_text, dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.conv2 = nn.ConvTranspose1d(dim_text, num_features,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        dilation=1,
                                        output_padding=0)
        self.softmax = nn.Softmax(dim=-2)

    def forward(self, z):
        d = self.fc(z)
        d = d.unsqueeze(-1)
#        d = self.resblock_1(d)
#        d = self.resblock_2(d)        
        d = self.resblock_3(d)
        d = self.resblock_4(d)
        d = self.resblock_5(d)
        d = self.resblock_6(d)
        d = self.conv2(d)
        d = self.softmax(d)
        d = d.transpose(-2, -1)
        return d

    @staticmethod
    def res_block_decoder(in_channels, out_channels, kernelsize, stride, padding, o_padding, dilation, a_val=2.0,
                          b_val=0.3):
        upsample = None

        if (kernelsize != 1 or stride != 1) or (in_channels != out_channels) or dilation != 1:
            upsample = nn.Sequential(nn.ConvTranspose1d(in_channels, out_channels,
                                                        kernel_size=kernelsize,
                                                        stride=stride,
                                                        padding=padding,
                                                        dilation=dilation,
                                                        output_padding=o_padding),
                                     nn.BatchNorm1d(out_channels))
        layers = []
        layers.append(
            ResidualBlock1dTransposeConv(in_channels, out_channels, kernelsize, stride, padding, dilation, o_padding,
                                         upsample=upsample, a=a_val, b=b_val))
        return nn.Sequential(*layers)

class DecoderCubText2(nn.Module):
    def __init__(self, z_dim, dim_text=32, vocab_size=1590):
        super().__init__()
        num_features = vocab_size
        self.fc = nn.Linear(z_dim, 5 * dim_text)
        self.resblock_1 = self.res_block_decoder(5 * dim_text, 5 * dim_text,
                                                      kernelsize=4, stride=1, padding=0, dilation=1, o_padding=0)
        self.resblock_2 = self.res_block_decoder(5 * dim_text, 5 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.resblock_3 = self.res_block_decoder(5 * dim_text, 4 * dim_text,
                                                      kernelsize=4, stride=2, padding=1, dilation=1, o_padding=0)
        self.conv2 = nn.ConvTranspose1d(4*dim_text, num_features,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        dilation=1,
                                        output_padding=0)
        self.softmax = nn.Softmax(dim=-2)

    def forward(self, z):
        d = self.fc(z)
        d = d.unsqueeze(-1)
        d = self.resblock_1(d)
        d = self.resblock_2(d)        
        d = self.resblock_3(d)
        d = self.conv2(d)
        d = self.softmax(d)
        d = d.transpose(-2, -1)
        return d

    @staticmethod
    def res_block_decoder(in_channels, out_channels, kernelsize, stride, padding, o_padding, dilation, a_val=2.0,
                          b_val=0.3):
        upsample = None

        if (kernelsize != 1 or stride != 1) or (in_channels != out_channels) or dilation != 1:
            upsample = nn.Sequential(nn.ConvTranspose1d(in_channels, out_channels,
                                                        kernel_size=kernelsize,
                                                        stride=stride,
                                                        padding=padding,
                                                        dilation=dilation,
                                                        output_padding=o_padding),
                                     nn.BatchNorm1d(out_channels))
        layers = []
        layers.append(
            ResidualBlock1dTransposeConv(in_channels, out_channels, kernelsize, stride, padding, dilation, o_padding,
                                         upsample=upsample, a=a_val, b=b_val))
        return nn.Sequential(*layers)