from tensorflow.keras.layers import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import AveragePooling2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras.layers import Flatten, MaxPool2D
from keras.layers import Input
from keras.models import Model
from keras.layers import concatenate

from tensorflow.keras.optimizers import SGD

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import RMSprop
from keras.datasets import mnist, fashion_mnist, cifar10
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils
from sklearn.preprocessing import LabelBinarizer

import copy
import pickle
import numpy as np
import os
import time
import random 
from pathlib import Path
import scipy.io
#################functions ###############
def accur(y_hat, obs):
  j = 0
  
  for i in range(y_hat.shape[0]):
 
    if np.argmax(y_hat, axis = 1)[i]== np.argmax(obs, axis=1)[i]:
      
      j = j+1
 
  return j/float(y_hat.shape[0])

def var_ratio(x):
  return  1 - max(x)
def max_ent(x):
  return -sum(x*np.log2(x))

def weight_un(x):
  return x/x.max()


def load_np(fil_name):
  with open(fil_name, "rb") as f:
    return np.load(f)


#########################################




# batch_size = 100
batch_size_spn  = 32
nb_classes = 10


nb_epoch = 250

# input image dimensions
img_rows, img_cols = 32, 32

score=0
all_accuracy = 0
all_accuracy_spn = 0 
acquisition_iterations = 9

#use a large number of dropout iterations
dropout_iterations = 1
Queries = 1000


datapath =Path(__file__).resolve().parent.parent



# Load the data

train_raw = scipy.io.loadmat(os.path.join(datapath, "train_32x32.mat"))
test_raw = scipy.io.loadmat(os.path.join(datapath,"test_32x32.mat"))


# Load images and labels

train_images = np.array(train_raw['X'])
test_images = np.array(test_raw['X'])

train_labels = train_raw['y']
test_labels = test_raw['y']
train_images = np.moveaxis(train_images, -1, 0)
test_images = np.moveaxis(test_images, -1, 0)

train_images = train_images.astype('float64')
test_images = test_images.astype('float64')


# Convert train and test labels into 'int64' type

train_labels = train_labels.astype('int64')
test_labels = test_labels.astype('int64')




lb = LabelBinarizer()
# train_labels = lb.fit_transform(train_labels)
# test_labels = lb.fit_transform(test_labels)








########################## Setting##############################




batch_size = 100

nb_classes = 10


nb_epoch = 250

# input image dimensions
img_rows, img_cols = 32, 32


score=0
all_accuracy = 0 #The array to save accuracy of all acquisitions
acquisition_iterations = 9


dropout_iterations = 1
Queries = 1000
###########################
def load_np(fil_name):
  with open(fil_name, "rb") as f:
    return np.load(f)

###########################

Experiments_All_Accuracy = np.zeros(shape=(acquisition_iterations+1))
# (X_train_All, y_train_All), (X_test, y_test) = cifar10.load_data()

(X_train_All, y_train_All), (X_test, y_test) = (train_images, train_labels), (test_images, test_labels)

train_images /= 255.0
test_images /= 255.0








# X_train_All = X_train_All.reshape(X_train_All.shape[0],img_rows, img_cols, 3)
# X_test = X_test.reshape(X_test.shape[0],img_rows, img_cols,3)
random_split = np.asarray(random.sample(range(0,X_train_All.shape[0]), X_train_All.shape[0]))
rand_path = Path(__file__).resolve().parent.parent.parent
random_split = load_np(os.path.join(rand_path ,"rand_init_4.npy"))
X_train_All = X_train_All[random_split, :, :, :]
y_train_All = y_train_All[random_split]
X_valid = X_train_All[5000:10000, :, :, :]
y_valid = y_train_All[5000:10000]
X_Pool = X_train_All[10000:, :, :, :]
y_Pool = y_train_All[10000:]
X_train_All = X_train_All[0:5000, :, :, :]
y_train_All = y_train_All[0:5000]

#training data to have equal distribution of classes
idx_0 = np.array( np.where(y_train_All==10)  ).T
idx_0 = idx_0[0:100,0]
X_0 = X_train_All[idx_0, :, :, :]
y_0 = y_train_All[idx_0]
print(X_0.shape)

idx_1 = np.array( np.where(y_train_All==1)  ).T
idx_1 = idx_1[0:100,0]
X_1 = X_train_All[idx_1, :, :, :]
y_1 = y_train_All[idx_1]
print(X_1.shape)

idx_2 = np.array( np.where(y_train_All==2)  ).T
idx_2 = idx_2[0:100,0]
X_2 = X_train_All[idx_2, :, :, :]
y_2 = y_train_All[idx_2]
print(X_2.shape)

