from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import tensorflow as tf
from tensorflow.contrib.framework import arg_scope
from tensorflow.contrib.framework import add_arg_scope

from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm
from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two
from keras import backend as K
from layers import leaky_relu
import os
from generic_utils import model_dir
import numpy as np
import tensorbayes as tb
from layers import batch_ema_acc
from collections import OrderedDict


def build_block(input_layer, layout, info=1):
    x = input_layer
    for i in range(0, len(layout)):
        with tf.variable_scope('l{:d}'.format(i)):
            f, f_args, f_kwargs = layout[i]
            x = f(x, *f_args, **f_kwargs)
            if info > 1:
                print(x)
    return x


@add_arg_scope
def normalize_perturbation(d, scope=None):
    with tf.name_scope(scope, 'norm_pert'):
        output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape)))
    return output


def build_encode_template(
        input_layer, training_phase, scope, encode_layout,
        reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'):
    with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
        with arg_scope([leaky_relu], a=0.1), \
             arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
             arg_scope([batch_norm], internal_update=internal_update):

            preprocess = instance_norm if inorm else tf.identity

            layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size)
            output_layer = build_block(input_layer, layout)

    return output_layer


def build_class_discriminator_template(
        input_layer, training_phase, scope, num_classes, class_discriminator_layout,
        reuse=None, internal_update=False, getter=None, cnn_size='large'):
    with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
        with arg_scope([leaky_relu], a=0.1), \
             arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
             arg_scope([batch_norm], internal_update=internal_update):
            layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None,
                                                cnn_size=cnn_size)
            output_layer = build_block(input_layer, layout)

    return output_layer


def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None):
    with tf.variable_scope('domain_disc', reuse=reuse):
        with arg_scope([dense], activation=tf.nn.relu):
            layout = domain_layout(c=c)
            output_layer = build_block(x, layout)

    return output_layer


def get_default_config():
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.log_device_placement = False
    tf_config.allow_soft_placement = True
    return tf_config


