from tensorflow.keras.layers import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import AveragePooling2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras.layers import Flatten, MaxPool2D
from keras.layers import Input
from keras.models import Model
from keras.layers import concatenate
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from keras.datasets import mnist, fashion_mnist, cifar10
from tensorflow.keras.optimizers import SGD
from tensorflow import keras
import tensorflow as tf
import tensorflow_probability as tfp


from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Dense
from keras.layers import Flatten
from tensorflow.keras.optimizers import SGD

from keras.layers import Dropout
from keras.layers import BatchNormalization

import copy
import pickle
import numpy as np
import os
import sys
import time

import random 
from keras.utils import np_utils
#####################################################


batch_size = 120
nb_classes = 10


nb_epoch = 100

# input image dimensions
img_rows, img_cols = 28, 28

score=0
all_accuracy = 0
all_accuracy_spn = 0 
acquisition_iterations = 98


dropout_iterations = 1
Queries = 10


Experiments_All_Accuracy = np.zeros(shape=(acquisition_iterations+1))

(X_train_All, y_train_All), (X_test, y_test) = mnist.load_data()

X_train_All = X_train_All.reshape(X_train_All.shape[0],img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0],img_rows, img_cols,1)

random_split = np.asarray(random.sample(range(0,X_train_All.shape[0]), X_train_All.shape[0]))

X_train_All = X_train_All[random_split, :, :, :]
y_train_All = y_train_All[random_split]
X_valid = X_train_All[5000:10000, :, :, :]
y_valid = y_train_All[5000:10000]
X_Pool = X_train_All[10000:60000, :, :, :]
y_Pool = y_train_All[10000:60000]
X_train_All = X_train_All[0:5000, :, :, :]
y_train_All = y_train_All[0:5000]

#training data to have equal distribution of classes
idx_0 = np.array( np.where(y_train_All==0)  ).T
idx_0 = idx_0[0:2,0]
X_0 = X_train_All[idx_0, :, :, :]
y_0 = y_train_All[idx_0]

idx_1 = np.array( np.where(y_train_All==1)  ).T
idx_1 = idx_1[0:2,0]
X_1 = X_train_All[idx_1, :, :, :]
y_1 = y_train_All[idx_1]

idx_2 = np.array( np.where(y_train_All==2)  ).T
idx_2 = idx_2[0:2,0]
X_2 = X_train_All[idx_2, :, :, :]
y_2 = y_train_All[idx_2]

idx_3 = np.array( np.where(y_train_All==3)  ).T
idx_3 = idx_3[0:2,0]
X_3 = X_train_All[idx_3, :, :, :]
y_3 = y_train_All[idx_3]

idx_4 = np.array( np.where(y_train_All==4)  ).T
idx_4 = idx_4[0:2,0]
X_4 = X_train_All[idx_4, :, :, :]
y_4 = y_train_All[idx_4]

idx_5 = np.array( np.where(y_train_All==5)  ).T
idx_5 = idx_5[0:2,0]
X_5 = X_train_All[idx_5, :, :, :]
y_5 = y_train_All[idx_5]

idx_6 = np.array( np.where(y_train_All==6)  ).T
idx_6 = idx_6[0:2,0]
X_6 = X_train_All[idx_6, :, :, :]
y_6 = y_train_All[idx_6]

idx_7 = np.array( np.where(y_train_All==7)  ).T
idx_7 = idx_7[0:2,0]
X_7 = X_train_All[idx_7, :, :, :]
y_7 = y_train_All[idx_7]

idx_8 = np.array( np.where(y_train_All==8)  ).T
idx_8 = idx_8[0:2,0]
X_8 = X_train_All[idx_8, :, :, :]
y_8 = y_train_All[idx_8]

idx_9 = np.array( np.where(y_train_All==9)  ).T
idx_9 = idx_9[0:2,0]
X_9 = X_train_All[idx_9, :, :, :]
y_9 = y_train_All[idx_9]

X_train = np.concatenate((X_0, X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9), axis=0 )
y_train = np.concatenate((y_0, y_1, y_2, y_3, y_4, y_5, y_6, y_7, y_8, y_9), axis=0 )


print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')



X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_valid = X_valid.astype('float32')
X_Pool = X_Pool.astype('float32')
X_train /= 255
X_valid /= 255
X_Pool /= 255
X_test /= 255