idx_3 = np.array( np.where(y_train_All==3)  ).T
idx_3 = idx_3[0:100,0]
X_3 = X_train_All[idx_3, :, :, :]
y_3 = y_train_All[idx_3]
print(X_3.shape)

idx_4 = np.array( np.where(y_train_All==4)  ).T
idx_4 = idx_4[0:100,0]
X_4 = X_train_All[idx_4, :, :, :]
y_4 = y_train_All[idx_4]
print(X_4.shape)

idx_5 = np.array( np.where(y_train_All==5)  ).T
idx_5 = idx_5[0:100,0]
X_5 = X_train_All[idx_5, :, :, :]
y_5 = y_train_All[idx_5]
print(X_5.shape)

idx_6 = np.array( np.where(y_train_All==6)  ).T
idx_6 = idx_6[0:100,0]
X_6 = X_train_All[idx_6, :, :, :]
y_6 = y_train_All[idx_6]
print(X_6.shape)

idx_7 = np.array( np.where(y_train_All==7)  ).T
idx_7 = idx_7[0:100,0]
X_7 = X_train_All[idx_7, :, :, :]
y_7 = y_train_All[idx_7]
print(X_7.shape)

idx_8 = np.array( np.where(y_train_All==8)  ).T
idx_8 = idx_8[0:100,0]
X_8 = X_train_All[idx_8, :, :, :]
y_8 = y_train_All[idx_8]
print(X_8.shape)

idx_9 = np.array( np.where(y_train_All==9)  ).T
idx_9 = idx_9[0:100,0]
X_9 = X_train_All[idx_9, :, :, :]
y_9 = y_train_All[idx_9]
print(X_9.shape)

X_train = np.concatenate((X_0, X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9), axis=0 )
y_train = np.concatenate((y_0, y_1, y_2, y_3, y_4, y_5, y_6, y_7, y_8, y_9), axis=0 )


print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')


# X_train = X_train.astype('float32')
# X_test = X_test.astype('float32')
# X_valid = X_valid.astype('float32')
# X_Pool = X_Pool.astype('float32')
# X_train /= 255
# X_valid /= 255
# X_Pool /= 255
# X_test /= 255
#
# Y_test = np_utils.to_categorical(y_test, nb_classes)
# Y_valid = np_utils.to_categorical(y_valid, nb_classes)
# Y_Pool = np_utils.to_categorical(y_Pool, nb_classes)
lb = LabelBinarizer()
Y_test = lb.fit_transform(y_test)
Y_valid = lb.fit_transform(y_valid)
Y_Pool = lb.fit_transform(y_Pool)
Y_train = lb.fit_transform(y_train)

x_pool_All = np.zeros(shape=(1))








##########model and initial trining ################


# ######models

 

def define_model():
  model = Sequential()
  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))
  model.add(BatchNormalization())
  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.2))
  model.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.3))
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.4))
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))





  model.add(Dropout(0.5))
  model.add(Flatten())
  model.add(Dense(512, name="feature", activation='relu', kernel_initializer='he_uniform'))
  model.add(BatchNormalization(name="norm"))
  model.add(Dropout(0.5))
  model.add(Dense(10, activation='softmax'))
  # compile model
  opt = SGD(learning_rate=0.001, momentum=0.9)
  model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  return model

deterministic_model = define_model()

start = time.time()

hist = deterministic_model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch, verbose=1,validation_data=(X_valid, Y_valid))

 
score, acc = deterministic_model.evaluate(X_test,Y_test ,  verbose=1)
print("The accuracy for first acquisition:", acc)
predi = deterministic_model.predict(X_test, batch_size=batch_size)
all_accuracy = accur(predi, Y_test)
print("The accuracy for first acquisition:", all_accuracy)

############## SPN part for probability calculation ###################
#*****extracting representations************
layer_name = "feature"

intermediate_layer_model = keras.Model(inputs=deterministic_model.input,
                                       outputs=deterministic_model.get_layer(layer_name).output)


opt = SGD(learning_rate=0.001, momentum=0.9)

intermediate_layer_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
# intermediate_layer_model.summary()
intermediate_output = intermediate_layer_model.predict(X_train,batch_size=batch_size)


output_rep_tr = intermediate_output.reshape(X_train.shape[0],32, 16, 1)
#*************spn training ****************
import libspn_keras as spnk

print(spnk.get_default_sum_op())

spnk.set_default_sum_op(spnk.SumOpGradBackprop())
from tensorflow import keras

spnk.set_default_accumulator_initializer(
    keras.initializers.TruncatedNormal(stddev=0.5, mean=1.0)

)

import tensorflow_datasets as tfds




