import tensorflow as tf
import numpy as np
from ops import *
from ops_biggan import self_attention_2, fully_conneted_R, resblock_up_condition, conv_R, batch_norm_R, spectral_norm, resblock_down_diff
def sample_gaussian(mu, log_sig):
    return mu + tf.exp(log_sig) * tf.random_normal(mu.get_shape())
def lrelu(x, rate=0.1):
    # return tf.nn.relu(x)
    return tf.maximum(tf.minimum(x * rate, 0), x)
conv2d = tf.contrib.layers.convolution2d
conv2d_t = tf.contrib.layers.convolution2d_transpose
fc_layer = tf.contrib.layers.fully_connected

def conv2d_bn_lrelu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d(inputs, num_outputs, kernel_size, stride,
                                           weights_initializer=tf.random_normal_initializer(stddev=0.02),
                                           weights_regularizer=tf.contrib.layers.l2_regularizer(2.5e-5),
                                           activation_fn=tf.identity)
    conv = tf.contrib.layers.batch_norm(conv)
    conv = lrelu(conv)
    return conv

def fc_bn_lrelu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.random_normal_initializer(stddev=0.02),
                                           weights_regularizer=tf.contrib.layers.l2_regularizer(2.5e-5),
                                           activation_fn=tf.identity)
    fc = tf.contrib.layers.batch_norm(fc)
    fc = lrelu(fc)
    return fc

def fc_bn_relu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.random_normal_initializer(stddev=0.02),
                                           weights_regularizer=tf.contrib.layers.l2_regularizer(2.5e-5),
                                           activation_fn=tf.identity)
    fc = tf.contrib.layers.batch_norm(fc)
    fc = tf.nn.relu(fc)
    return fc

def conv2d_t_bn(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d_transpose(inputs, num_outputs, kernel_size, stride,
                                                     weights_initializer=tf.random_normal_initializer(stddev=0.02),
                                                     weights_regularizer=tf.contrib.layers.l2_regularizer(2.5e-5),
                                                     activation_fn=tf.identity)
    conv = tf.contrib.layers.batch_norm(conv)
    return conv

def conv2d_t_bn_relu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d_transpose(inputs, num_outputs, kernel_size, stride,
                                                     weights_initializer=tf.random_normal_initializer(stddev=0.02),
                                                     weights_regularizer=tf.contrib.layers.l2_regularizer(2.5e-5),
                                                     activation_fn=tf.identity)
    conv = tf.contrib.layers.batch_norm(conv)
    conv = lrelu(conv)
    return conv
''' Motion prior'''
class prior(object):
    def __init__(self, dimZ, dimH, noisy_scale = 0.1):
        self.dimZ = dimZ
        self.dimH = dimH
        self.noisy_scale = noisy_scale
        self.name = 'prior'
    def __call__(self, z_prev, h_prev=None, c_prev=None):
        if z_prev is None:
            # z_prev = tf.random.normal(shape=(h_prev.get_shape().as_list()[0], self.dimZ))
            z_prev = tf.zeros(shape=(h_prev.get_shape().as_list()[0], self.dimZ))
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            embedded = tf.layers.dense(tf.concat([z_prev, h_prev], 1), units =self.dimH, use_bias=True, name='pr_in1')
            embedded = tf.nn.leaky_relu(embedded)
            embedded = tf.layers.dense(embedded, units=self.dimH, use_bias=True, name='pr_in2')
            embedded = tf.nn.leaky_relu(embedded)

            out = tf.layers.dense(embedded, units=4 * self.dimH, use_bias=True, name='pr_in3')
            f, i, c, o = tf.split(out, 4, axis=1)
            c_prev1 = tf.nn.sigmoid(i) * tf.nn.tanh(c) + tf.nn.sigmoid(f) * c_prev
            h_prev1 = tf.nn.sigmoid(o) * tf.nn.tanh(c)

            out = tf.layers.dense(h_prev1, units=4 * self.dimH, use_bias=True, name='pr_in3')
            f, i, c, o = tf.split(out, 4, axis=1)
            c = tf.nn.sigmoid(i) * tf.nn.tanh(c) + tf.nn.sigmoid(f) * c_prev1
            h = tf.nn.sigmoid(o) * tf.nn.tanh(c)

            mu = tf.layers.dense(h, units =self.dimZ, use_bias=True, name='pr_mu')
            log_sig = tf.layers.dense(h, units=self.dimZ, use_bias=True, name='pr_sig')

            return mu, log_sig, h, c
    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]
