from __future__ import print_function

import numpy as np

import math

from scipy.special import logsumexp


import torch
import torch.nn.functional as F
import torch.utils.data
import torch.nn as nn
from torch.nn import Linear
from torch.autograd import Variable

from ..utils.distributions import log_Bernoulli, log_Normal_diag, log_Normal_standard, log_Logistic_256
from ..utils.visual_evaluation import plot_histogram
from ..utils.nn import he_init, GatedDense, NonLinear

from .Model import Model
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

#=======================================================================================================================
class VAE(Model):
    def __init__(self, args):
        super(VAE, self).__init__(args)

        self.input_size = args.input_size
        self.model_name = args.model_name
        self.n_channels = args.n_channels
        self.device = 'cuda' if torch.cuda.is_available else 'cpu'
        self.architecture = args.architecture
        self.dataset = args.dataset

        assert args.architecture in ['mlp', 'convnet'], "Architecture must be in ['mlp', convnet']"


        if args.dataset == 'mnist' or args.dataset == 'cifar': # SAME as From Var. to Deter. AE (check num of params)
            self.conv = torch.nn.Sequential(
                        nn.Conv2d(self.n_channels, 128, 4, 2, padding=1),
                        nn.BatchNorm2d(128),
                        nn.ReLU(),
                        nn.Conv2d(128, 256, 4, 2, padding=1),
                        nn.BatchNorm2d(256),
                        nn.ReLU(),
                        nn.Conv2d(256, 512, 4, 2, padding=1),
                        nn.BatchNorm2d(512),
                        nn.ReLU(),
                        nn.Conv2d(512, 1024, 4, 2, padding=1),
                        nn.BatchNorm2d(1024),
                        nn.ReLU(),
                    )

            self.fc21 = nn.Linear(1024*2*2, args.z1_size)
            self.fc22 = nn.Linear(1024*2*2, args.z1_size)

            self.fc3 = nn.Linear(args.z1_size, 1024*8*8)
            #self.fc4 = nn.Linear(16, 1024*8*8)
            self.deconv = nn.Sequential(
                    nn.ConvTranspose2d(1024, 512, 4, 2, padding=1),
                    nn.BatchNorm2d(512),
                    nn.ReLU(),
                    nn.ConvTranspose2d(512, 256, 4, 2, padding=1, output_padding=1),
                    nn.BatchNorm2d(256),
                    nn.ReLU(),
                    nn.ConvTranspose2d(256, self.n_channels, 4, 1, padding=2),
                    #nn.BatchNorm2d(self.n_channels),
                    #nn.ReLU(),
                    #nn.ConvTranspose2d(self.n_channels, self.n_channels, 3, 1, padding=1),
                    #nn.BatchNorm2d(self.n_channels),
                    nn.Sigmoid()
            )

        elif args.dataset == 'celeba':
                self.conv = torch.nn.Sequential(
                        nn.Conv2d(self.n_channels, 128, 5, 2, padding=1),
                        nn.BatchNorm2d(128),
                        nn.ReLU(),
                        nn.Conv2d(128, 256, 5, 2, padding=1),
                        nn.BatchNorm2d(256),
                        nn.ReLU(),
                        nn.Conv2d(256, 512, 5, 2, padding=2),
                        nn.BatchNorm2d(512),
                        nn.ReLU(),
                        nn.Conv2d(512, 1024, 5, 2, padding=2),
                        nn.BatchNorm2d(1024),
                        nn.ReLU(),
                    )

                #self.fc1 = nn.Linear(1024*2*2, 128)
                self.fc21 = nn.Linear(1024*4*4, args.z1_size)
                self.fc22 = nn.Linear(1024*4*4, args.z1_size)
                #self.fc22 = nn.Linear(400, args.latent_dim)

                self.fc3 = nn.Linear(args.z1_size, 1024*8*8)
                #self.fc4 = nn.Linear(128, 1024*8*8)
                self.deconv = nn.Sequential(
                              nn.ConvTranspose2d(1024, 512, 5, 2, padding=2),
                              nn.BatchNorm2d(512),
                              nn.ReLU(),
                              nn.ConvTranspose2d(512, 256, 5, 2, padding=1, output_padding=0),
                              nn.BatchNorm2d(256),
                              nn.ReLU(),
                              nn.ConvTranspose2d(256, 128, 5, 2, padding=2, output_padding=1),
                              nn.BatchNorm2d(128),
                              nn.ReLU(),
                              nn.ConvTranspose2d(128, 3, 5, 1, padding=1),
                              #nn.BatchNorm2d(self.n_channels),
                              #nn.ReLU(),
                              #nn.ConvTranspose2d(self.n_channels, self.n_channels, 3, 1, padding=1),
                              #nn.BatchNorm2d(self.n_channels),
                              nn.Sigmoid())

        elif args.dataset == 'oasis':

            self.conv = torch.nn.Sequential(
                nn.Conv2d(self.n_channels, 64, 5, 2, padding=1),
                #nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 128, 5, 2, padding=1),
                #nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 256, 5, 2, padding=1),
                #nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 512, 5, 2, padding=(1, 2)),
                #nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.Conv2d(512, 1024, 5, 2, padding=0),
                #nn.BatchNorm2d(1024),
                nn.ReLU(),
                #nn.Conv2d(1024, 20, 5, 2, padding=1),
                #nn.BatchNorm2d(1024),
                #nn.ReLU(),
            )
            self.fc21 = nn.Linear(1024*4*4, args.z1_size)
            if not args.model_name == 'AE':
                self.fc22 = nn.Linear(1024*4*4, args.z1_size)

            self.fc3 = nn.Linear(args.z1_size, 1024*8*8)
            self.deconv = nn.Sequential(
                            nn.ConvTranspose2d(1024, 512, 5, (3, 2), padding=(1, 0), output_padding=(0, 0)),
                            #nn.BatchNorm2d(512),
                            nn.ReLU(),
                            nn.ConvTranspose2d(512, 256, 5, 2, padding=(1, 0), output_padding=(0, 0)),
                            #nn.BatchNorm2d(256),
                            nn.ReLU(),
                            nn.ConvTranspose2d(256, 128, 5, 2, padding=0, output_padding=(0, 0)),
                            #nn.BatchNorm2d(128),
                            nn.ReLU(),
                            nn.ConvTranspose2d(128, 64, 5, 2, padding=0, output_padding=(1, 1)),
                            #nn.BatchNorm2d(64),
                            nn.ReLU(),
                            nn.ConvTranspose2d(64, self.n_channels, 5, 1, padding=1),
                            #nn.BatchNorm2d(self.n_channels),
                            nn.Sigmoid())

        else:
            # encoder network
            self.fc1 = nn.Linear(args.input_size, 1000)
            self.fc11 = nn.Linear(1000, 500)
            #self.fc111 = nn.Linear(500, 500)
            #self.fc1111 = nn.Linear(500, 500)
            self.fc21 = nn.Linear(500, args.z1_size)
            self.fc22 = nn.Linear(500, args.z1_size)

            # decoder network
            self.fc3 = nn.Linear(args.z1_size, 500)
            self.fc33 = nn.Linear(500, 1000)
            #self.fc333 = nn.Linear(500, 500)
            #self.fc3333 = nn.Linear(500, 500)
            self.fc4 = nn.Linear(1000, args.input_size)


        # encoder: q(z | x)
        #self.q_z_layers = nn.Sequential(
        #    GatedDense(np.prod(self.args.input_size), 300),
        #    GatedDense(300, 300)
        #)
