import numpy as np
import tensorflow as tf
import sys, os

sys.path.append('utils/')
from load_save import save_params, load_params
from visualisation import plot_gif, plot_gif_mnist

sys.path.append('models/')
from Disc_Cont import content_disc
# from Ladder_MNIST_12 import prior, encoder, decoder, predictor
from Res_MNIST import prior, encoder, decoder, predictor
from helper_functions import *
import numpy as np
import argparse
import scipy.io
import cal_mmd
import time
from tools_funs import *
from load_data import load_dataset
from torch.utils.data import DataLoader
def sample_gaussian(mu, log_sig):
    return mu + tf.exp(log_sig) * tf.random_normal(mu.get_shape())

kernel_choice = 'mix_rbf'
def mmd_eval(z_tilde, z,  kernel_choice = 'mix_rbf'):
    kernel = getattr(cal_mmd, '_%s_kernel' % kernel_choice)
    kerGI = kernel(z_tilde, z)
    mmd_val =cal_mmd.mmd2(kerGI)
    return mmd_val

def gan_loss(fake, real, discriminator, batch_size):

    e = tf.random_uniform((batch_size, 1), minval=0.0, maxval=1.0)
    z_hat = e * real + (1.0 - e) * fake
    '''losses'''
    D_z_tilde = tf.squeeze(discriminator(fake), axis=1)  # (B,1) -> (B,)
    D_z = tf.squeeze(discriminator(real), axis=1)
    D_z_hat = tf.squeeze(discriminator(z_hat), axis=1)
    # calculage Loss
    critic_loss = tf.reduce_mean(D_z_tilde - D_z, axis=0)
    grad = tf.reshape(tf.gradients(D_z_hat, [z_hat])[0], (batch_size, -1))
    gradient_penalty = tf.reduce_mean((tf.norm(grad, axis=1) - 1.0) ** 2,axis=0)
    return critic_loss, gradient_penalty

def recon(x, gen, enc, predictor, dimH):
    batch_size, T, _, _, _ = x.get_shape().as_list()
    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    # f = sample_gaussian(mu_qf, log_sig_qf)
    f = mu_qf
    # then infer z, do bidiretional lstm
    mu_q, log_sig_q, z_seq = enc.qzx(fea_seq, T)
    h_pred_prev = tf.zeros(shape=(batch_size, dimH))
    c_pred_prev = tf.zeros(shape=(batch_size, dimH))
    x_recon = []
    for t in range(T):
        z = mu_q[t]
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z], axis=1), h_pred_prev, c_pred_prev)
        mu_x = gen(pred)
        x_recon.append(tf.expand_dims(mu_x, axis=1))

    return tf.concat(x_recon, axis=1)

def gen_sam(prior, pxz, predictor, N, T, dimF, dimZ, P_dimH, dimH):
    # start from sample z_0, generate data
    x_seq = []
    f = tf.random.normal(shape=(N, dimF))
    z_prev = tf.zeros(shape=(N, dimZ))
    h_prev = tf.zeros(shape=(N, P_dimH))
    c_prev = tf.zeros(shape=(N, P_dimH))
    h_pred_prev = tf.zeros(shape=(N, dimH))
    c_pred_prev = tf.zeros(shape=(N, dimH))
    for t in range(T):
        mu, log_sig, h_prev, c_prev = prior(z_prev, h_prev, c_prev)
        z = sample_gaussian(mu, log_sig)
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z], axis=1), h_pred_prev, c_pred_prev)
        x = pxz(pred)
        z_prev = z
        x_seq.append(tf.expand_dims(x, axis=1))
    return tf.concat(x_seq, axis=1)

def gen_fix_f(prior, pxz, predictor, N, T, dimF, dimZ, P_dimH, dimH):

    f = tf.tile(tf.random_normal(shape=(1, dimF)), [N, 1])
    z = tf.zeros(shape=(N, dimZ))
    h_prev = tf.zeros(shape=(N, P_dimH))
    c_prev = tf.zeros(shape=(N, P_dimH))
    h_pred_prev = tf.zeros(shape=(N, dimH))
    c_pred_prev = tf.zeros(shape=(N, dimH))
    x_seq = []
    for t in range(T):
        mu, log_sig, h_prev, c_prev = prior(z, h_prev, c_prev)
        z = sample_gaussian(mu, log_sig)
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z], axis=1), h_pred_prev, c_pred_prev)
        x = pxz(pred)
        x_seq.append(tf.expand_dims(x, axis=1))
    return tf.concat(x_seq, axis=1)