'''After Concatinating motion and content from encoder, we first use LSTM before applying decoder'''
class predictor(object):
    def __init__(self, N, dimH):
        self.dimH = dimH
        self.N = N
        self.name = 'predictor'

    def __call__(self, inputs, h_prev=None, c_prev=None):
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            if h_prev is None:
               h_prev = tf.zeros(shape=(self.N, self.dimH))
            if c_prev is None:
               c_prev = tf.zeros(shape=(self.N, self.dimH))
            embedded = tf.layers.dense(tf.concat([inputs, h_prev], 1), units=self.dimH, use_bias=True,
                                           name='pred_inp')
            embedded = tf.nn.tanh(embedded)

            out = tf.layers.dense(embedded, units=4 * self.dimH, use_bias=True, name='pred')
            f, i, c, o = tf.split(out, 4, axis=1)
            c = tf.nn.sigmoid(i) * tf.nn.tanh(c) + tf.nn.sigmoid(f) * c_prev
            h = tf.nn.sigmoid(o) * tf.nn.tanh(c)
            output = tf.layers.dense(h, units=self.dimH, use_bias=True, name='pred_out')
            output = tf.nn.tanh(output)
            return output, h, c

    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

def deconv_R(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
    with tf.variable_scope(scope):
        x_shape = x.get_shape().as_list()

        if padding == 'SAME':
            output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]

        else:
            output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]

        if sn :
            w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
            x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)

            if use_bias :
                bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, bias)

        else :
            x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
                                           kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
                                           strides=stride, padding=padding, use_bias=use_bias)

        return x

class decoder(object):
    def __init__(self, batch_size, dimZ, dimF, dimH):
        self.batch_size = batch_size
        self.dimZ = dimZ
        self.dimH = dimH
        self.dimF = dimF
        self.data_dims = [64, 46, 1]
        self.ch = 64
        self.sn = True
        self.ladder_dims = ladder_dims
        self.cs = [self.data_dims[-1], 64, 128, 256, 512, 1024]
        self.fs = [self.data_dims[0], self.data_dims[0] // 2, self.data_dims[0] // 4, self.data_dims[0] // 8,
                   self.data_dims[0] // 16]
        self.name = 'dec'
    def __call__(self, inputs, is_training=True):
        ''''z is motion, f is content'''
        # inputs = tf.concat([z, f], axis=1)  # between [0, 1]
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            split_dim = self.dimH // 6
            split_dim_remainder = self.dimH - (split_dim * 6)

            if split_dim_remainder == 0:
                z_split = tf.split(inputs, num_or_size_splits=[split_dim] * 6, axis=-1)
            else:
                z_split = tf.split(inputs, num_or_size_splits=[split_dim] * 5 + [split_dim_remainder], axis=-1)
            ch = 32 * self.ch

            x = fully_conneted_R(z_split[0], units=2 * 2 * ch, sn=self.sn, scope='dense')
            x = tf.reshape(x, shape=[-1, 2, 2, ch])
            ch = ch // 2

            x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                      scope='resblock_up_16') # 6*6
            ch = ch // 2 # 384

            x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                      scope='resblock_up_8') # 12*12
            ch = ch // 2 # 192

            x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                      scope='resblock_up_4') # 24*24
            ch = ch // 2 # 96

            x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                      scope='resblock_up_2') # 48*48

            # Non-Local Block
            x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention')
            ch = ch // 2 # 48

            x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                      scope='resblock_up_1') # 96*96

            x = batch_norm_R(x, is_training)
            x = tf.nn.relu(x)
            x = conv_R(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit')

            x = tf.sigmoid(x)
            # x = tf.nn.tanh(x)
            return x

    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]

ladder_dims, latent_dims, ker_size = [6, 6, 6, 6, 6], 30, 3