normalize = spnk.layers.NormalizeStandardScore(
    axes=spnk.layers.NormalizeAxes.GLOBAL, input_shape=(32, 16, 1)) 

train_x_ds = tf.data.Dataset.from_tensor_slices((output_rep_tr,)).batch(batch_size)

test_x_ds = tf.data.Dataset.from_tensor_slices((X_test,)).batch(batch_size)
def take_first(a, b):
  return a

normalize.adapt(train_x_ds)

def biuld_model(normalize, dr):
  spn = keras.Sequential([

                          
                          

    normalize,
    spnk.layers.NormalLeaf(
        num_components=16, 
        location_trainable=True,
        location_initializer=keras.initializers.TruncatedNormal(
            stddev=1.0, mean=0.0)
    ),
    # Non-overlapping products
    tf.keras.layers.Dropout(rate=dr),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[2, 2], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='valid'
    ),
    spnk.layers.Local2DSum(num_sums=16),
    # Non-overlapping products
    # tf.keras.layers.Dropout(rate=0.15),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[2, 2], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='valid'
    ),
    spnk.layers.Local2DSum(num_sums=32),
    # Overlapping products, starting at dilations [1, 1]
    # tf.keras.layers.Dropout(rate=0.1),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=32),
    # Overlapping products, with dilations [2, 2] and full padding
    # tf.keras.layers.Dropout(rate=0.1),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[2, 2], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=64),
    # Overlapping products, with dilations [2, 2] and full padding
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[4, 4], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=64),
    # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
    # all scopes
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[8, 8], 
        kernel_size=[2, 2],
        padding='final'
    ),
    spnk.layers.SpatialToRegions(),
    # Class roots
    spnk.layers.DenseSum(num_sums=10),
    spnk.layers.RootSum(return_weighted_child_logits=True),
    tf.keras.layers.Softmax()
  ])
  return spn

def spn_training(X_train, dr):
  normalize = spnk.layers.NormalizeStandardScore(
    axes=spnk.layers.NormalizeAxes.GLOBAL, input_shape=(32, 16, 1))
  train_x_ds = tf.data.Dataset.from_tensor_slices((X_train,)).batch(batch_size)
  normalize.adapt(train_x_ds)
  
  spn_net = biuld_model(normalize, dr)
  optimizer1 = copy.deepcopy(keras.optimizers.Adam(learning_rate=0.007)) #SGD(learning_rate=0.001, momentum=0.9)
  metrics1 = copy.deepcopy([keras.metrics.SparseCategoricalAccuracy()])
  loss1 = copy.deepcopy(keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM, from_logits=True))

  spn_net.compile(loss=loss1, metrics=metrics1, optimizer=optimizer1)
  return spn_net

def cat_to_val(z):
  return np.argmax(z, axis=1)

spn_network = spn_training(output_rep_tr, .3)

# spn_network.summary()

callbacks_sp = [tf.keras.callbacks.TerminateOnNaN()]
hist = spn_network.fit(output_rep_tr, cat_to_val(Y_train),batch_size = batch_size_spn, epochs=600,callbacks=callbacks_sp )# , validation_data=(X_valid, cat_to_val(Y_valid))
intermediate_output = intermediate_layer_model.predict(X_test,batch_size=batch_size) #get_rep(deterministic_model,"feature", X_test ) #
output_rep = intermediate_output.reshape(X_test.shape[0],32, 16, 1)
score, _ = spn_network.evaluate(output_rep,cat_to_val(Y_test) ,  verbose=1)
pred = spn_network.predict(output_rep ,  verbose=1)
acc1 = accur(pred, Y_test)
all_accuracy = acc
all_accuracy_spn = acc1

print("the accuracy for spn", acc1)



