import keras
from keras.datasets import mnist, fashion_mnist
from keras.layers import Dense, Conv2D, Flatten, Dropout
from keras.layers.pooling import MaxPooling2D, AveragePooling2D
from keras.models import Sequential
from random import randint
import numpy as np
import keras.backend as K
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
import scipy.stats

# Load and preprocess dataset
num_classes = 10
# Set threshold to 0.02 for mnist or 0.05 for fashion_mnist
threshold = 0.05
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
y_test = keras.utils.to_categorical(y_test, num_classes)
y_train = keras.utils.to_categorical(y_train, num_classes)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
n = int(x_train.shape[0])
m = int(n/2)
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)
input_shape = (x_train.shape[1], x_train.shape[2], 1)

# The modified LeNet architecture from (Zhou+ 19) and (Dziugaite+ 20)
def create_CNN(): 
    model = Sequential()
    model.add(Conv2D(20, kernel_size=(5, 5), strides=(1, 1), activation='linear', input_shape=(28,28,1),
								padding="valid"))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
    model.add(Conv2D(50, kernel_size=(5, 5), strides=(1, 1), activation='linear', padding='valid'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
    model.add(Flatten())
    model.add(Dense(500, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.SGD(), metrics=["acc"])
    model.summary()
    return model

# Define learning rate schedule
def decayed_learning_rate(epoch):
    initial_learning_rate = 0.01
    decay_epoch = 20
    decay_rate = 2
    return initial_learning_rate / (1 + decay_rate * np.floor(epoch / decay_epoch))
lrate = LearningRateScheduler(decayed_learning_rate)

def evaluate(model, x_train, y_train, x_test, y_test, filename = "weights_pw.{epoch:02d}.hdf5",
             batch_size=128, epochs=5):
    checkpoint_mnist_1 = ModelCheckpoint(filename, monitor='loss', verbose=1,
    save_best_only=False, mode='auto', period=1, save_weights_only=True)
    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=True,
                        validation_data=(x_test,y_test),
                        callbacks=[checkpoint_mnist_1, lrate])
    loss, accuracy  = model.evaluate(x_test, y_test, verbose=False)
    train_acc = history.history['acc']
    return history
    
# Find sigma with 1 signif digits decreasing train acc by at most cutoff
def find_sig(model, cutoff, x_train_1, y_train_1, sig_base):
    sig_found = 0
    num_samples = 5
    weights = model.get_weights()
    weights_flat = weights[0].flatten()
    for j in range( len(weights)-1 ):
        weights_flat = np.concatenate( (weights_flat, weights[j+1].flatten()) )

    while(sig_found == 0):
        sig_found = 1
        avg_train = 0
        for sampleiteration in range(num_samples):
            weights_flat_new = np.random.normal(weights_flat, sig_base*np.ones_like( weights_flat ))
            reconstr = []; counter = 0
            for l in range(len(weights)):
                layshape = weights[l].shape
                reconstr.append( np.reshape( weights_flat_new[counter : (counter+np.prod(layshape)) ], layshape ) )
                counter = counter + np.prod(layshape)

            model.set_weights(reconstr)
            [trainloss, trainacc] = model.evaluate(x_train_1, y_train_1,verbose=0)
            avg_train = avg_train + trainacc
        if avg_train < num_samples*cutoff:
            sig_found = 0
            sig_base = sig_base/10
        if sig_base < 0.0001:
            print("Failed to find good sigma. Change cutoff")
            break
    for i in range(2,9):
        sig = sig_base*i
        avg_train = 0
        for sampleiteration in range(num_samples):
            weights_flat_new = np.random.normal(weights_flat, sig*np.ones_like( weights_flat ))
            reconstr = []; counter = 0
            for l in range(len(weights)):
                layshape = weights[l].shape
                reconstr.append( np.reshape( weights_flat_new[counter : (counter+np.prod(layshape)) ], layshape ) )
                counter = counter + np.prod(layshape)

            model.set_weights(reconstr)
            [trainloss, trainacc] = model.evaluate(x_train_1, y_train_1,verbose=0)
            avg_train = avg_train + trainacc
        if avg_train < num_samples*cutoff:
            sig = sig_base*(i-1)
            break
    return sig
   
B = 512
E = 51

# Set initial weights in a reference model
model_ref = create_CNN()
model_1 = create_CNN()
model_1.set_weights( model_ref.get_weights() )

# Train Q_W networks on subsets of the supersample
num_qw_iterations = 10
for qw_iteration in range(num_qw_iterations):
    new_start= int(qw_iteration*np.floor((n-m-1)/(num_qw_iterations+1)))
    x_train_2 = x_train [ range(new_start,new_start+m) ]
    y_train_2 = y_train [ range(new_start, new_start+m) ]
    model_1.set_weights( model_ref.get_weights() )
    history_2 = evaluate(model_1, x_train_2, y_train_2, x_test, y_test, batch_size = B, epochs=E,
              filename="weights_qw_%02d.{epoch:02d}.hdf5" % qw_iteration)
	
# Average the Q_W runs to form the prior
weights = model_1.get_weights()
for qw_iteration in range(num_qw_iterations):
    model_1.load_weights('weights_qw_%02d.%02d.hdf5' % (qw_iteration, (E-1))  )
    weights_prior = model_1.get_weights()
    weights_prior_flat_2 = weights_prior[0].flatten()
    for i in range( len(weights)-1 ):
        weights_prior_flat_2 = np.concatenate( (weights_prior_flat_2, weights_prior[i+1].flatten()) )
    if qw_iteration == 0:
        weights_prior_flat = np.zeros_like(weights_prior_flat_2)
    weights_prior_flat = weights_prior_flat + weights_prior_flat_2
weights_prior_flat = weights_prior_flat/num_qw_iterations

# Find \tilde sigma_2, the candidate for sigma_2
reconstr = []
counter = 0
for l in range(len(weights)):
    layshape = weights[l].shape
    reconstr.append( np.reshape( weights_prior_flat[counter : (counter+np.prod(layshape)) ], layshape ) )
    counter = counter + np.prod(layshape)
model_1.set_weights(reconstr)

[unused,trainacc] = model_1.evaluate(x_train, y_train,verbose=0)
cutoff = trainacc-threshold
sig_prior = find_sig(model_1, cutoff, x_train_2, y_train_2, 0.1)


# Run posterior experiments over a number of instances of Z(S)
nr_S_samples = 10
for S_sample in range(nr_S_samples):
    model_1 = create_CNN()
    model_1.set_weights( model_ref.get_weights() )
    # Make the half-split for actual training set
    s_n = np.random.choice(n, m, replace = False)
    x_train_1 = x_train[ s_n,: ]
    y_train_1 = y_train[ s_n ]
    history_1 = evaluate(model_1, x_train_1, y_train_1, x_test, y_test, batch_size = B, epochs=E,
             filename="weights_pw.{epoch:02d}.hdf5")

    # Find sigma_1 and select the best sigma_2 from the candidates
    cutoff = history_1.history['acc'][E-1]-threshold
    sig = find_sig(model_1, cutoff, x_train_1, y_train_1, 0.1)
    sig_prior_b = np.floor( np.log10(sig_prior) )
    if sig >= 10**(sig_prior_b-1) and sig <= 9 * 10**(sig_prior_b+1):
        sig_prior = sig

    training_list = []; test_list = []; pacb_bound_list = [];
    sd_bound_list = []; guarantee_list = []; fastguarantee_list = [];

    # Compute empirical test/train losses and bounds over the epochs
    EList = range(0,E,3)
    for t in EList:
        model_1.load_weights('weights_pw.%02d.hdf5' % (t+1) )
        weights = model_1.get_weights()
        weights_flat = weights[0].flatten()
        for i in range( len(weights)-1 ):
            weights_flat = np.concatenate( (weights_flat, weights[i+1].flatten()) )
            
        weights_prior_flat = np.zeros_like(weights_flat)
        for qw_iteration in range(num_qw_iterations):
            model_1.load_weights('weights_qw_%02d.%02d.hdf5' % (qw_iteration, (t+1))  )
            weights_prior = model_1.get_weights()
            weights_prior_flat_2 = weights_prior[0].flatten()
            for i in range( len(weights)-1 ):
                weights_prior_flat_2 = np.concatenate( (weights_prior_flat_2, weights_prior[i+1].flatten()) )
            weights_prior_flat = weights_prior_flat + weights_prior_flat_2
        weights_prior_flat = weights_prior_flat/num_qw_iterations

	    # Compute PAC-Bayesian slow-rate bound
        diffnorm = np.linalg.norm( weights_flat - weights_prior_flat )
        dim = weights_flat.shape[0]
        kldiv = 0.5 *( ( diffnorm**2 )/ (sig_prior**2)  + dim*( (sig/sig_prior)**2  - 1  + 2*np.log( sig_prior/sig )   )  )
        # Choose delta such that the final results hold with confidence .95
        delta = 0.05/54
        gen_bound = np.sqrt( (kldiv + np.log(1/delta))*2/m )
        pacb_bound_list.append(gen_bound)

	    # Compute single-draw slow-rate bound
        dP = scipy.stats.norm(weights_flat, sig*np.ones_like(weights_flat) ).pdf(weights_flat)
        dQ = scipy.stats.norm(weights_prior_flat, sig_prior*np.ones_like(weights_prior_flat) ).pdf(weights_flat)
        dPdQ =  np.sum(  np.sum( np.log(dP) - np.log(dQ) )  )
        gen_sd = np.sqrt ( 2/m * ( dPdQ + np.log(1/delta) ) )
        sd_bound_list.append(gen_sd)
    

        avg_test = 0.0;    avg_train = 0.0;    avg_gen = 0.0;
        # Estimate the test and train accuracies of the randomized nets
        num_samples = 5
        for i in range(num_samples):
            # Draw random weights from the posterior
            weights_flat_new = np.random.normal(weights_flat, sig*np.ones_like( weights_flat ))
            reconstr = [];  counter = 0
            for l in range(len(weights)):
                layshape = weights[l].shape
                reconstr.append( np.reshape( weights_flat_new[counter : (counter+np.prod(layshape)) ], layshape ) )
                counter = counter + np.prod(layshape)

            # Evaluate the training and test losses for these weights
            model_1.set_weights(reconstr)
            [trainloss, trainacc] = model_1.evaluate(x_train_1, y_train_1,verbose=0)
            [testloss, testacc] = model_1.evaluate(x_test, y_test,verbose=0)
            avg_train = avg_train + trainacc
            avg_test = avg_test + testacc
            avg_gen = avg_gen + abs(avg_test-avg_train)

        avg_train = avg_train/num_samples; avg_test = avg_test/num_samples;
        avg_gen = avg_gen/num_samples; guarantee = avg_train - gen_bound;
    
        # Convert the slow-rate bound to a fast-rate test loss bound
        fastguarantee = 1.795*(1-avg_train) + (2.98/2) * gen_bound**2
    
        training_list.append(1-avg_train)
        test_list.append(1-avg_test)
        guarantee_list.append(1-guarantee)
        fastguarantee_list.append(fastguarantee)
    
    # Save all of the data
    with open('data_lenet_mnist_{0}.txt'.format(S_sample), 'w') as f:
	    f.write("Epoch Train Test SlowGuarantee SD PACB FastGuarantee\n")
	    for iteration in range(len(training_list)):
		    f.write("{0} {1} {2} {3} {4} {5} {6}\n" .format( EList[iteration], training_list[iteration], test_list[iteration], guarantee_list[iteration], sd_bound_list[iteration], pacb_bound_list[iteration], fastguarantee_list[iteration]   ) )
