from inits import *
from read_svhn import read_svhn
from read_mnist import read_mnist
from batch_generator import batch_generator, batch_generator_source, batch_generator_target
from model import BaseModel
import pickle
import time
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

src_domain = 'svhn'
tgt_domain = 'mnist'
epochs = 10000
batch_size = 64
eps = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]


def train(X_source, Y_source, X_target, Y_target, X_target_test, Y_target_test, training_mode='domain_adaptation', target_index=0, mu=0., verbose=False):
    num_lab = 100
    num_lab_each_epoch = 1
    lab_target = [X_target[:num_lab], Y_target[:num_lab]]
    unlab_target = [X_target[num_lab:], np.zeros_like(Y_target[num_lab:])]

    model = BaseModel(len(X_source), batch_size, num_lab_each_epoch=num_lab_each_epoch, mu=mu)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Batch generators
        gen_source_batch = batch_generator_source([X_source, Y_source], batch_size // 2)
        gen_target_batch = batch_generator_target(lab_target, batch_size // 2, unlab_data=unlab_target, num_lab_each_epoch=num_lab_each_epoch)
        domain_labels = np.vstack([np.tile([1., 0.], [batch_size // 2, 1]), np.tile([0., 1.], [batch_size // 2, 1])])

        for i in range(epochs):
            p = float(i) / epochs
            l = 2. / (1. + np.exp(-10. * p)) - 1
            lr = 0.01 / (1. + 10 * p) ** 0.75

            if training_mode == 'domain_adaptation':
                sample_source, sample_source_label = next(gen_source_batch)
                sample_target, sample_target_label = next(gen_target_batch)
                X = np.vstack([sample_source, sample_target])
                y = np.vstack([sample_source_label, sample_target_label])

                _, batch_loss = sess.run(
                    [model.train_op, model.total_loss],
                    feed_dict={model.X: X,
                               model.y: y,
                               model.domain: domain_labels,
                               model.train: True,
                               model.l: l,
                               model.learning_rate: lr})

                if verbose and (i + 1) % 100 == 0:
                    print('\t Epoch: {}, loss: {}'.format(i + 1, batch_loss))

        target_acc = sess.run(model.label_acc, feed_dict={model.X: X_target_test,
                                                          model.y: Y_target_test,
                                                          model.train: False})

        print('Test results:')
        print('Target (MNIST) accuracy: {:.4f}'.format(target_acc))

        '''Save the predictive labels on target train data'''
        target_train_preds = sess.run(model.pred, feed_dict={model.X: X_target,
                                                             model.y: Y_target,
                                                             model.train: False})
        with open("pred_labels/" + tgt_domain + "Target_" + str(eps[target_index]) + ".pkl", "wb") as pkl_file:
            pickle.dump(target_train_preds, pkl_file)

    return target_acc


if __name__ == '__main__':
    data_root = 'data/'

    X_source, Y_source, _, _ = read_svhn(data_root)
    for mu in [0.0]:
        results = []
        for i in range(len(eps)):
            print("Running on {}-th time stamp of target domain with mu = {}".format(i, mu))
            tf.reset_default_graph()
            all_X_source, all_Y_source = [X_source], [Y_source]
            for j in range(i):
                all_mnist = pickle.load(open('data/MNIST_target_{}.pkl'.format(eps[j]), "rb"))
                Y_hist_target = pickle.load(open("pred_labels/" + tgt_domain + "Target_" + str(eps[j]) + ".pkl", "rb"))
                all_X_source.append(all_mnist['X_train'])
                all_Y_source.append(Y_hist_target)

            if os.path.isfile('data/MNIST_target_{}.pkl'.format(eps[i])):
                all_mnist = pickle.load(open('data/MNIST_target_{}.pkl'.format(eps[i]), "rb"))
                X_target, Y_target, X_target_test, Y_target_test = all_mnist['X_train'], all_mnist['Y_train'], all_mnist['X_test'], all_mnist['Y_test']
            else:
                X_target, Y_target, X_target_test, Y_target_test = read_mnist(data_root, eps=eps[i])
            target_acc = train(all_X_source, all_Y_source, X_target, Y_target, X_target_test, Y_target_test, target_index=i, mu=mu)
            results.append(target_acc)
        print(results)

