# Sentence model specification - real CUB image version
import os
import json

import numpy as np
import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.utils.data import DataLoader

from datasets_compatible import CUBSentences
from utils import Constants, FakeCategorical
from .vae import VAE

# Constants
maxSentLen = 32  # max length of any description for birds dataset
minOccur = 3
embeddingDim = 128
lenWindow = 3
fBase = 32
vocabSize = 1590
vocab_path = '../data/cub/oc:{}_msl:{}/cub.vocab'.format(minOccur, maxSentLen)

# Classes
class Enc(nn.Module):
    """ Generate latent parameters for sentence data. """

    def __init__(self, latentDim_w, latentDim_u):
        super(Enc, self).__init__()
        self.embedding = nn.Linear(vocabSize, embeddingDim)
        self.enc_w = nn.Sequential(
            # input size: 1 x 32 x 128
            nn.Conv2d(1, fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase),
            nn.ReLU(True),
            # size: (fBase) x 16 x 64
            nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 8 x 32
            nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase * 4),
            nn.ReLU(True)
        )
        self.enc_u = nn.Sequential(
            # input size: 1 x 32 x 128
            nn.Conv2d(1, fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase),
            nn.ReLU(True),
            # size: (fBase) x 16 x 64
            nn.Conv2d(fBase, fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 8 x 32
            nn.Conv2d(fBase * 2, fBase * 4, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase * 4),
            nn.ReLU(True),
            # # size: (fBase * 4) x 4 x 16
            nn.Conv2d(fBase * 4, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.Conv2d(fBase * 8, fBase * 16, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 4
        )
        self.c1_w = nn.Linear(fBase * 16 * 16, latentDim_w)
        self.c2_w = nn.Linear(fBase * 16 * 16, latentDim_w)

        self.c1_u = nn.Conv2d(fBase * 16, latentDim_u, 4, 1, 0, bias=True)
        self.c2_u = nn.Conv2d(fBase * 16, latentDim_u, 4, 1, 0, bias=True)

    def forward(self, x):
        x_emb = self.embedding(x).unsqueeze(1)
        e_w = self.enc_w(x_emb)
        e_w = e_w.view(-1, fBase * 16 * 16)
        mu_w, lv_w = self.c1_w(e_w), self.c2_w(e_w)
        e_u = self.enc_u(x_emb)
        mu_u, lv_u = self.c1_u(e_u).squeeze(), self.c2_u(e_u).squeeze()
        # mu_u, lv_u = self.c1_u(e_u).squeeze().unsqueeze(0), self.c2_u(e_u).squeeze().unsqueeze(0)
        return torch.cat((mu_w, mu_u), dim=1), \
               torch.cat((F.softplus(lv_w) + Constants.eta,
                          F.softplus(lv_u) + Constants.eta), dim=1)


class Dec(nn.Module):
    """ Generate a sentence given a sample from the latent space. """

    def __init__(self, latentDim_w, latentDim_u):
        super(Dec, self).__init__()
        self.dec_w = nn.Sequential(
            nn.ConvTranspose2d(latentDim_w, fBase * 16, 4, 1, 0, bias=True),
            nn.BatchNorm2d(fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 4
            nn.ConvTranspose2d(fBase * 16, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.ConvTranspose2d(fBase * 8, fBase * 4, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 4),
            nn.ReLU(True),
        )
        self.dec_u = nn.Sequential(
            nn.ConvTranspose2d(latentDim_u, fBase * 16, 4, 1, 0, bias=True),
            nn.BatchNorm2d(fBase * 16),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 4
            nn.ConvTranspose2d(fBase * 16, fBase * 8, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 4
            nn.ConvTranspose2d(fBase * 8, fBase * 8, 3, 1, 1, bias=True),
            nn.BatchNorm2d(fBase * 8),
            nn.ReLU(True),
            # size: (fBase * 8) x 4 x 8
            nn.ConvTranspose2d(fBase * 8, fBase * 4, (1, 4), (1, 2), (0, 1), bias=True),
            nn.BatchNorm2d(fBase * 4),
            nn.ReLU(True),
        )
        self.dec_h = nn.Sequential(
            nn.ConvTranspose2d(fBase * 8, fBase * 4, 3, 1, 1, bias=True),
            nn.BatchNorm2d(fBase * 4),
            nn.ReLU(True),
            # size: (fBase * 4) x 8 x 32
            nn.ConvTranspose2d(fBase * 4, fBase * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase * 2),
            nn.ReLU(True),
            # size: (fBase * 2) x 16 x 64
            nn.ConvTranspose2d(fBase * 2, fBase, 4, 2, 1, bias=True),
            nn.BatchNorm2d(fBase),
            nn.ReLU(True),
            # size: (fBase) x 32 x 128
            nn.ConvTranspose2d(fBase, 1, 4, 2, 1, bias=True),
            nn.ReLU(True)
            # Output size: 1 x 64 x 256
        )
        # inverts the 'embedding' module upto one-hotness
        self.toVocabSize = nn.Linear(embeddingDim, vocabSize)

        self.latent_dim_w = latentDim_w
        self.latent_dim_u = latentDim_u

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, z):
        #z = z.unsqueeze(-1).unsqueeze(-1)  # fit deconv layers
        w, u = torch.split(z, [self.latent_dim_w, self.latent_dim_u], dim=-1)
        u = u.unsqueeze(-1).unsqueeze(-1)
        hu = self.dec_u(u.view(-1, *u.size()[-3:]))
        w = w.unsqueeze(-1).unsqueeze(-1)
        hw = self.dec_w(w.view(-1, *w.size()[-3:]))
        h = torch.cat((hw, hu), dim=1)
        out = self.dec_h(h)
        out = out.view(*u.size()[:-3], *out.size()[1:]).view(-1, embeddingDim)
        # The softmax is key for this to work
        ret = [self.softmax(self.toVocabSize(out).view(*u.size()[:-3], maxSentLen, vocabSize))]
        return ret


class CUB_Sentence(VAE):
    """ Derive a specific sub-class of a VAE for a sentence model. """

    def __init__(self, params):
        super(CUB_Sentence, self).__init__(
            prior_dist=dist.Normal,
            likelihood_dist=dist.OneHotCategorical,
            post_dist=dist.Normal,
            enc=Enc(params.latent_dim_w, params.latent_dim_u),
            dec=Dec(params.latent_dim_w, params.latent_dim_u),
            params=params)
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), **grad)  # logvar
        ])
        grad_w = {'requires_grad': params.learn_prior_w_sent}
        self._pw_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim_w), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim_w), **grad_w)  # logvar
        ])
        self.modelName = 'cubS_conv'
        self.llik_scaling = 1.


        self.fn_2i = lambda t: t.cpu().numpy().astype(int)
        self.fn_trun = lambda s: s[:np.where(s == 2)[0][0] + 1] if 2 in s else s
        self.vocab_file = vocab_path

        self.maxSentLen = maxSentLen
        self.vocabSize = vocabSize

        self.i2w = self.load_vocab()


    @property
    def pz_params(self):
        return self._pz_params[0], \
            F.softplus(self._pz_params[1]) + Constants.eta

    @property
    def pw_params(self):
        return self._pw_params[0], \
               F.softplus(self._pw_params[1]) + Constants.eta

    @staticmethod
    def getDataLoaders(batch_size, shuffle=True, device="cuda"):
        kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
        tx = lambda data: torch.Tensor(data)
        t_data = CUBSentences('../data', split='train', transform=tx, one_hot=True, max_sequence_length=maxSentLen)
        s_data = CUBSentences('../data', split='test', transform=tx, one_hot=True, max_sequence_length=maxSentLen)

        train_loader = DataLoader(t_data, batch_size=batch_size, shuffle=shuffle, **kwargs)
        test_loader = DataLoader(s_data, batch_size=batch_size, shuffle=shuffle, **kwargs)

        return train_loader, test_loader

    def reconstruct(self, data, runPath, epoch):
        recon = super(CUB_Sentence, self).reconstruct(data[:8]).argmax(-1).squeeze()
        data = data.argmax(-1)
        print("\n Reconstruction examples (excluding <PAD>):")
        for r_sent, d_sent in zip(recon[:3], data[:3]):
            d_sent_mod = self.fn_trun(self.fn_2i(d_sent))
            r_sent_mod = self.fn_trun(self.fn_2i(r_sent))
            print('[DATA] ==> {}'.format(' '.join(self.i2w[str(i)] for i in d_sent_mod)))
            print('[RECON] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in r_sent_mod)))
        with open('{}/recon_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
            for r_sent, d_sent in zip(recon, data):
                d_sent_mod = self.fn_trun(self.fn_2i(d_sent))
                r_sent_mod = self.fn_trun(self.fn_2i(r_sent))
                txt_file.write('[DATA]  ==> {}'.format(' '.join(self.i2w[str(i)] for i in d_sent_mod)))
                txt_file.write('[RECON] ==> {}\n'.format(' '.join(self.i2w[str(i)] for i in r_sent_mod)))

    def generate(self, runPath, epoch):
        N, K = 5, 4
        samples = super(CUB_Sentence, self).generate(N, K).squeeze()
        samples = samples.view(K, N, maxSentLen, samples.size(-1)).transpose(0, 1)  # N x K x 64
        samples = samples.argmax(-1)
        print("\n Generated examples (excluding <PAD>):")
        for s_sent in samples[0][:3]:
            s_sent_mod = self.fn_trun(self.fn_2i(s_sent))
            print('[GEN]   ==> {}'.format(' '.join(self.i2w[str(i)] for i in s_sent_mod)))

        with open('{}/gen_samples_{:03d}.txt'.format(runPath, epoch), "w+") as txt_file:
            for s_sents in samples:
                for s_sent in s_sents:
                    s_sent_mod = self.fn_trun(self.fn_2i(s_sent))
                    txt_file.write('{}\n'.format(' '.join(self.i2w[str(i)] for i in s_sent_mod)))
                txt_file.write('\n')

    def analyse(self, data, runPath, epoch):
        pass

    def load_vocab(self):
        # call dataloader function to create vocab file
        if not os.path.exists(self.vocab_file):
            _, _ = self.getDataLoaders(256)
        with open(self.vocab_file, 'r') as vocab_file:
            vocab = json.load(vocab_file)
        return vocab['i2w']
