from inits import *
from flip_gradient import flip_gradient
import utils


class BaseModel(object):
    def __init__(self, num_domains, batch_size, num_lab_each_epoch, mu=0):
        self.num_domains = num_domains
        self.batch_size = batch_size
        self.dim_z = 256
        self.dim_y = 10
        self.mu = mu
        self.num_lab_each_epoch = num_lab_each_epoch
        id_index = np.arange(self.batch_size // 2 + self.num_lab_each_epoch)
        unlab_id_index = np.arange(self.batch_size // 2 + self.num_lab_each_epoch, self.batch_size)
        self.lab_index = tf.Variable(tf.constant(id_index), trainable=False)
        self.unlab_index = tf.Variable(tf.constant(unlab_id_index), trainable=False)

        self._build_model()

    def _build_model(self):
        self.X = tf.placeholder(tf.float32, [None, 32, 32, 1])
        self.y = tf.placeholder(tf.float32, [None, self.dim_y])
        self.domain = tf.placeholder(tf.float32, [None, 2])
        self.l = tf.placeholder(tf.float32, [])
        self.train = tf.placeholder(tf.bool, [])
        self.learning_rate = tf.placeholder(tf.float32, [])
        self.initializer = tf.contrib.layers.xavier_initializer(seed=2019)

        all_vae_loss, all_pred_loss, all_latent_z = [], [], []
        for i in range(self.num_domains):
            id_index = np.concatenate([np.arange(self.batch_size*i/2, self.batch_size*(i+1)/2), np.arange(self.batch_size * self.num_domains/2, self.batch_size * (self.num_domains+1)/2)]).astype(int)
            all_X = lambda: self.X
            sub_X = lambda: tf.gather(self.X, id_index)
            X_source = tf.cond(self.train, sub_X, all_X)

            all_Y = lambda: self.y
            sub_Y = lambda: tf.gather(self.y, id_index)
            Y_source = tf.cond(self.train, sub_Y, all_Y)

            feature = self.encoder(X_source)
            vae_loss, pred_loss, z, label_pred = self.label_informed_DA(feature, Y_source)
            all_vae_loss.append(vae_loss)
            all_pred_loss.append(pred_loss)
            all_latent_z.append(z)
        all_d_logits = self.domain_discriminator(all_latent_z)
        all_domain_loss = [tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=all_d_logits[i], labels=self.domain)) for i in range(self.num_domains)]
        all_vae_loss = tf.stack(all_vae_loss)
        all_pred_loss = tf.stack(all_pred_loss)
        all_domain_loss = tf.stack(all_domain_loss)
        all_loss = all_vae_loss + all_pred_loss + all_domain_loss

        self.total_loss = 0.
        for i in range(self.num_domains):
            self.total_loss += (self.mu**(self.num_domains-i-1)) * all_loss[i]

        self.train_op = tf.train.MomentumOptimizer(self.learning_rate, 0.9).minimize(self.total_loss)
        correct_label_pred = tf.equal(tf.argmax(self.y, 1), tf.argmax(label_pred, 1))
        self.label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))
        self.pred = label_pred

    def encoder(self, x):
        with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
            # Layer 1: 32x32x1 --> 16x16x64
            conv1 = tf.layers.conv2d(x, 64, [5, 5], strides=2, padding='SAME',
                                     kernel_initializer=self.initializer,
                                     activation=self.leaky_relu)
            conv1 = self.instance_normalization(conv1, "conv1")

            # Layer 2: 16x16x64 --> 8x8x128
            conv2 = tf.layers.conv2d(conv1, 64, [5, 5], strides=2, padding='SAME',
                                     kernel_initializer=self.initializer,
                                     activation=self.leaky_relu)
            conv2 = self.instance_normalization(conv2, "conv2")

            # Layer 3: 8x8x128 --> 8x8x256
            conv3 = tf.layers.conv2d(conv2, 128, [5, 5], strides=1, padding='SAME',
                                     kernel_initializer=self.initializer,
                                     activation=self.leaky_relu)

            feat = tf.reshape(conv3, [-1, 8*8*128])
            fc1 = tf.layers.dense(inputs=feat,
                                  units=1024,
                                  activation=self.leaky_relu,
                                  kernel_initializer=self.initializer)
            fc2 = tf.layers.dense(inputs=fc1,
                                  units=1024,
                                  activation=self.leaky_relu,
                                  kernel_initializer=self.initializer)
            fc3 = tf.layers.dense(inputs=fc2,
                                  units=self.dim_y,
                                  activation=None,
                                  kernel_initializer=self.initializer)
        return fc3

    def label_informed_DA(self, feat, labels):
        all_features = lambda: feat
        source_features = lambda: tf.gather(feat, self.lab_index)
        classify_feats = tf.cond(self.train, source_features, all_features)

        all_labels = lambda: labels
        source_labels = lambda: tf.gather(labels, self.lab_index)
        classify_labels = tf.cond(self.train, source_labels, all_labels)

        ''' Labelled data points '''
        y_lab_logits = self.label_discriminator(classify_feats)
        z_lab, z_lab_mu, z_lab_sigma = self._generate_zxy(classify_feats, classify_labels)
        x_recon_lab = self._generate_xzy(z_lab, classify_labels)
        y_recon_lab = self._generate_yz(z_lab)
        lab_vae_loss = - tf.reduce_mean(utils.tf_normal_logpdf(classify_feats, x_recon_lab[0], x_recon_lab[1]), axis=1) \
                       - tf.reduce_mean(utils.tf_normal_logpdf(classify_labels, y_recon_lab[0], y_recon_lab[1]), axis=1)\
                       - tf.reduce_mean(utils.tf_gaussian_marg(z_lab_mu, z_lab_sigma), axis=1)\
                       + tf.reduce_mean(utils.tf_gaussian_ent(z_lab_sigma), axis=1)
        lab_vae_loss = tf.reduce_mean(lab_vae_loss)

        ''' Unabelled data points '''
        unlabelled_x = tf.gather(feat, self.unlab_index)
        y_ulab_logits = self.label_discriminator(unlabelled_x, reuse=True)
        Loss_ulab = []
        for label in range(self.dim_y):
            _y_ulab = self.one_label_tensor(label)
            z_ulab, z_ulab_mu, z_ulab_sigma = self._generate_zxy(unlabelled_x, _y_ulab, reuse=True)
            x_recon_ulab = self._generate_xzy(z_ulab, _y_ulab, reuse=True)
            y_recon_ulab = self._generate_yz(z_ulab, reuse=True)
            _L_ulab = - tf.reduce_mean(utils.tf_normal_logpdf(unlabelled_x, x_recon_ulab[0], x_recon_ulab[1]), axis=1, keepdims=True) \
                      - tf.reduce_mean(utils.tf_normal_logpdf(_y_ulab, y_recon_ulab[0], y_recon_ulab[1]), axis=1, keepdims=True) \
                      - tf.reduce_mean(utils.tf_gaussian_marg(z_ulab_mu, z_ulab_sigma), axis=1, keepdims=True) \
                      + tf.reduce_mean(utils.tf_gaussian_ent(z_ulab_sigma), axis=1, keepdims=True)
            Loss_ulab.append(_L_ulab)

        Loss_ulab = tf.concat(Loss_ulab, axis=1)
        y_ulab = tf.nn.softmax(y_ulab_logits)
        ulab_vae_loss = tf.reduce_sum(tf.multiply(y_ulab, tf.subtract(Loss_ulab, tf.log(y_ulab))), 1)

        psudo_unlab_y = self.label_discriminator(unlabelled_x, reuse=True)
        psudo_unlab_y = tf.nn.softmax(psudo_unlab_y, axis=1)
        psudo_all_y = tf.concat([classify_labels, psudo_unlab_y], axis=0)
        all_z, _, _ = self._generate_zxy(feat, psudo_all_y, reuse=True)

        vae_loss = 0.1 * (lab_vae_loss + tf.reduce_mean(ulab_vae_loss))
        label_pred = tf.nn.softmax(y_lab_logits, axis=1)
        pred_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_lab_logits, labels=classify_labels))
        return vae_loss, pred_loss, all_z, label_pred

    def _generate_zxy(self, x, y, reuse=False):
        with tf.variable_scope('generate_zxy', reuse=tf.AUTO_REUSE):
            z_mu = tf.layers.dense(inputs=tf.concat([x, y], axis=1),
                                   units=self.dim_z,
                                   activation=None,
                                   kernel_initializer=self.initializer)
            z_sigma = tf.layers.dense(inputs=tf.concat([x, y], axis=1),
                                      units=self.dim_z,
                                      activation=None,
                                      kernel_initializer=self.initializer)
            z_sample = z_mu + tf.multiply(tf.exp(z_sigma / 2),
                                          tf.random_normal([tf.shape(x)[0], self.dim_z], 0, 1, dtype=tf.float32))
        return z_sample, z_mu, z_sigma

    def _generate_xzy(self, z, y, reuse=False):
        with tf.variable_scope('generate_xzy', reuse=tf.AUTO_REUSE):
            x_recon_mu = tf.layers.dense(inputs=tf.concat([z, y], axis=1),
                                         units=self.dim_y,
                                         activation=None,
                                         kernel_initializer=self.initializer)
            x_recon_sigma = tf.layers.dense(inputs=tf.concat([z, y], axis=1),
                                            units=self.dim_y,
                                            activation=None,
                                            kernel_initializer=self.initializer)
        return [x_recon_mu, x_recon_sigma]

    def _generate_yz(self, z, reuse=False):
        with tf.variable_scope('generate_yz', reuse=tf.AUTO_REUSE):
            y_recon_mu = tf.layers.dense(inputs=z,
                                         units=self.dim_y,
                                         activation=None,
                                         kernel_initializer=self.initializer)
            y_recon_sigma = tf.layers.dense(inputs=z,
                                            units=self.dim_y,
                                            activation=None,
                                            kernel_initializer=self.initializer)
        return [y_recon_mu, y_recon_sigma]

    def label_discriminator(self, feat, reuse=False):
        with tf.variable_scope("label_discriminator", reuse=tf.AUTO_REUSE):
            logits = tf.layers.dense(inputs=feat,
                                     units=self.dim_y,
                                     activation=None,
                                     kernel_initializer=self.initializer)
        return logits

    def domain_discriminator(self, feat, reuse=False):
        all_d_logits = []
        for i in range(self.num_domains):
            with tf.variable_scope("domain_discriminator_{}".format(i), reuse=reuse):
                feat_flip = flip_gradient(feat[i], self.l)
                d_fc1 = tf.layers.dense(inputs=feat_flip,
                                        units=256,
                                        activation=self.leaky_relu,
                                        kernel_initializer=self.initializer)
                d_logits = tf.layers.dense(inputs=d_fc1,
                                           units=2,
                                           activation=None,
                                           kernel_initializer=self.initializer)
                all_d_logits.append(d_logits)
        return all_d_logits

    def instance_normalization(self, x, name):
        with tf.variable_scope("instance_norm"):
            with tf.variable_scope(name):
                epsilon = 1e-5
                mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
                scale = tf.get_variable('scale', [x.get_shape()[-1]],
                                        initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))
                offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0))
                out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset

                return out

    def leaky_relu(self, x):
        alpha = 0.05
        return tf.maximum(x, alpha * x)

    def one_label_tensor(self, label):
        indices = []
        values = []
        for i in range(self.batch_size // 2 - self.num_lab_each_epoch):
            indices += [[i, label]]
            values += [1.]
        _y_ulab = tf.sparse_tensor_to_dense(
            tf.SparseTensor(indices=indices, values=values,
                            dense_shape=[self.batch_size // 2 - self.num_lab_each_epoch, self.dim_y]), 0.0)
        return _y_ulab
