from __future__ import print_function
import tensorflow as tf
import keras
import keras.backend as K

import numpy as np
from copy import deepcopy


def set_indices(y_train):
    n_classes = np.unique(y_train).shape[0]
    n = y_train.shape[0]
    idx_labeled, idx_unlabeled, idx_valid = [], [], []
    for c in range(n_classes):
        idx = np.where(y_train == c)[0]
        np.random.shuffle(idx)
        idx_labeled += idx[:2].tolist()
        idx_unlabeled += idx[2:idx.shape[0]-500].tolist()
        idx_valid += idx[idx.shape[0]-500:].tolist()
    idx_labeled = np.array(idx_labeled)
    idx_unlabeled = np.array(idx_unlabeled)
    idx_valid = np.array(idx_valid)
    np.random.shuffle(idx_labeled)
    np.random.shuffle(idx_unlabeled)
    np.random.shuffle(idx_valid)

    return idx_labeled, idx_unlabeled, idx_valid


def make_model(input_shape):
    model = keras.models.Sequential([
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape, kernel_initializer='he_normal'),
        keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal'),
        keras.layers.MaxPool2D((2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu', kernel_initializer='he_normal'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(n_classes, kernel_initializer='he_normal'),
        keras.layers.Activation('softmax')
    ])
    opt = keras.optimizers.adam(lr=0.001, beta_1=0.9, beta_2=0.999)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

    return model


def get_perturbed_output(model, X_labeled, y_labeled, X_pool, N, rho_star, sigma, beta):
    weight_layer = [l for l in range(len(model.layers)) if len(model.layers[l].get_weights()) > 0]
    l = weight_layer[-1]
    weights_ori = model.layers[l].get_weights()
    weights0 = deepcopy(weights_ori)
    err_hat = 1 - model.evaluate(X_labeled, y_labeled, verbose=0)[1]
    y_hat = np.argmax(model.predict(X_pool), axis=1)
    gammas = []
    predictions = []
    for n in range(N):
        weights_add = [np.random.normal(0, sigma, weight.shape) for weight in weights0]
        weights_ = [weights0[i] + weights_add[i] for i in range(len(weights0))]
        model.layers[l].set_weights(weights_)
        err_n = 1 - model.evaluate(X_labeled, y_labeled, verbose=0)[1]
        y_n = np.argmax(model.predict(X_pool), axis=1)
        gammas.append(np.exp(-np.max([0, err_n - err_hat])))
        predictions.append(y_n)
        rho_n = sum(y_hat != y_n) / len(y_hat)
        sigma *= np.exp(-beta * (rho_n - rho_star))
    model.layers[l].set_weights(weights_ori)
    predictions = np.array(predictions)
    gammas = np.array(gammas)

    return predictions, gammas, sigma


def measure_uncertainty(predictions, gammas):
    n_pool = predictions.shape[1]
    weighted_class = np.zeros((n_pool, 10))
    for c in range(10):
        weighted_class[:, c] = np.sum(np.reshape(gammas, (-1, 1)) * (predictions==c), axis=0)
    f_w = np.max(weighted_class, axis=1)
    uncertainty = 1.0 - f_w / np.sum(gammas)

    return uncertainty



# gpu setting
config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

# hyperparameters
n_classes = 10
query_size = 20
steps = 50
pool_size = 2000
rho_star = query_size / pool_size
beta = 1
N = 10
sigma = 0.01

# load dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) / 255.0
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) / 255.0

# set indices
idx_labeled, idx_unlabeled, idx_valid = set_indices(y_train)

# start acquisition step
# test_accs = []
for step in range(steps):
    # training part
    X_labeled = X_train[idx_labeled]
    y_labeled = y_train[idx_labeled]
    model = make_model(X_train.shape[1:])
    model.fit(X_labeled, y_labeled, batch_size=32, epochs=50)
    test_acc = model.evaluate(X_test, y_test, verbose=0)[1]
    # test_accs.append(test_acc)
    print(f'Step {step:03d} - test acc: {test_acc:.5f}\n')

    # acquisition part
    np.random.shuffle(idx_unlabeled)
    idx_pool = idx_unlabeled[:pool_size]
    X_pool = X_train[idx_pool]
    predictions, gammas, sigma = get_perturbed_output(model, X_labeled, y_labeled, X_pool, N, rho_star, sigma, beta)
    uncertainty = measure_uncertainty(predictions, gammas)
    idx_sorted = np.argsort(uncertainty)[::-1]
    idx_labeled = np.concatenate((idx_labeled, idx_pool[idx_sorted[:query_size]]))
    idx_unlabeled = np.setdiff1d(idx_unlabeled, idx_labeled)

    del model
    tf.compat.v1.keras.backend.clear_session()

# final training
X_labeled = X_train[idx_labeled]
y_labeled = y_train[idx_labeled]
model = make_model(X_train.shape[1:])
model.fit(X_labeled, y_labeled, batch_size=32, epochs=50)
test_acc = model.evaluate(X_test, y_test, verbose=0)[1]
# test_accs.append(test_acc)
print(f'Final Step - test acc: {test_acc:.5f}\n')

del model
tf.compat.v1.keras.backend.clear_session()