class encoder(object):
    def __init__(self, batch_size, dimZ, dimF, dimH):
        self.batch_size = batch_size
        self.dimZ = dimZ
        self.dimH = dimH
        self.dimF = dimF
        self.data_dims=[64, 64, 1]
        self.sn = True
        self.ch = 64
        self.ladder_dims = ladder_dims
        self.name = 'enc'

    def feature(self, inputs, is_training=True):
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            ch =2*self.ch
            iconv2 = resblock_down_diff(inputs, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1') # 48*48
            iconv2 = self_attention_2(iconv2, channels=ch, sn=self.sn, scope='self_attention')
            print('y0 shape', iconv2.get_shape())
            ladder0 = tf.reshape(iconv2, [-1, np.prod(iconv2.get_shape().as_list()[1:])])
            print('l0 shape', ladder0.get_shape())
            ladder0_mean = tf.contrib.layers.fully_connected(ladder0, self.ladder_dims[0], activation_fn=tf.identity)
            ch = ch * 2
            iconv4 = resblock_down_diff(iconv2, channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                              scope='resblock_down_2') # 24*24
            print('y1 shape', iconv4.get_shape())
            ch = ch * 2
            ladder1 = tf.reshape(iconv4, [-1, np.prod(iconv4.get_shape().as_list()[1:])])
            print('l1 shape', ladder1.get_shape())
            ladder1_mean = tf.contrib.layers.fully_connected(ladder1, self.ladder_dims[1], activation_fn=tf.identity)

            iconv6 = resblock_down_diff(iconv4, channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                              scope='resblock_down_4')  #12*12
            print('y2 shape', iconv6.get_shape())
            ch = ch * 2
            ladder2 = tf.reshape(iconv6, [-1, np.prod(iconv6.get_shape().as_list()[1:])])
            print('l2 shape', ladder2.get_shape())
            ladder2_mean = tf.contrib.layers.fully_connected(ladder2, self.ladder_dims[2], activation_fn=tf.identity)
            iconv7 = resblock_down_diff(iconv6, channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                              scope='resblock_down_8')  #6*6
            print('y3 shape', iconv7.get_shape())
            ladder3 = tf.reshape(iconv7, [-1, np.prod(iconv7.get_shape().as_list()[1:])])
            print('l3 shape', ladder3.get_shape())
            ladder3_mean = tf.contrib.layers.fully_connected(ladder3, self.ladder_dims[3], activation_fn=tf.identity)
            ch = ch * 2

            iconv8 = resblock_down_diff(iconv7, channels=ch, use_bias=False, is_training=is_training, sn=self.sn,
                                   scope='resblock_down_16')  # 3*3
            print('y4 shape', iconv8.get_shape())

            ladder4 = tf.reshape(iconv8, [-1, np.prod(iconv8.get_shape().as_list()[1:])])
            print('l4 shape', ladder4.get_shape())
            ladder4_mean = tf.contrib.layers.fully_connected(ladder4, self.ladder_dims[4], activation_fn=tf.identity)

            latent_mean = tf.concat([ladder0_mean, ladder1_mean, ladder2_mean, ladder3_mean, ladder4_mean], 1)
            return latent_mean

    '''define q(f|x), f is content variable '''
    def qfx(self, x_seq, N_frames, h_prev = None, c_prev = None):
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            N, T = x_seq.get_shape().as_list()[:2]
            fea_list = []
            if h_prev is None:
                h_prev = tf.zeros(shape=(N, self.dimH))
            if c_prev is None:
                c_prev = tf.zeros(shape=(N, self.dimH))
            for t in range(N_frames):
                fea = self.feature(x_seq[:, t])
                fea_list.append(fea)
                '''f lstm model'''
                embedded = tf.layers.dense(tf.concat([fea, h_prev], 1), units=self.dimH, use_bias=True,
                                           name='en_fea_f1')
                embedded = tf.nn.tanh(embedded)
                embedded = tf.layers.dense(embedded, units=self.dimH, use_bias=True, name='en_fea_f2')
                embedded = tf.nn.tanh(embedded)
                embedded = tf.layers.dense(embedded, units=self.dimH, use_bias=True, name='en_fea_f3')
                embedded = tf.nn.tanh(embedded)

                out = tf.layers.dense(embedded, units=4 * self.dimH, use_bias=True, name='en_lstm_f1')
                f, i, c, o = tf.split(out, 4, axis=1)
                c = tf.nn.sigmoid(i) * tf.nn.tanh(c) + tf.nn.sigmoid(f) * c_prev
                h = tf.nn.sigmoid(o) * tf.nn.tanh(c)
                h_prev = h; c_prev = c
            '''get f'''
            mu = tf.layers.dense(h, units=self.dimF, use_bias=True, name='en_f_mu')
            log_sig = tf.layers.dense(h, units=self.dimF, use_bias=True, name='en_f_sig')

            return mu, log_sig, fea_list

    '''define q(z|x), z is motion variable '''
    def qzx(self, fea_seq, N_frames):
        mu_list = [];
        log_sig_list = [];
        z_list = []
        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            for t in range(N_frames):
                f_op1 = tf.layers.dense(fea_seq[t], units=2 * self.dimZ, use_bias=True, name='en_z1')
                f_op1 = tf.nn.tanh(f_op1)
                f_op1 = tf.layers.dense(f_op1, units=2*self.dimZ, use_bias=True, name='en_z2')
                f_op1 = tf.nn.tanh(f_op1)
                f_op1 = tf.layers.dense(f_op1, units=self.dimZ, use_bias=True, name='en_z3')
                f_op1 = tf.nn.tanh(f_op1)
                mu = tf.layers.dense(f_op1, units=self.dimZ, use_bias=True, name='en_z_mu')
                log_sig = tf.layers.dense(f_op1, units=self.dimZ, use_bias=True, name='en_z_sig')
                mu_list.append(mu)
                log_sig_list.append(log_sig)
                z_list.append(sample_gaussian(mu, log_sig))
            return mu_list, log_sig_list, z_list

    @property
    def vars(self):
        return [var for var in tf.global_variables() if self.name in var.name]