def gen_fix_z(prior, pxz, predictor, N, T, dimF, dimZ, P_dimH, dimH):
    x_seq = []
    f = tf.random_normal(shape=(N, dimF))
    z = tf.zeros(shape=(1, dimZ))
    h_prev = tf.zeros(shape=(1, P_dimH))
    c_prev = tf.zeros(shape=(1, P_dimH))
    h_pred_prev = tf.zeros(shape=(N, dimH))
    c_pred_prev = tf.zeros(shape=(N, dimH))
    for t in range(T):
        mu, log_sig, h_prev, c_prev = prior(z, h_prev, c_prev)
        z = sample_gaussian(mu, log_sig)
        z_ = tf.tile(z, [N, 1])
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z_], axis=1), h_pred_prev, c_pred_prev)
        x = pxz(pred)
        x_seq.append(tf.expand_dims(x, axis=1))
    return tf.concat(x_seq, axis=1)


class RWVAE(object):
    def __init__(self, cont_disc, enc, dec, prior, predictor, dataset, dimH, prior_dimH, dimF, dimZ, d_steps, n_iter, n_digits, N_train, N_test, ws_beta, kl_beta, gan_beta, grad_pty, lr_dec, lr_enc, lr_disc, batch_size, N_gen, N_frames, N_pred, likelihood, checkpoint):
        self.cont_disc = cont_disc
        self.enc = enc
        self.dec = dec
        self.prior = prior
        self.predictor = predictor
        self.dataset = dataset
        self.dimH = dimH
        self.prior_dimH = prior_dimH
        self.dimF = dimF
        self.dimZ = dimZ
        self.d_steps = d_steps
        self.n_iter = n_iter
        self.N_train = N_train
        self.n_digits = n_digits
        self.N_test = N_test
        self.ws_beta = ws_beta
        self.kl_beta = kl_beta
        self.gan_beta = gan_beta
        self.grad_pty = grad_pty
        self.lr_enc = lr_enc
        self.lr_dec = lr_dec
        self.lr_d = lr_disc
        self.batch_size = batch_size
        self.N_gen = N_gen
        self.N_frames = N_frames
        self.N_pred = N_pred
        self.likelihood = likelihood
        self.checkpoint = checkpoint
        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        '''input frame size: 64*64*1'''
        self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.N_frames, 64, 64, 1))
        self.x_test = tf.placeholder(tf.float32, shape=(self.batch_size, self.N_frames, 64, 64, 1))
        self.lr_venc = tf.placeholder(tf.float32, shape=())
        self.lr_vdec = tf.placeholder(tf.float32, shape=())
        self.lr_disc =tf.placeholder(tf.float32, shape=())
        self.construct_opt()

    def construct_opt(self):
        mu_qf, log_sig_qf, fea_seq = self.enc.qfx(self.x, self.N_frames)

        kl_tmp = KL(mu_qf, log_sig_qf, 0.0, 0.0)
        f_fake = sample_gaussian(mu_qf, log_sig_qf)
        f_real = tf.random.normal(shape=(self.batch_size, self.dimF))
        critic_loss, grad_penalty = gan_loss(f_fake, f_real, self.cont_disc, self.batch_size)
        self.loss_disc = critic_loss + self.grad_pty * grad_penalty
        # bound = - self.gan_beta * critic_loss + self.kl_beta * kl_tmp
        bound = - self.gan_beta * critic_loss
        mmd_tmp, mse_tmp = 0., 0.
        # bound = -rate * kl_f
        mu_qz, log_sig_qz, z_seq = self.enc.qzx(fea_seq, self.N_frames)
        # and sample and compute the kl terms for z
        h_prev_p = tf.zeros(shape=(self.batch_size, self.prior_dimH))
        c_prev_p = tf.zeros(shape=(self.batch_size, self.prior_dimH))
        z_prev = tf.zeros(shape=(self.batch_size, self.dimZ))
        h_pred_prev = tf.zeros(shape=(self.batch_size, self.dimH))
        c_pred_prev = tf.zeros(shape=(self.batch_size, self.dimH))
        for t in range(self.N_frames):
            mu_pz, log_sig_pz, h_prev_p, c_prev_p = self.prior(z_prev, h_prev_p, c_prev_p)
            kl_z = KL(mu_qz[t], log_sig_qz[t], mu_pz, log_sig_pz)
            z_fake = z_seq[t]
            z_real = sample_gaussian(mu_pz, log_sig_pz)
            mmd_z = mmd_eval(z_fake, z_real)
            output, h_pred_prev, c_pred_prev = self.predictor(tf.concat([f_fake, z_fake], axis=1), h_pred_prev, c_pred_prev)
            '''Reconstructed Image Frame'''
            x_fake_t = self.dec(output)
            if self.likelihood == 'l2':
                ll = tf.reduce_mean((self.x[:, t] - x_fake_t) ** 2, [1, 2, 3])
                # ll = log_gaussian_prob(x[:, t], x_fake_t, 0.0)
            if self.likelihood == 'l1':
                ll = tf.reduce_mean(tf.abs(self.x[:, t] - x_fake_t), [1, 2, 3])
                # ll = log_laplace_prob(x[:, t], x_fake_t, 0.0)
            if self.likelihood == 'bernoulli':
                ll = -log_bernoulli_prob(self.x[:, t], x_fake_t)
            ''' loss function, z is motion variable'''
            bound += (ll + self.ws_beta * mmd_z + self.kl_beta * kl_z)
            mmd_tmp += mmd_z
            mse_tmp += ll
            kl_tmp += kl_z
            z_prev = z_fake

        self.model_loss = tf.reduce_mean(bound)
        self.mmd_v = tf.reduce_mean(mmd_tmp)
        self.mse = tf.reduce_mean(mse_tmp)
        self.kl_v = tf.reduce_mean(kl_tmp)
        self.gg_loss = critic_loss
        ''' update content discriminator'''
        self.opt_dis = tf.train.AdamOptimizer(learning_rate=self.lr_disc, beta1=.5, beta2=.999
                                         ).minimize(self.loss_disc, var_list=cont_disc.vars)
        enc_vars = [self.enc.vars, self.predictor.vars]
        ''' update encoder '''
        self.opt_enc = tf.train.AdamOptimizer(learning_rate=self.lr_enc, beta1=.5, beta2=.999
                                              ).minimize(self.model_loss, var_list=enc_vars)
        ''''update decoder, motion prior'''
        dec_vars = [self.dec.vars, self.prior.vars]
        self.opt_dec = tf.train.AdamOptimizer(learning_rate=self.lr_dec, beta1=.5, beta2=.999,
                                     ).minimize(self.model_loss, var_list=dec_vars)


    def train(self):
        if self.dataset == 'mmnist':
            data_root = '~/Dropbox'
            train_data, test_data = load_dataset('smmnist', data_root, self.N_frames, 64, num_digs=self.n_digits)
            train_loader = DataLoader(train_data,
                                      num_workers=3,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      pin_memory=True)
            X_valid = np.empty((20, self.N_frames, 64, 64, 1))
            for i in range(20):
                X_valid[i] = test_data[i]
            test_data = X_valid
            del X_valid

            def get_training_batch():
                while True:
                    for sequence in train_loader:
                        yield sequence

            training_batch_generator = get_training_batch()
        width, height, channel = 64, 64, 1
        shape_high = (width, height, channel)
        if self.checkpoint < 0:
            print('training from scratch')
            self.sess.run(tf.global_variables_initializer())
        else:
            # '''load params and generate samples'''
            load_params(self.sess, 'comb'+str(self.dimF)+str(self.lr_dec), self.checkpoint)
            # '''reconstruct sam'''
            x_recon = recon(self.x, self.dec, self.enc, self.predictor, self.dimH)
            num_plot = self.batch_size
            x_recon_d = self.sess.run(x_recon, feed_dict={self.x: test_data[:self.batch_size]})
            path = 'figs'
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path, 'recon_' + self.dataset + str(self.dimF) + str(self.lr_dec))
            x_gen_d = gen_sam(self.prior, self.dec, self.predictor, self.batch_size, self.N_frames, self.dimF, self.dimZ, self.prior_dimH, self.dimH)
            '''gen sam'''
            x_gen_d = self.sess.run(x_gen_d)
            plot_gif_mnist(x_gen_d[:num_plot], shape_high, path, 'gen_' + self.dataset + str(self.dimF) + str(self.lr_dec))
            '''gen sam fix f'''
            x_gen_f = gen_fix_f(self.prior, self.dec, self.predictor, self.batch_size, self.N_frames, self.dimF,
                                self.dimZ, self.prior_dimH, self.dimH)
            x_gen_f = self.sess.run(x_gen_f)
            plot_gif_mnist(x_gen_f[:num_plot], shape_high, path,
                     'genfixF_' + self.dataset + str(self.dimF) + str(self.lr_dec))
            '''gen sam fix z'''
            x_gen_z = gen_fix_z(self.prior, self.dec, self.predictor, self.batch_size, self.N_frames, self.dimF,
                                self.dimZ, self.prior_dimH, self.dimH)
            x_gen_z = self.sess.run(x_gen_z)
            plot_gif_mnist(x_gen_z[:num_plot], shape_high, path,
                     'genfixZ_' + self.dataset + str(self.dimF) + str(self.lr_dec))

            '''smaller plot'''
            x_recon = recon(self.x_test, self.dec, self.enc, self.predictor, self.dimH)
            x_recon_d = self.sess.run(x_recon, feed_dict={self.x_test: test_data[:self.batch_size]})
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path,
                     'orig_' + self.dataset + str(self.dimF) + str(self.lr_dec))
            '''recon_random_z'''
            x_recon = recon_random_z(self.x_test, self.dec, self.enc, self.prior, self.predictor, self.prior_dimH, self.dimH, self.dimZ)

            x_recon_d = self.sess.run(x_recon, feed_dict={self.x_test: test_data[:self.batch_size]})
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path,
                     'recon_random_z' + self.dataset + str(self.dimF) + str(self.lr_dec))
            x_recon = recon_shuffle_z(self.x_test, self.dec, self.enc, self.predictor, self.dimH, self.batch_size)

            x_recon_d = self.sess.run(x_recon, feed_dict={self.x_test: test_data[:self.batch_size]})
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path,
                     'recon_shuffle_z' + self.dataset + str(self.dimF) + str(self.lr_dec))
            '''recon_random_f'''
            x_recon = recon_random_f(self.x_test, self.dec, self.enc, self.predictor, self.dimH)

            x_recon_d = self.sess.run(x_recon, feed_dict={self.x_test: test_data[:self.batch_size]})
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path,
                     'recon_random_f' + self.dataset + str(self.dimF) + str(self.lr_dec))
            '''recon_shuffle_f'''
            x_recon = recon_shuffle_f(self.x_test, self.dec, self.enc, self.predictor, self.dimH)

            x_recon_d = self.sess.run(x_recon, feed_dict={self.x_test: test_data[:self.batch_size]})
            plot_gif_mnist(x_recon_d[:num_plot], shape_high, path,
                     'recon_shuffle_f' + self.dataset + str(self.dimF) + str(self.lr_dec))
        n_iter_vae = self.N_train // self.batch_size
        for epoch in range(self.n_iter):
            if (epoch+1) % 20 == 0:
                save_params(self.sess, 'comb' + str(self.dimF) + str(self.lr_dec), self.checkpoint)
            bound_total, disc_total, mmd_total, kl_total, gan_total = 0.0, .0, .0, 0., 0.
            begin = time.time()
            j = 0
            kk = 0
            while j < n_iter_vae:
                batch = next(training_batch_generator)
                '''update discriminator and encoders'''
                for steps in range(self.d_steps):
                    self.sess.run([self.opt_dis, self.opt_enc], feed_dict={self.x: batch, self.lr_disc: self.lr_d, self.lr_venc: self.lr_enc})
                    batch = next(training_batch_generator)
                    j = j + 1
                '''update decoder params'''
                bound, dis_bd, mmd_bd, kl_ls, g_ls, _ = self.sess.run([self.mse, self.loss_disc, self.mmd_v, self.kl_v, self.gg_loss, self.opt_dec], feed_dict={self.x: batch, self.lr_vdec: self.lr_dec})
                bound_total += bound; disc_total += dis_bd; mmd_total += mmd_bd; kl_total += kl_ls; gan_total += g_ls
                kk = kk+1.
            bound_total = bound_total/kk;  disc_total = disc_total/kk; mmd_total = mmd_total/kk; gan_total = gan_total/kk; kl_total = kl_total/kk
            end = time.time()
            print("Epoch %d, mse=%.5f, gan_loss=%.5f, mmd=%.5f,  kl_loss=%.5f,  disc_loss=%.5f, time=%.2f" \
                  % (epoch, bound_total, gan_total, mmd_total, kl_total, disc_total, end - begin))
            begin = end
            '''Decay step size'''
            # if epoch % 50 == 0:
            #     self.lr_d = self.lr_d * .9
            #     self.lr_enc = self.lr_enc * .9
            #     self.lr_dec = self.lr_dec * .9
        save_params(self.sess, 'comb'+str(self.dimF)+str(self.lr_dec), self.checkpoint)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="bair", type=str)
    parser.add_argument("--likelihood", default="l2", type=str)
    parser.add_argument("--file_name", default="Kal", type=str)
    parser.add_argument("--dimZ", default=32, type=int)
    parser.add_argument("--dimF", default=256, type=int)
    parser.add_argument("--d_steps", default=2, type=int)
    parser.add_argument("--disc_H", default=256, type=int)
    parser.add_argument("--disc_lastH", default=128, type=int)
    parser.add_argument("--dimH", default=512, type=int)
    parser.add_argument("--prior_dimH", default=512, type=int)
    parser.add_argument("--n_iter", default=5, type=int)
    parser.add_argument("--N_gen", default=100, type=int)
    parser.add_argument("--filter_width", default=5, type=int)  ### filter_width = 3, height, width = 4, 4
    parser.add_argument("--height", default=9, type=int)
    parser.add_argument("--N_test", default=100, type=int)
    parser.add_argument("--N_train", default=8000, type=int)
    parser.add_argument("--N_frames", default=8, type=int)
    parser.add_argument("--N_pred", default=8, type=int)
    parser.add_argument("--width", default=5, type=int)
    parser.add_argument("--n_digits", default=1, type=int)
    parser.add_argument("--checkpoint", default=-1, type=int)
    parser.add_argument("--n_channel", default=256, type=int)
    parser.add_argument("--batch_size", default=50, type=int)
    parser.add_argument("--ws_beta", default=.0001, type=float)
    parser.add_argument("--kl_beta", default=.0001, type=float)
    parser.add_argument("--gan_beta", default=1.e-4, type=float)
    parser.add_argument("--grad_pty", default=5., type=float)
    parser.add_argument("--noise_scale", default=0.1, type=float)
    parser.add_argument("--lr_dec", default=1.e-4, type=float)
    parser.add_argument("--lr_enc", default=5.e-5, type=float)
    parser.add_argument("--lr_disc", default=1.e-5, type=float)
    args = parser.parse_args()
    args_dict = vars(args)
    locals().update(args_dict)

    cont_disc = content_disc(disc_H, disc_lastH)
    enc = encoder(batch_size, dimZ, dimF, dimH)
    dec = decoder(batch_size, dimZ, dimF, dimH)
    prior_m = prior(dimZ, prior_dimH)
    predictor = predictor(batch_size, dimH)
    VAE = RWVAE(cont_disc, enc, dec, prior_m, predictor, dataset, dimH, prior_dimH, dimF, dimZ, d_steps, n_iter, n_digits, N_train, N_test, ws_beta, kl_beta, gan_beta, grad_pty, lr_dec, lr_enc, lr_disc, batch_size, N_gen, N_frames, N_pred, likelihood, checkpoint)
    VAE.train()