Y_test = np_utils.to_categorical(y_test, nb_classes)
Y_valid = np_utils.to_categorical(y_valid, nb_classes)
Y_Pool = np_utils.to_categorical(y_Pool, nb_classes)


 
x_pool_All = np.zeros(shape=(1))

Y_train = np_utils.to_categorical(y_train, nb_classes)




# ######

# model

 

def define_model():
  
  model = Sequential()
  model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(28, 28, 1)))

  model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))

  model.add(MaxPooling2D((2, 2)))

  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))

  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))

  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.5))






  model.add(Flatten())
  model.add(Dense(128, name="feature", activation='relu', kernel_initializer='he_uniform')) 
 
  model.add(Dropout(0.5))
  model.add(Dense(10, activation='softmax'))
  # compile model
  opt = SGD(learning_rate=0.001, momentum=0.9) 
 
  model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  return model

start = time.time()
 
tf.keras.backend.clear_session()
deterministic_model = define_model()

 


hist = deterministic_model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch,  verbose=1,validation_data=(X_valid, Y_valid))


print('Evaluating Test Accuracy Without Acquisition')
score, acc = deterministic_model.evaluate(X_test,Y_test ,  verbose=1)

predi = deterministic_model.predict(X_test, batch_size=batch_size)

def accur(y_hat, obs):
  j = 0
  
  for i in range(y_hat.shape[0]):
 
    if np.argmax(y_hat, axis = 1)[i]== np.argmax(obs, axis=1)[i]:
      
      j = j+1
 
  return j/float(y_hat.shape[0])

all_accuracy = accur(predi, Y_test)
                          

all_accuracy = acc



for i in range(acquisition_iterations):
  print('POOLING ITERATION', i)
  pool_subset = 5000
  pool_subset_dropout = np.asarray(random.sample(range(0,X_Pool.shape[0]), pool_subset))

  X_Pool_Dropout = X_Pool[pool_subset_dropout, :, :, :]
  y_Pool_Dropout = y_Pool[pool_subset_dropout]
  All_Dropout_Classes = np.zeros(shape=(X_Pool_Dropout.shape[0],10)) 





  lst = []
  for i in range(All_Dropout_Classes.shape[0]):
    lst.append(np.random.uniform(0,1,1))

  a_1d = np.array(lst).flatten()


   

  x_pool_index = a_1d.argsort()[-Queries:][::-1] 

  #store all the pooled images indexes
  x_pool_All = np.append(x_pool_All, x_pool_index)

  Pooled_X = X_Pool_Dropout[x_pool_index, :, :, :]
  Pooled_Y = y_Pool_Dropout[x_pool_index]

  #first delete the random subset used for test time dropout from X_Pool
  #Delete the pooled point from this pool set (this random subset)
  #then add back the random pool subset with pooled points deleted back to the X_Pool set
  delete_Pool_X = np.delete(X_Pool, (pool_subset_dropout), axis=0)
  delete_Pool_Y = np.delete(y_Pool, (pool_subset_dropout), axis=0)

  delete_Pool_X_Dropout = np.delete(X_Pool_Dropout, (x_pool_index), axis=0)
  delete_Pool_Y_Dropout = np.delete(y_Pool_Dropout, (x_pool_index), axis=0)

  X_Pool = np.concatenate((delete_Pool_X, delete_Pool_X_Dropout), axis=0)
  y_Pool = np.concatenate((delete_Pool_Y, delete_Pool_Y_Dropout), axis=0)


  print('Acquised Points added to training set')

  X_train = np.concatenate((X_train, Pooled_X), axis=0)
  y_train = np.concatenate((y_train, Pooled_Y), axis=0)
  print('Train Model with pooled points')

  # convert class vectors to binary class matrices
  Y_train = np_utils.to_categorical(y_train, nb_classes)
  #retrain the model with new and old data

  deterministic_model = define_model()

  hist = deterministic_model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch, verbose=1, validation_data=(X_valid, Y_valid))


  print('Evaluate Model Test Accuracy with pooled points')
  
  score, acc = deterministic_model.evaluate(X_test,Y_test ,  verbose=1)
  print('Test score:', score)
  print('Test accuracy:', acc)
  all_accuracy = np.append(all_accuracy, acc)


file_name = "random_mnist_1.pkl"
open_file = open(file_name, "wb")

pickle.dump(all_accuracy, open_file)

open_file.close()

endd = time.time()
print("time runing: ", endd-start)


















