import torch
import torch.nn as nn
import torch.nn.functional as F


class Enc_Caption(nn.Module):
    def __init__(self, ndim_z, ndim_w):
        super(Enc_Caption, self).__init__()
        self.maxSentLen = 32
        self.minOccur = 3
        self.embeddingDim = 128
        self.lenWindow = 3
        self.fBase = 32
        self.vocabSize = 1590

        self.embedding = nn.Linear(self.vocabSize, self.embeddingDim)
        self.enc_w = nn.Sequential(
            # input size: 1 x 32 x 128
            nn.Conv2d(1, self.fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase),
            nn.ReLU(True),
            # size: (fBase) x 16 x 64
            nn.Conv2d(self.fBase, self.fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 8 x 32
            nn.Conv2d(self.fBase * 2, self.fBase * 4, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 4),
            nn.ReLU(True)
            # size: (fBase * 4) x 4 x 16
        )
        self.enc_z = nn.Sequential(
            # input size: 1 x 32 x 128
            nn.Conv2d(1, self.fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase),
            nn.ReLU(True),
            # size: (fBase) x 16 x 64
            nn.Conv2d(self.fBase, self.fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 8 x 32
            nn.Conv2d(self.fBase * 2, self.fBase * 4, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 4),
            nn.ReLU(True),
            # # size: (fBase * 4) x 4 x 16
            nn.Conv2d(self.fBase * 4, self.fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.Conv2d(self.fBase * 8, self.fBase * 16, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 16) x 4 x 4
        )
        self.c1_w = nn.Linear(self.fBase * 16 * 16, ndim_w)
        self.c2_w = nn.Linear(self.fBase * 16 * 16, ndim_w)

        self.c1_z = nn.Conv2d(self.fBase * 16, ndim_z, 4, 1, 0, bias=True)
        self.c2_z = nn.Conv2d(self.fBase * 16, ndim_z, 4, 1, 0, bias=True)

    def forward(self, x):
        x_emb = self.embedding(x).unsqueeze(1)
        e_w = self.enc_w(x_emb)
        e_w = e_w.view(-1, self.fBase * 16 * 16)
        mu_w, lv_w = self.c1_w(e_w), self.c2_w(e_w)
        e_z = self.enc_z(x_emb)
        mu_z, lv_z = self.c1_z(e_z).squeeze(-1).squeeze(-1), self.c2_z(e_z).squeeze(-1).squeeze(-1)

        lv_z = F.softplus(lv_z) + 1e-20
        lv_w = F.softplus(lv_w) + 1e-20

        return torch.cat((mu_z, mu_w), dim=-1), torch.cat((lv_z, lv_w), dim=-1)


class Dec_Caption(nn.Module):
    """ Generate a sentence given a sample from the latent space. """

    def __init__(self, ndim_z, ndim_w):
        super(Dec_Caption, self).__init__()
        self.maxSentLen = 32
        self.minOccur = 3
        self.embeddingDim = 128
        self.lenWindow = 3
        self.fBase = 32
        self.vocabSize = 1590

        self.dec_w = nn.Sequential(
            nn.ConvTranspose2d(ndim_w, self.fBase * 16, 4, 1, 0, bias=True),
            nn.BatchNorm2d(self.fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 16) x 4 x 4
            nn.ConvTranspose2d(self.fBase * 16, self.fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.ConvTranspose2d(self.fBase * 8, self.fBase * 4, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 4),
            nn.ReLU(True),
            # size: (fBase * 4) x 4 x 16
        )
        self.dec_z = nn.Sequential(
            nn.ConvTranspose2d(ndim_z, self.fBase * 16, 4, 1, 0, bias=True),
            nn.BatchNorm2d(self.fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 16) x 4 x 4
            nn.ConvTranspose2d(self.fBase * 16, self.fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.ConvTranspose2d(self.fBase * 8, self.fBase * 8, 3, 1, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.ConvTranspose2d(self.fBase * 8, self.fBase * 4, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(self.fBase * 4),
            nn.ReLU(True),
            # size: (fBase * 4) x 4 x 16
        )
        self.dec_h = nn.Sequential(
            nn.ConvTranspose2d(self.fBase * 8, self.fBase * 4, 3, 1, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 4),
            nn.ReLU(True),
            # size: (fBase * 4) x 4 x 16
            nn.ConvTranspose2d(self.fBase * 4, self.fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 8 x 32
            nn.ConvTranspose2d(self.fBase * 2, self.fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(self.fBase),
            nn.ReLU(True),
            # size: (fBase) x 16 x 64
            nn.ConvTranspose2d(self.fBase, 1, 4, 2, 1, bias=True),
            nn.ReLU(True)
            # Output size: 1 x 32 x 128
        )
        # inverts the 'embedding' module upto one-hotness
        self.toVocabSize = nn.Linear(self.embeddingDim, self.vocabSize)

        self.latent_dim_w = ndim_w
        self.latent_dim_z = ndim_z

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, u):
        z, w = torch.split(u, [self.latent_dim_z, self.latent_dim_w], dim=-1)
        z = z.unsqueeze(-1).unsqueeze(-1)
        hz = self.dec_z(z.view(-1, *z.size()[-3:]))
        w = w.unsqueeze(-1).unsqueeze(-1)
        hw = self.dec_w(w.view(-1, *w.size()[-3:]))
        h = torch.cat((hz, hw), dim=1)
        out = self.dec_h(h)
        ret = [self.softmax(self.toVocabSize(out).view(*z.size()[:-3], self.maxSentLen, self.vocabSize))]
        return ret