class LDROT():
    def __init__(self,
                 model_name="LDROT",
                 learning_rate=1.0,
                 batch_size=128,
                 num_iters=10000,
                 summary_freq=100,
                 src_class_trade_off=1.0,
                 src_vat_trade_off=1.0,
                 trg_vat_troff=1.0,
                 trg_ent_troff=1.0,
                 domain_trade_off=1.0,
                 adapt_domain_trade_off=False,
                 encode_layout=None,
                 classify_layout=None,
                 domain_layout=None,
                 current_time='',
                 inorm=True,
                 theta=0.1,
                 g_network_trade_off=1.0,
                 ldrot_model_id='',
                 save_grads=False,
                 only_save_final_model=True,
                 cnn_size='large',
                 update_target_loss=True,
                 data_augmentation=False,
                 sample_size=50,
                 data_shift_troff=10.0,
                 tau=0.1,
                 pseudo_lbl_at=4000,
                 muy=0.9,
                 use_lbl_smoothing=True,
                 random_seed=6789,
                 lbl_shift_troff=1.0,
                 **kwargs):
        self.model_name = model_name
        self.random_seed = random_seed
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_iters = num_iters
        self.summary_freq = summary_freq
        self.src_class_trade_off = src_class_trade_off
        self.src_vat_trade_off = src_vat_trade_off
        self.trg_vat_troff = trg_vat_troff
        self.trg_ent_troff = trg_ent_troff
        self.domain_trade_off = domain_trade_off
        self.adapt_domain_trade_off = adapt_domain_trade_off

        self.encode_layout = encode_layout
        self.classify_layout = classify_layout
        self.domain_layout = domain_layout

        self.current_time = current_time
        self.inorm = inorm

        self.theta = theta
        self.g_network_trade_off = g_network_trade_off

        self.ldrot_model_id = ldrot_model_id

        self.save_grads = save_grads
        self.only_save_final_model = only_save_final_model

        self.cnn_size = cnn_size
        self.update_target_loss = update_target_loss

        self.sample_size = sample_size
        self.data_shift_troff = data_shift_troff
        self.pseudo_lbl_at = pseudo_lbl_at

        self.data_augmentation = data_augmentation
        self.muy = muy

        self.dirtt = 0
        self.beta_kl_trade_off = 0.1
        self.tau = tau
        self.use_lbl_smoothing = use_lbl_smoothing
        self.lbl_shift_troff = lbl_shift_troff

    def _init(self, y_src_train, y_trg_train):
        np.random.seed(self.random_seed)
        tf.set_random_seed(self.random_seed)
        tf.reset_default_graph()

        self.tf_graph = tf.get_default_graph()
        self.tf_config = get_default_config()
        self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph)

        self.src_num_classes = len(np.unique(y_src_train))
        self.trg_num_classes = len(np.unique(y_trg_train))
        self.num_classes = self.src_num_classes

    def _get_variables(self, list_scopes):
        variables = []
        for scope_name in list_scopes:
            variables.append(tf.get_collection('trainable_variables', scope_name))
        return variables

    def convert_one_hot(self, y):
        y_idx = y.reshape(-1).astype(int) if y is not None else None
        y = np.eye(self.src_num_classes)[y_idx] if y is not None else None
        return y

    def _get_scope(self, part_name, side_name, same_network=True):
        suffix = ''
        if not same_network:
            suffix = '/' + side_name
        return part_name + suffix

    def _get_primary_scopes(self):
        return ['generator', 'classifier']

    def _get_secondary_scopes(self):
        return ['domain_disc']

    def _build_source_middle(self, x_src):
        scope_name = self._get_scope('generator', 'src')
        return build_encode_template(x_src, encode_layout=self.encode_layout,
                                     scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size)

    def _build_target_middle(self, x_trg):
        scope_name = self._get_scope('generator', 'trg')
        return build_encode_template(
            x_trg, encode_layout=self.encode_layout,
            scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
            reuse=True, internal_update=True, cnn_size=self.cnn_size
        )

    def _build_classifier(self, x, num_classes, ema=None, is_teacher=False):  # use for DIRT-T
        g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False)
        self.g_x = build_encode_template(
            x, encode_layout=self.encode_layout,
            scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm,
            reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema),
            cnn_size=self.cnn_size
        )

        h_teacher_scope = self._get_scope('classifier', 'teacher', same_network=False)
        self.h_g_x = build_class_discriminator_template(
            self.g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'classifier', num_classes=num_classes,
            reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout,
            getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size
        )
        return self.h_g_x

    def _build_domain_discriminator(self, x_mid, reuse=False):
        return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=1, reuse=reuse)

    def _build_class_src_discriminator(self, x_src, num_src_classes):
        return build_class_discriminator_template(
            self.x_src_mid, training_phase=self.is_training, scope='classifier', num_classes=num_src_classes,
            class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
        )

    def _build_class_trg_discriminator(self, x_trg, num_trg_classes):
        return build_class_discriminator_template(
            self.x_trg_mid, training_phase=self.is_training, scope='classifier', num_classes=num_trg_classes,
            reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
        )

    def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout,
                      pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None):
        with tf.name_scope(scope, 'perturb_image'):
            eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x)))

            # Predict on randomly perturbed image
            x_eps_mid = build_encode_template(
                x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True,
                inorm=self.inorm, cnn_size=self.cnn_size)
            x_eps_pred = build_class_discriminator_template(
                x_eps_mid, class_discriminator_layout=class_discriminator_layout,
                training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
                cnn_size=self.cnn_size
            )
            # eps_p = classifier(x + eps, phase=True, reuse=True)
            loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred)

            # Based on perturbed image, get direction of greatest error
            eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0]

            # Use that direction as adversarial perturbation
            eps_adv = normalize_perturbation(eps_adv)
            x_adv = tf.stop_gradient(x + radius * eps_adv)

        return x_adv

    def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout,
                 scope=None, scope_classify=None, scope_encode=None, training_phase=None):

        with tf.name_scope(scope, 'smoothing_loss'):
            x_adv = self.perturb_image(
                x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout,
                scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase)

            x_adv_mid = build_encode_template(
                x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm,
                reuse=True, cnn_size=self.cnn_size)
            x_adv_pred = build_class_discriminator_template(
                x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
                class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size
            )
            # p_adv = classifier(x_adv, phase=True, reuse=True)
            loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred))

        return loss

    def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None):
        return self.vat_loss(  # compute the divergence between C(x) and C(G(x+r))
            x, p, num_classes,
            class_discriminator_layout=self.classify_layout,
            encode_layout=self.encode_layout,
            scope=scope, scope_classify=scope_classify, scope_encode=scope_encode,
            training_phase=self.is_training
        )

    def _compute_cosine_similarity(self, x_src_mid, x_trg_mid, use_distance=True):
        x_src_mid_flatten = tf.layers.Flatten()(x_src_mid)
        x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
        # Cosine similarity
        similarity = tf.reduce_sum(x_src_mid_flatten[:, tf.newaxis] * x_trg_mid_flatten, axis=-1)
        # Only necessary if vectors are not normalized
        similarity /= tf.norm(x_src_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_trg_mid_flatten, axis=-1)

        # If you prefer the distance measure
        if use_distance:
            distance = 1.0 - similarity
            return distance
        else:
            return (1.0 + similarity)/2  # belongs to [0, 1]

    def _compute_label_shift_loss(self, y_src_logit, y_trg_logit):
        y_src_logit_rep = K.repeat_elements(tf.expand_dims(y_src_logit, axis=0), rep=self.batch_size, axis=0)  # shape a (3, 3, 4)\
        self.y_src_logit_rep = tf.reshape(y_src_logit_rep, [-1, y_src_logit_rep.get_shape()[-1]])
        self.y_trg_logit_rep = K.repeat_elements(y_trg_logit, rep=self.batch_size, axis=0)  # shape b (9, 4)
        self.label_shift_loss1 = softmax_x_entropy_two(self.y_trg_logit_rep, self.y_src_logit_rep)
        label_shift_loss = tf.reshape(self.label_shift_loss1, [self.batch_size, self.batch_size])
        return label_shift_loss

    def _build_model(self):
        self.x_src = tf.placeholder(dtype=tf.float32, shape=(None, 2048)) # name='x_src_input'
        self.x_src /= tf.norm(self.x_src)
        self.x_trg = tf.placeholder(dtype=tf.float32, shape=(None, 2048)) # name='x_trg_input'
        self.x_trg /= tf.norm(self.x_trg)
        self.y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.src_num_classes))
        self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.trg_num_classes))

        T = tb.utils.TensorDict(dict(
            x_tmp=tf.placeholder(dtype=tf.float32, shape=(None, 2048)),
            y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.trg_num_classes))
        ))

        T.x_tmp /= tf.norm(T.x_tmp)

        self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

        # for encode source to middle (build generator 1 to map source to joint space)
        self.x_src_mid = self._build_source_middle(self.x_src)

        # for encode target to middle (build generator 1 to map target to joint space)
        self.x_trg_mid = self._build_target_middle(self.x_trg)

        # for classifier
        self.y_src_logit = self._build_class_src_discriminator(self.x_src_mid, self.src_num_classes)  # (batch_size, src_num_classes)
        self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, self.trg_num_classes)  # (batch_size, trg_num_classes)

        self.y_src_pred = tf.argmax(self.y_src_logit, 1, output_type=tf.int32)
        self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32)
        self.y_src_sparse = tf.argmax(self.y_src, 1, output_type=tf.int32)
        self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32)

        ###############################
        # classification loss n
        if self.use_lbl_smoothing:
            self.y_src_smooth = (1 - 0.1) * self.y_src + 0.1 / self.num_classes
            self.src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=self.y_src_logit, labels=self.y_src_smooth)
        else:
            self.src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=self.y_src_logit, labels=self.y_src)
        self.src_loss_class = tf.reduce_mean(self.src_loss_class_detail)

        self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.y_trg_logit, labels=self.y_trg)
        self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail)

        self.src_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_src_sparse, self.y_src_pred), 'float32'))
        self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32'))

        #############################
        # optimal transport
        self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid)  # SWAP
        self.compute_muy_use_pseudolbl = tf.placeholder(tf.bool, shape=(), name='compute_muy_use_pseudolbl')
        self.trg_pseudo_label = tf.placeholder(dtype=tf.int32, shape=(None,))
        def compute_muy():
            sim_mask = tf.cast(tf.reshape(tf.equal(self.trg_pseudo_label[:, tf.newaxis], self.y_src_sparse), [-1]),
                               'float32')  # 1: same classes, 0: different classes

            x_trg_mid_flatten = tf.layers.Flatten()(self.x_trg_mid)
            x_src_mid_flatten = tf.layers.Flatten()(self.x_src_mid)
            # Cosine similarity
            similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_flatten, axis=-1)
            similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_flatten, axis=-1)
            # # If you prefer the distance measure
            # distance = (1.0 - similarity2)/2
            similarity = tf.reshape(similarity, [-1])

            sorted_sim, sim_indices = tf.nn.top_k(tf.negative(similarity), k=self.batch_size * self.batch_size)
            sorted_sim = tf.negative(sorted_sim)
            sorted_mask = tf.gather(sim_mask, sim_indices)
            minId_same_cls = tf.where(tf.equal(sorted_mask, 1.0))[0][0]
            maxId_diff_cls = tf.where(tf.equal(sorted_mask, 0.0))[-1][0]

            def _compute_muy_mix_region():
                sim_mix_region = tf.gather(sorted_sim, tf.range(minId_same_cls, maxId_diff_cls))
                muy = tf.contrib.distributions.percentile(sim_mix_region, 50.0)
                return muy

            def _compute_muy_default():
                muy = (self.num_classes - 1.0) / self.num_classes
                return muy

            muy = tf.cond(minId_same_cls < maxId_diff_cls, _compute_muy_mix_region, _compute_muy_default)
            return muy

        def compute_muy_default():
            muy = (self.num_classes - 1.0) / self.num_classes
            return muy

        # self.muy = tf.cond(self.compute_muy_use_pseudolbl,
        #                    compute_muy, compute_muy_default)
        self.muy = tf.constant(self.muy)
        self.data_shift_troff = tf.exp(
            (self._compute_cosine_similarity(self.x_trg, self.x_src, use_distance=False) - self.muy)
            / self.tau)

        self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.y_src_logit)
        self.data_label_shift_loss = self.data_shift_troff*self.data_shift_loss + self.lbl_shift_troff * self.label_shift_loss

        self.g_network = tf.reshape(self._build_domain_discriminator(self.x_trg_mid), [-1])
        self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta
        self.g_network_loss = tf.reduce_mean(self.g_network)

        self.OT_loss = tf.reduce_mean(
            - self.theta * \
            (
                    tf.log(1.0 / self.batch_size) +
                    tf.reduce_logsumexp(self.exp_term, axis=1)
            )
        ) + self.g_network_trade_off * self.g_network_loss

        #############################
        # vat loss
        self.src_loss_vat = self._build_vat_loss(
            self.x_src, self.y_src_logit, self.num_classes,
            scope_encode=self._get_scope('generator', 'src'), scope_classify='classifier'
        )
        self.trg_loss_vat = self._build_vat_loss(
            self.x_trg, self.y_trg_logit, self.num_classes,
            scope_encode=self._get_scope('generator', 'trg'), scope_classify='classifier'
        )

        #############################
        # conditional entropy loss w.r.t target distribution
        self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit,
                                                                   logits=self.y_trg_logit))

        #############################
        # construct primary loss
        if self.adapt_domain_trade_off:
            self.domain_trade_off_ph = tf.placeholder(dtype=tf.float32)
        lst_primary_losses = [
            (self.src_class_trade_off if self.dirtt == 0 else self.beta_kl_trade_off, self.src_loss_class),
            (self.domain_trade_off if self.dirtt == 0 else 0.0, self.OT_loss),
            (self.src_vat_trade_off if self.dirtt == 0 else 0.0, self.src_loss_vat),
            (self.trg_vat_troff, self.trg_loss_vat),
            (self.trg_ent_troff, self.trg_loss_cond_entropy)
        ]
        self.primary_loss = tf.constant(0.0)
        for trade_off, loss in lst_primary_losses:
            if trade_off != 0:
                self.primary_loss += trade_off * loss
        primary_variables = self._get_variables(self._get_primary_scopes())

        # Evaluation (EMA)
        ema = tf.train.ExponentialMovingAverage(decay=0.998)
        var_list_for_ema = primary_variables[0] + primary_variables[1]
        ema_op = ema.apply(var_list=var_list_for_ema)
        self.ema_p = self._build_classifier(T.x_tmp, self.trg_num_classes, ema)

        # Accuracies
        self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p)
        self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc)

        # Student accuracy
        self.batch_student_acc = batch_ema_acc(self.y_trg, self.y_trg_logit)
        self.fn_batch_student_acc = tb.function(self.tf_session, [self.x_trg, self.y_trg, self.is_training],
                                                self.batch_student_acc)

        self.train_main = \
            tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.primary_loss, var_list=primary_variables)
        self.primary_train_op = tf.group(self.train_main, ema_op)

        if self.save_grads:
            self.grads_wrt_primary_loss = tf.train.AdamOptimizer(self.learning_rate, 0.5).compute_gradients(
                self.primary_loss, var_list=primary_variables)

        # construct secondary loss
        secondary_variables = self._get_variables(self._get_secondary_scopes())
        self.secondary_train_op = \
            tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss,
                                                                   var_list=secondary_variables)

        if self.update_target_loss:
            self.target_loss = self.trg_trade_off * (self.trg_loss_vat + self.trg_loss_cond_entropy)

            self.target_train_op = \
                tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.target_loss,
                                                                         var_list=primary_variables)
        if self.save_grads:
            self.grads_wrt_secondary_loss = tf.train.AdamOptimizer(self.learning_rate, 0.5).compute_gradients(
                -self.OT_loss, var_list=secondary_variables)
        ############################
        # summaries
        tf.summary.scalar('domain/primary_loss', self.primary_loss)
        tf.summary.scalar('primary_loss/src_loss_class', self.src_loss_class)
        tf.summary.scalar('primary_loss/data_shift_loss', tf.reduce_mean(self.data_shift_loss))
        tf.summary.scalar('primary_loss/label_shift_loss', tf.reduce_mean(self.label_shift_loss))
        tf.summary.scalar('primary_loss/data_label_shift_loss', tf.reduce_mean(self.data_label_shift_loss))
        tf.summary.scalar('primary_loss/exp_term', tf.reduce_mean(self.exp_term))
        tf.summary.histogram('primary_loss/g_batch', self.g_network)
        tf.summary.scalar('primary_loss/g_network_loss', self.g_network_loss)
        tf.summary.scalar('primary_loss/W_distance', self.OT_loss)

        tf.summary.scalar('acc/src_acc', self.src_accuracy)
        tf.summary.scalar('acc/trg_acc', self.trg_accuracy)

        tf.summary.scalar('trg_loss_class', self.trg_loss_class)
        tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate)
        tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off)
        tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off)
        tf.summary.scalar('hyperparameters/domain_trade_off',
                          self.domain_trade_off_ph if self.adapt_domain_trade_off
                          else self.domain_trade_off)
        tf.summary.scalar('hyperparameters/src_vat_trade_off', self.src_vat_trade_off)
        tf.summary.scalar('hyperparameters/trg_vat_off', self.trg_vat_troff)
        tf.summary.scalar('hyperparameters/trg_ent_off', self.trg_ent_troff)

        self.list_str_variables = OrderedDict()
        self.gradient_and_value = []
        if self.save_grads:
            with tf.name_scope("visualize"):
                for var in tf.trainable_variables():
                    tf.summary.histogram(var.op.name + '/values', var)
                for grad, var in self.grads_wrt_primary_loss:
                    if grad is not None:
                        tf.summary.histogram(var.op.name + '/grads_wrt_primary_loss', grad)
                        self.gradient_and_value += [(grad, var)]
                        self.list_str_variables[var.op.name] = grad.get_shape().as_list()
                for grad, var in self.grads_wrt_secondary_loss:
                    if grad is not None:
                        tf.summary.histogram(var.op.name + '/grads_wrt_secondary_loss', grad)
        self.tf_merged_summaries = tf.summary.merge_all()

    def mini_batch_class_balanced(self, label, sample_size=20, shuffle=False):
        if shuffle:
            rindex = np.random.permutation(len(label))
            label = label[rindex]

        n_class = len(np.unique(label))
        index = []
        for i in range(n_class):
            s_index = np.nonzero(label == i)

            if len(s_index[0]) < sample_size:
                s_ind = np.random.choice(s_index[0], sample_size)
            else:
                s_ind = np.random.permutation(s_index[0])

            index = np.append(index, s_ind[0:sample_size])
        index = np.array(index, dtype=int)
        return index

    def create_id_for_minibatch_balanced(self, idx_trg_samples, balanced_class=False):
        y_trg_logit = self.tf_session.run(self.y_trg_logit, feed_dict={
            self.x_trg: self.x_trg_[idx_trg_samples, :],
            self.is_training: False
        })
        y_trg_logit = np.argmax(y_trg_logit, 1)
        new_idx_trg_samples = self.mini_batch_class_balanced(y_trg_logit,
                            sample_size=self.sample_size) if balanced_class else idx_trg_samples
        return new_idx_trg_samples[:self.batch_size], y_trg_logit[:self.batch_size]

    def _fit_loop(self, x_src=None, y_src=None, x_trg=None, y_trg=None):

        self.x_src_ = x_src
        self.y_src_ = y_src
        self.x_trg_ = x_trg
        self.y_trg_ = y_trg

        print('Start training LDROT model at', os.path.abspath(__file__))
        print('============ LOG-ID: %s ============' % self.current_time)

        # num_src_samples = x_src.shape[0]
        num_trg_samples = x_trg.shape[0]

        self.tf_session.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=101)
        self.log_path = os.path.join(model_dir(), self.model_name, "logs",
                                     "{}".format(self.current_time))
        self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph)

        self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", "{}".format(self.ldrot_model_id))
        check_point = tf.train.get_checkpoint_state(self.checkpoint_path)

        if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path):
            print("Load model parameters from %s\n" % check_point.model_checkpoint_path)
            saver.restore(self.tf_session, check_point.model_checkpoint_path)

        for it in range(self.num_iters):
            idx_src_samples = self.mini_batch_class_balanced(self.y_src_, sample_size=self.sample_size)

            feed_data = dict()
            feed_data[self.x_src] = self.x_src_[idx_src_samples, :]
            feed_data[self.y_src] = self.y_src_[idx_src_samples]
            feed_data[self.y_src] = self.convert_one_hot(feed_data[self.y_src])

            if it >= self.pseudo_lbl_at:
                idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size * 10]
                idx_trg_samples, trg_pseudo_label = self.create_id_for_minibatch_balanced(idx_trg_samples)
                feed_data[self.x_trg] = self.x_trg_[idx_trg_samples]
                feed_data[self.y_trg] = self.y_trg_[idx_trg_samples]
                feed_data[self.y_trg] = self.convert_one_hot(feed_data[self.y_trg])
                feed_data[self.compute_muy_use_pseudolbl] = True
                feed_data[self.trg_pseudo_label] = trg_pseudo_label
            else:
                idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
                feed_data[self.x_trg] = self.x_trg_[idx_trg_samples, :]
                feed_data[self.y_trg] = self.y_trg_[idx_trg_samples]
                feed_data[self.y_trg] = self.convert_one_hot(feed_data[self.y_trg])
                feed_data[self.compute_muy_use_pseudolbl] = False
                feed_data[self.trg_pseudo_label] = self.y_trg_[idx_trg_samples]

            feed_data[self.is_training] = True

            # compute accuracy of source classifier on target data
            trg_acc_on_src_classifier = self.tf_session.run(self.src_accuracy, feed_dict={
                self.x_src: self.x_trg_[idx_trg_samples, :],
                self.y_src: self.convert_one_hot(self.y_trg_[idx_trg_samples]),
                self.is_training: False
            })

            for i in range(0, 5):
                g_idx_src_samples = self.mini_batch_class_balanced(self.y_src_, sample_size=self.sample_size)
                g_feed_data = dict()
                g_feed_data[self.x_src] = self.x_src_[g_idx_src_samples, :]
                g_feed_data[self.y_src] = self.y_src_[g_idx_src_samples]
                g_feed_data[self.y_src] = self.convert_one_hot(g_feed_data[self.y_src])

                if it >= self.pseudo_lbl_at:
                    g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size * 10]
                    g_idx_trg_samples, trg_pseudo_label = self.create_id_for_minibatch_balanced(g_idx_trg_samples)
                    g_feed_data[self.x_trg] = self.x_trg_[g_idx_trg_samples]
                    g_feed_data[self.y_trg] = self.y_trg_[g_idx_trg_samples]
                    g_feed_data[self.y_trg] = self.convert_one_hot(g_feed_data[self.y_trg])
                    g_feed_data[self.compute_muy_use_pseudolbl] = True
                    g_feed_data[self.trg_pseudo_label] = trg_pseudo_label
                else:
                    g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
                    g_feed_data[self.x_trg] = self.x_trg_[g_idx_trg_samples, :]
                    g_feed_data[self.y_trg] = self.y_trg_[g_idx_trg_samples]
                    g_feed_data[self.y_trg] = self.convert_one_hot(g_feed_data[self.y_trg])
                    g_feed_data[self.compute_muy_use_pseudolbl] = False
                    g_feed_data[self.trg_pseudo_label] = self.y_trg_[g_idx_trg_samples]

                g_feed_data[self.is_training] = True

                _, W_dist = \
                    self.tf_session.run(
                        [self.secondary_train_op, self.OT_loss],
                        feed_dict=g_feed_data
                    )

            _, max_term, src_loss_class, trg_loss_class, src_acc, trg_acc, data_shift_loss, label_shift_loss, data_label_shift_loss, g_network_loss, muy = \
                self.tf_session.run(
                    [self.primary_train_op, self.exp_term, self.src_loss_class,
                     self.trg_loss_class, self.src_accuracy, self.trg_accuracy, self.data_shift_loss,
                     self.label_shift_loss, self.data_label_shift_loss, self.g_network_loss, self.muy],
                    feed_dict=feed_data
                )

            if self.update_target_loss:
                _, target_loss = \
                    self.tf_session.run(
                        [self.target_train_op, self.target_loss],
                        feed_dict=feed_data
                    )

            if it == 0 or (it + 1) % self.summary_freq == 0:
                print("iter %d/%d W_dist %.3f; muy %.3f; exp_term %.3f; src_loss_class %.5f; trg_loss_class %.3f" % (
                    it + 1, self.num_iters, W_dist, muy, np.mean(max_term), src_loss_class, trg_loss_class))
                print("src_acc %.2f; trg_acc: %.2f; trg_acc_on_src_classifier %.2f" % (src_acc * 100, trg_acc * 100, trg_acc_on_src_classifier * 100))
                print(
                    "data_shift_loss %.3f; label_shift_loss %.3f; data_label_shift_loss %.3f; g_network_loss %.3f" % (
                    np.mean(data_shift_loss), np.mean(label_shift_loss), np.mean(data_label_shift_loss),
                    g_network_loss))

                summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data)
                self.tf_summary_writer.add_summary(summary, it + 1)
                self.tf_summary_writer.flush()

            if it == 0 or (it + 1) % self.summary_freq == 0:
                if not self.only_save_final_model:
                    self.save_trained_model(saver, it + 1)
                elif it + 1 == self.num_iters:
                    self.save_trained_model(saver, it + 1)
                # Save acc values
                self.save_value(step=it + 1)

    def save_trained_model(self, saver, step):
        # Save model
        checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
                                       "{}".format(self.current_time))
        checkpoint_path = os.path.join(checkpoint_path, "ldrot_" + self.current_time + ".ckpt")

        directory = os.path.dirname(checkpoint_path)
        if not os.path.exists(directory):
            os.makedirs(directory)
        saver.save(self.tf_session, checkpoint_path, global_step=step)

    def save_value(self, step):
        acc_trg_test_ema, trg_test_acc, summary = self.compute_value(xt_full=self.x_trg_full,
                                                                     yt=self.y_trg_full, labeler=None)

        self.tf_summary_writer.add_summary(summary, step)
        self.tf_summary_writer.flush()

        print_list = ['trg_test_ema', round(acc_trg_test_ema*100, 2),
                      'trg_test_acc', round(trg_test_acc*100, 2)]

        print(print_list)

    def compute_value(self, xt_full, yt, labeler, full=False, shuffle=True):
        # Convert y to one-hot encoder
        yt = self.convert_one_hot(yt)

        if shuffle:
            with tb.nputils.FixedSeed(0):
                shuffle1 = np.random.permutation(len(xt_full))

            xt_shuff = xt_full[shuffle1]
            yt_shuff = yt[shuffle1] if yt is not None else None
        else:
            xt_shuff = xt_full
            yt_shuff = yt

        if not full:
            xt_shuff = xt_shuff[:1000]
            yt_shuff = yt_shuff[:1000] if yt_shuff is not None else None

        nt = len(xt_shuff)
        bs = 200

        ema_acc_full = np.ones(nt, dtype=float)
        test_acc_full = np.ones(nt, dtype=float)

        for i in range(0, nt, bs):
            xt_batch = xt_shuff[i:i + bs]
            yt_batch = yt_shuff[i:i + bs] if yt_shuff is not None else labeler(xt_batch)
            ema_acc_batch = self.fn_batch_ema_acc(xt_batch, yt_batch)  # (batch_size,)
            test_acc_batch = self.fn_batch_student_acc(xt_batch, yt_batch, False)
            ema_acc_full[i:i + bs] = ema_acc_batch
            test_acc_full[i:i + bs] = test_acc_batch

        ema_acc = np.mean(ema_acc_full)
        test_acc = np.mean(test_acc_full)

        summary1 = tf.Summary.Value(tag='trg_test/ema_acc', simple_value=ema_acc)
        summary2 = tf.Summary.Value(tag='trg_test/trg_acc', simple_value=test_acc)

        summary = tf.Summary(value=[summary1, summary2])
        return ema_acc, test_acc, summary
