import tensorflow as tf
import numpy as np
def sample_gaussian(mu, log_sig):
    return mu + tf.exp(log_sig) * tf.random_normal(mu.get_shape())

def visualize(x, enc):
    batch_size, T, _, _, _ = x.get_shape().as_list()
    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    f = mu_qf
    # then infer z, do bidiretional lstm
    mu_q, log_sig_q, z_seq = enc.qzx(fea_seq, T)
    x_recon = []
    for t in range(T):
        z = mu_q[t]
        x_recon.append(tf.expand_dims(z, axis=1))

    return tf.concat(x_recon, axis=1)

def recon_random_z(x, gen, enc, pz_param, predictor, P_dimH, dimH, dimZ):
    ''' Fix '''
    batch_size, T, _, _, _ = x.get_shape().as_list()

    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    f = sample_gaussian(mu_qf, log_sig_qf)
    '''Fix f, randomly sample z'''
    # then infer z, do bidirectional lstm
    z = tf.zeros(shape=(batch_size, dimZ))
    h_prev = tf.zeros(shape=(batch_size, P_dimH))
    c_prev = tf.zeros(shape=(batch_size, P_dimH))
    h_pred_prev = tf.zeros(shape=(batch_size, dimH))
    c_pred_prev = tf.zeros(shape=(batch_size, dimH))
    x_recon = []
    for t in range(T):
        mu, log_sig, h_prev, c_prev = pz_param(z, h_prev, c_prev)
        z = sample_gaussian(mu, log_sig)
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z], axis=1), h_pred_prev, c_pred_prev)
        x = gen(pred)
        x_recon.append(tf.expand_dims(x, axis=1))

    return tf.concat(x_recon, axis=1)

def recon_shuffle_z(x, pxz, enc, predictor, dimH, N, seed=0):

    batch_size, T, _, _, _ = x.get_shape().as_list()

    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    f = sample_gaussian(mu_qf, log_sig_qf)
    # then infer z, do bidiretional lstm
    mu_q, log_sig_q, z_seq = enc.qzx(fea_seq, T)
    # random shuffling of z
    np.random.seed(seed)
    ind = np.asarray(np.random.permutation(range(batch_size)), dtype=np.int32)
    h_pred_prev = tf.zeros(shape=(N, dimH))
    c_pred_prev = tf.zeros(shape=(N, dimH))
    x_recon = []
    for t in range(T):
        z = tf.gather(z_seq[t], ind)
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z], axis=1), h_pred_prev, c_pred_prev)
        mu_x = pxz(pred)
        x_recon.append(tf.expand_dims(mu_x, axis=1))

    return tf.concat(x_recon, axis=1)

def recon_random_f(x, gen, enc, predictor, dimH):
    batch_size, T, _, _, _ = x.get_shape().as_list()
    z_seq = []

    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    f = sample_gaussian(mu_qf, log_sig_qf)

    # then infer z, do bidiretional lstm
    mu_q, log_sig_q, z_seq = enc.qzx(fea_seq, T)
    # change to random f
    dimF = f.get_shape().as_list()[1]
    f = tf.random_normal(shape=(batch_size, dimF))
    x_recon = []
    h_pred_prev = tf.zeros(shape=(batch_size, dimH))
    c_pred_prev = tf.zeros(shape=(batch_size, dimH))
    for t in range(T):
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z_seq[t]], axis=1), h_pred_prev, c_pred_prev)
        mu_x =gen(pred)
        x_recon.append(tf.expand_dims(mu_x, axis=1))

    return tf.concat(x_recon, axis=1)

def recon_shuffle_f(x, gen, enc, predictor, dimH, seed = 0):
    batch_size, T, _, _, _ = x.get_shape().as_list()
    z_seq = []

    # first infer f
    mu_qf, log_sig_qf, fea_seq = enc.qfx(x, T)
    f = sample_gaussian(mu_qf, log_sig_qf)

    # then infer z, do bidiretional lstm
    mu_q, log_sig_q, z_seq = enc.qzx(fea_seq, T)
    # change to shuffle f
    np.random.seed(seed)
    ind = np.asarray(np.random.permutation(range(batch_size)), dtype=np.int32)
    f = tf.gather(f, ind)
    x_recon = []
    h_pred_prev = tf.zeros(shape=(batch_size, dimH))
    c_pred_prev = tf.zeros(shape=(batch_size, dimH))
    for t in range(T):
        pred, h_pred_prev, c_pred_prev = predictor(tf.concat([f, z_seq[t]], axis=1), h_pred_prev, c_pred_prev)
        mu_x = gen(pred)
        x_recon.append(tf.expand_dims(mu_x, axis=1))

    return tf.concat(x_recon, axis=1)