#
        #self.q_z_mean = Linear(300, self.args.z1_size)
        #self.q_z_logvar = NonLinear(300, self.args.z1_size, activation=nn.Hardtanh(min_val=-6.,max_val=2.))
#
        ## decoder: p(x | z)
        #self.p_x_layers = nn.Sequential(
        #    GatedDense(self.args.z1_size, 300),
        #    GatedDense(300, 300)
        #)

        #if self.args.input_type == 'binary':
        #    self.p_x_mean = NonLinear(300, np.prod(self.args.input_size), activation=nn.Sigmoid())
        #elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
        #    self.p_x_mean = NonLinear(300, np.prod(self.args.input_size), activation=nn.Sigmoid())
        #    self.p_x_logvar = NonLinear(300, np.prod(self.args.input_size), activation=nn.Hardtanh(min_val=-4.5,max_val=0))

        # weights initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                he_init(m)

        # add pseudo-inputs if VampPrior
        if self.args.prior == 'vampprior':
            self.add_pseudoinputs()

    # AUXILIARY METHODS
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''

        # pass through VAE
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)
        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            #RE = -log_Logistic_256(x.squeeze(1), x_mean, x_logvar, dim=1)
            RE = -F.mse_loss(x.reshape(x.shape[0], -1), x_mean.reshape(x.shape[0], -1), reduction='none').sum(dim=1)
        else:
            raise Exception('Wrong input type!')
            

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = - RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL

    def calculate_likelihood(self, X, dir, mode='test', S=5000, MB=100):
        # set auxiliary variables for number of training and test sets
        N_test = X.size(0)

        # init list
        likelihood_test = []

        if S <= MB:
            R = 1
        else:
            R = S / MB
            S = MB

        for j in range(N_test):
            if j % 100 == 0:
                print('{:.2f}%'.format(j / (1. * N_test) * 100))
            # Take x*
            x_single = X[j].unsqueeze(0)

            a = []
            for r in range(0, int(R)):
                # Repeat it for all training points
                x = x_single.expand(S, x_single.size(1))

                a_tmp, _, _ = self.calculate_loss(x)

                a.append( -a_tmp.cpu().data.numpy() )

            # calculate max
            a = np.asarray(a)
            a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
            likelihood_x = logsumexp( a )
            likelihood_test.append(likelihood_x - np.log(len(a)))

        likelihood_test = np.array(likelihood_test)

        plot_histogram(-likelihood_test, dir, mode)

        return -np.mean(likelihood_test)

    def calculate_lower_bound(self, X_full, MB=100):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        I = int(math.ceil(X_full.size(0) / MB))

        for i in range(I):
            x = X_full[i * MB: (i + 1) * MB].view(-1, np.prod(self.args.input_size))

            loss, RE, KL = self.calculate_loss(x,average=True)

            RE_all += RE.cpu().data[0]
            KL_all += KL.cpu().data[0]
            lower_bound += loss.cpu().data[0]

        lower_bound /= I

        return lower_bound

    # ADDITIONAL METHODS
    def generate_x(self, N=25):
        if self.args.prior == 'standard':
            z_sample_rand = Variable( torch.FloatTensor(N, self.args.z1_size).normal_() )
            if self.args.cuda:
                z_sample_rand = z_sample_rand.cuda()

        elif self.args.prior == 'vampprior':
            means = self.means(self.idle_input)[0:N]
            z_sample_gen_mean, z_sample_gen_logvar = self.q_z(means)
            z_sample_rand = self.reparameterize(z_sample_gen_mean, z_sample_gen_logvar)

        samples_rand, _ = self.p_x(z_sample_rand)
        return samples_rand

    def reconstruct_x(self, x):
        x_mean, _, _, _, _ = self.forward(x)
        return x_mean

    # THE MODEL: VARIATIONAL POSTERIOR
    def q_z(self, x):
        if self.architecture == 'convnet':
            if self.dataset == 'oasis':
                x = self.conv(x.reshape(-1, self.n_channels, 208, 176)).reshape(x.shape[0], -1)
            else:
                x = self.conv(x.reshape(-1, self.n_channels, int((self.input_size / self.n_channels)**0.5), int((self.input_size / self.n_channels)**0.5))).reshape(x.shape[0], -1)

        elif self.architecture == 'mlp':
            x = F.relu(self.fc1(x.reshape(-1, self.input_size)))
            x = F.relu(self.fc11(x))

        z_q_mean = self.fc21(x)
        z_q_logvar = self.fc22(x)

        if self.dataset == 'oasis':
            return z_q_mean, torch.tanh(z_q_logvar)    
        
        return z_q_mean, z_q_logvar

    # THE MODEL: GENERATIVE DISTRIBUTION
    def p_x(self, z):
       
        if self.architecture == 'convnet':
            z = self.fc3(z)
            x_mean = self.deconv(z.reshape(-1, 1024, 8, 8))
            x_logvar = torch.tensor([0]).to(self.device)

        elif self.architecture == 'mlp':
            h3 = F.relu(self.fc3(z))
            h3 = F.relu(self.fc33(h3))
            x_mean = torch.sigmoid(self.fc4(h3))
            x_logvar = torch.tensor([0]).to(self.device)

        #if self.args.input_type == 'binary':
        #    x_logvar = 0.
        #else:
        #    x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.)
        #    x_logvar = 0 # self.p_x_logvar(z)
        return x_mean.reshape(z.shape[0], -1), x_logvar

    # the prior
    def log_p_z(self, z):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            #print(log_Normal_diag(z_expand, means, logvars, dim=2))
            a_max, _ = torch.max(a, 1)  # MB x 1
            #print('a', a_max)

            # calculte log-sum-exp
            log_prior = a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior

    # THE MODEL: FORWARD PASS
    def forward(self, x):
        # z ~ q(z | x)
        z_q_mean, z_q_logvar = self.q_z(x)
        z_q = self.reparameterize(z_q_mean, z_q_logvar)

        # x_mean = p(x|z)
        x_mean, x_logvar = self.p_x(z_q)

        return x_mean, x_logvar, z_q, z_q_mean, z_q_logvar