######################Active learning part ####################
for i in range(acquisition_iterations):
  print('POOLING ITERATION', i)
  pool_subset = 25000
  pool_subset_dropout = np.asarray(random.sample(range(0,X_Pool.shape[0]), pool_subset))

  X_Pool_Dropout = X_Pool[pool_subset_dropout, :, :, :]
  y_Pool_Dropout = y_Pool[pool_subset_dropout]
  All_Dropout_Classes = np.zeros(shape=(X_Pool_Dropout.shape[0],10))




  intermediate_output = intermediate_layer_model.predict(X_Pool_Dropout, batch_size= batch_size) #get_rep(deterministic_model,"feature", X_Pool_Dropout ) #
  output_rep = intermediate_output.reshape(X_Pool_Dropout.shape[0],32, 16, 1)
  del intermediate_layer_model
 
  


  spn_network.fit(output_rep_tr, cat_to_val(Y_train), epochs=1,callbacks=callbacks_sp )
  
  pred = spn_network.predict(output_rep ,  verbose=1)
  y_pred = np.argmax(pred, axis=1).reshape(pred.shape[0],1)
  for i in range(29):
    spn_network.fit(output_rep_tr, cat_to_val(Y_train), epochs=1,callbacks=callbacks_sp )
    pred = spn_network.predict(output_rep ,  verbose=1)
    y_pred = np.append(y_pred,np.argmax(pred, axis=1).reshape(pred.shape[0],1), axis=1 )

  setlist = []
  for i in range(y_pred.shape[0]):
    setlist.append(len(set(y_pred[i]))) 
  weight_ar  = weight_un(np.array(setlist)).flatten()


  
  dropout_classes = spn_network.predict(output_rep,batch_size=batch_size, verbose=1)
 
    




  All_Dropout_Classes = dropout_classes 
  lst = []
  for i in range(All_Dropout_Classes.shape[0]):
    lst.append(max_ent(All_Dropout_Classes[i]))

  a_1d = weight_un(np.array(lst)).flatten() 
  a_1d = np.multiply(a_1d,weight_ar)




  x_pool_index = a_1d.argsort()[-Queries:][::-1] 

  #store all the pooled images indexes
  x_pool_All = np.append(x_pool_All, x_pool_index)

  Pooled_X = X_Pool_Dropout[x_pool_index, :, :, :]
  Pooled_Y = y_Pool_Dropout[x_pool_index]

  #first delete the random subset used for test time dropout from X_Pool
  #Delete the pooled point from this pool set (this random subset)
  #then add back the random pool subset with pooled points deleted back to the X_Pool set
  delete_Pool_X = np.delete(X_Pool, (pool_subset_dropout), axis=0)
  delete_Pool_Y = np.delete(y_Pool, (pool_subset_dropout), axis=0)

  delete_Pool_X_Dropout = np.delete(X_Pool_Dropout, (x_pool_index), axis=0)
  delete_Pool_Y_Dropout = np.delete(y_Pool_Dropout, (x_pool_index), axis=0)

  X_Pool = np.concatenate((delete_Pool_X, delete_Pool_X_Dropout), axis=0)
  y_Pool = np.concatenate((delete_Pool_Y, delete_Pool_Y_Dropout), axis=0)


  print('Acquised Points added to training set')

  X_train = np.concatenate((X_train, Pooled_X), axis=0)
  y_train = np.concatenate((y_train, Pooled_Y), axis=0)
  print('Train Model with pooled points')

  # convert class vectors to binary class matrices
  Y_train = lb.fit_transform(y_train)



  deterministic_model = define_model()




  hist = deterministic_model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch, verbose=1, validation_data=(X_valid, Y_valid))

  layer_name = "feature"
  intermediate_layer_model = keras.Model(inputs=deterministic_model.input,
                                        outputs=deterministic_model.get_layer(layer_name).output)
  intermediate_layer_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  intermediate_output = intermediate_layer_model.predict(X_train, batch_size=batch_size)
  output_rep_tr = intermediate_output.reshape(X_train.shape[0],32, 16, 1)
  spn_network = spn_training(output_rep_tr, .3)
      

  hist1 = spn_network.fit(output_rep_tr, cat_to_val(Y_train), batch_size=batch_size_spn, epochs=nb_epoch+300, verbose=1,callbacks=callbacks_sp) #, validation_data=(X_valid, cat_to_val(Y_valid))







  print('Evaluate Model Test Accuracy with pooled points')
  intermediate_output = intermediate_layer_model.predict(X_test, batch_size=batch_size)
  output_rep = intermediate_output.reshape(X_test.shape[0],32, 16, 1)

  score, _ = spn_network.evaluate(output_rep, cat_to_val(Y_test), verbose=0)
  pred = spn_network.predict(output_rep ,  verbose=1)
  # proper = np.argmax(pred, axis = 1)

  acc1 = accur(pred, Y_test)

  print('Test score:', score)
  print('Test accuracy_spn:', acc1)

  all_accuracy_spn = np.append(all_accuracy_spn, acc1)

  print('Use this trained model with pooled points for Dropout again')
  score, acc = deterministic_model.evaluate(X_test,Y_test ,  verbose=1)
  print('Test score:', score)
  print('Test accuracy:', acc)
  all_accuracy = np.append(all_accuracy, acc)


file_name = "spn_outputmax_4.pkl"
open_file = open(file_name, "wb")

pickle.dump(all_accuracy_spn, open_file)

open_file.close()


file_name = "cnn-outputmax_4.pkl"
open_file = open(file_name, "wb")

pickle.dump(all_accuracy, open_file)

open_file.close()
end = time.time()

print(f"Runtime of the program is {end - start}")



















