from tensorflow.keras.layers import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import AveragePooling2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras.layers import Flatten, MaxPool2D
from keras.layers import Input
from keras.models import Model
from keras.layers import concatenate

from tensorflow.keras.optimizers import SGD

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import RMSprop
from keras.datasets import mnist, fashion_mnist, cifar10
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils

import copy
import pickle
import numpy as np
import os
import time
import random 
from pathlib import Path
from calibration import compute_calibration
#################functions ###############

def accur(y_hat, obs):
  j = 0
  
  for i in range(y_hat.shape[0]):
 
    if np.argmax(y_hat, axis = 1)[i]== np.argmax(obs, axis=1)[i]:
      
      j = j+1
 
  return j/float(y_hat.shape[0])

def var_ratio(x):
  return  1 - max(x)
def max_ent(x):
  return -sum(x*np.log2(x))

def weight_un(x):
  return x/x.max()


def load_np(fil_name):
  with open(fil_name, "rb") as f:
    return np.load(f)


#########################################




batch_size = 100
batch_size_spn  = 32
nb_classes = 10


nb_epoch = 35

# input image dimensions
img_rows, img_cols = 32, 32

score=0
all_accuracy = 0
all_accuracy_spn = 0 
acquisition_iterations = 9

#use a large number of dropout iterations
dropout_iterations = 1
Queries = 1000


Experiments_All_Accuracy = np.zeros(shape=(acquisition_iterations+1))

(X_train_All, y_train_All), (X_test, y_test) = cifar10.load_data()

X_train_All = X_train_All.reshape(X_train_All.shape[0],img_rows, img_cols, 3)
X_test = X_test.reshape(X_test.shape[0],img_rows, img_cols,3)

# random_split = np.asarray(random.sample(range(0,X_train_All.shape[0]), X_train_All.shape[0]))
rand_path = Path(__file__).resolve().parent
random_split = load_np(os.path.join(rand_path ,"rand_init_0.npy"))

X_train_All = X_train_All[random_split, :, :, :]
y_train_All = y_train_All[random_split]
X_valid = X_train_All[45000:50000, :, :, :]
y_valid = y_train_All[45000:50000]
X_Pool = X_train_All[10000:60000, :, :, :]
y_Pool = y_train_All[10000:60000]
X_train_All = X_train_All[0:45000, :, :, :]
y_train_All = y_train_All[0:45000]

#training data to have equal distribution of classes
idx_0 = np.array( np.where(y_train_All==0)  ).T
idx_0 = idx_0[0:100,0]
X_0 = X_train_All[idx_0, :, :, :]
y_0 = y_train_All[idx_0]

idx_1 = np.array( np.where(y_train_All==1)  ).T
idx_1 = idx_1[0:100,0]
X_1 = X_train_All[idx_1, :, :, :]
y_1 = y_train_All[idx_1]

idx_2 = np.array( np.where(y_train_All==2)  ).T
idx_2 = idx_2[0:100,0]
X_2 = X_train_All[idx_2, :, :, :]
y_2 = y_train_All[idx_2]

idx_3 = np.array( np.where(y_train_All==3)  ).T
idx_3 = idx_3[0:100,0]
X_3 = X_train_All[idx_3, :, :, :]
y_3 = y_train_All[idx_3]

idx_4 = np.array( np.where(y_train_All==4)  ).T
idx_4 = idx_4[0:100,0]
X_4 = X_train_All[idx_4, :, :, :]
y_4 = y_train_All[idx_4]

idx_5 = np.array( np.where(y_train_All==5)  ).T
idx_5 = idx_5[0:100,0]
X_5 = X_train_All[idx_5, :, :, :]
y_5 = y_train_All[idx_5]

idx_6 = np.array( np.where(y_train_All==6)  ).T
idx_6 = idx_6[0:100,0]
X_6 = X_train_All[idx_6, :, :, :]
y_6 = y_train_All[idx_6]

idx_7 = np.array( np.where(y_train_All==7)  ).T
idx_7 = idx_7[0:100,0]
X_7 = X_train_All[idx_7, :, :, :]
y_7 = y_train_All[idx_7]

idx_8 = np.array( np.where(y_train_All==8)  ).T
idx_8 = idx_8[0:100,0]
X_8 = X_train_All[idx_8, :, :, :]
y_8 = y_train_All[idx_8]

idx_9 = np.array( np.where(y_train_All==9)  ).T
idx_9 = idx_9[0:100,0]
X_9 = X_train_All[idx_9, :, :, :]
y_9 = y_train_All[idx_9]

X_train = X_train_All #np.concatenate((X_0, X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9), axis=0 )
y_train = y_train_All #np.concatenate((y_0, y_1, y_2, y_3, y_4, y_5, y_6, y_7, y_8, y_9), axis=0 )


print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')


X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_valid = X_valid.astype('float32')
X_Pool = X_Pool.astype('float32')
X_train /= 255
X_valid /= 255
X_Pool /= 255
X_test /= 255

Y_test = np_utils.to_categorical(y_test, nb_classes)
Y_valid = np_utils.to_categorical(y_valid, nb_classes)
Y_Pool = np_utils.to_categorical(y_Pool, nb_classes)


x_pool_All = np.zeros(shape=(1))

Y_train = np_utils.to_categorical(y_train, nb_classes)

##########model and initial trining ################


# ######models

 

def define_model():
  model = Sequential()
  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))
  model.add(BatchNormalization())
  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.2))
  model.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.3))
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.4))
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D((2, 2)))





  model.add(Dropout(0.5))
  model.add(Flatten())
  model.add(Dense(512, name="feature", activation='relu', kernel_initializer='he_uniform'))
  model.add(BatchNormalization(name="norm"))
  model.add(Dropout(0.5))
  model.add(Dense(10, name="befor"))
  model.add(tf.keras.layers.Softmax())
  # compile model
  opt = SGD(learning_rate=0.001, momentum=0.9)
  model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  return model

# #################################Second model resnet
def feature_extractor(inputs):

  feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(32, 32, 3),
                                               include_top=False,
                                               weights='imagenet')(inputs)
  return feature_extractor


'''
Defines final dense layers and subsequent softmax layer for classification.
'''
def classifier(inputs):
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation="relu", name='feature')(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dense(10, name="befor")(x)
    x = tf.keras.layers.Softmax()(x)
    return x

def final_model(inputs):

    resize = tf.keras.layers.UpSampling2D(size=(7,7))(inputs)

    resnet_feature_extractor = feature_extractor(resize)
    classification_output = classifier(resnet_feature_extractor)

    return classification_output

'''
Define the model and compile it. 
Use Stochastic Gradient Descent as the optimizer.
Use Sparse Categorical CrossEntropy as the loss function.
'''
def define_compile_model():
  inputs = tf.keras.layers.Input(shape=(32,32,3))
  
  classification_output = final_model(inputs) 
  model = tf.keras.Model(inputs=inputs, outputs = classification_output)
  losss = copy.deepcopy(keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM, from_logits=False))
 
  model.compile(optimizer='SGD', 
                loss=losss,
                metrics = ['accuracy'])
  
  return model

#############################################Lee net
def define_model():
  c =  0 
  model = Sequential()
  model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))

  model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))

  model.add(MaxPooling2D((2, 2)))

  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))

  model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
 
  model.add(MaxPooling2D((2, 2)))
  model.add(Dropout(0.5))


  model.add(Flatten())
  model.add(Dense(512, name="feature", activation='relu', kernel_initializer='he_uniform')) 

  model.add(Dropout(0.5))
  model.add(Dense(10, name="befor"))
  model.add(tf.keras.layers.Softmax())
  # compile model
  opt = SGD(learning_rate=0.001, momentum=0.9) # optimizer=  tf.keras.optimizers.Adam() 

  model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  return model
#############################################

deterministic_model = define_model()

start = time.time()

hist = deterministic_model.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch, verbose=1,validation_data=(X_valid, Y_valid))

 
score, acc = deterministic_model.evaluate(X_test,Y_test ,  verbose=1)
print("The accuracy for first acquisition:", acc)


predi = deterministic_model.predict(X_test, batch_size=batch_size)
all_accuracy = accur(predi, Y_test)

print("The accuracy for first acquisition:", all_accuracy)
################################################################# getting the logit output
#*****extracting representations************
layer_name = "befor"

logit_layer = keras.Model(inputs=deterministic_model.input,
                                       outputs=deterministic_model.get_layer(layer_name).output)


opt = SGD(learning_rate=0.001, momentum=0.9)

logit_layer.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
##########################################################################
##########Calculating the CEC before calibration:
predi_logit = logit_layer.predict(X_test,batch_size=batch_size)
logits_cnn_before = tf.convert_to_tensor(predi_logit, dtype=tf.float32, name='logits')
labels_true = tf.convert_to_tensor(np.argmax(Y_test, axis=1), dtype=tf.int32, name='labels_true')
calib_cnn_before = tfp.stats.expected_calibration_error(num_bins=15, 
                                     logits=logits_cnn_before, 
                                     labels_true= labels_true)


####################################Temurature Scaling: 
temp = tf.Variable(initial_value=1.0, trainable=True, dtype=tf.float32) 
pred_cal = logit_layer.predict(X_valid,batch_size=batch_size)
def compute_loss():
    
    y_pred_model_w_temp = tf.math.divide(pred_cal, temp)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\
                                tf.convert_to_tensor(Y_valid), y_pred_model_w_temp))
    return loss

optimizer = tf.optimizers.Adam(learning_rate=0.01)

print('Temperature Initial value: {}'.format(temp.numpy()))

for i in range(500):
    opts = optimizer.minimize(compute_loss, var_list=[temp])


print('Temperature Final value: {}'.format(temp.numpy()))
#########################CEC after calibration:
y_pred_model_w_temp = tf.math.divide(predi_logit, temp)
num_bins = 15
# labels_true = tf.convert_to_tensor(np.argmax(Y_test, axis=1), dtype=tf.int32, name='labels_true')
logits_after = tf.convert_to_tensor(y_pred_model_w_temp, dtype=tf.float32, name='logits')

CEC_cnn_AFTER = tfp.stats.expected_calibration_error(num_bins=num_bins, 
                                     logits=logits_after, 
                                     labels_true=labels_true)


###########################################SPN part 
############## SPN part for probability calculation ###################
#*****extracting representations************
layer_name = "feature"

intermediate_layer_model = keras.Model(inputs=deterministic_model.input,
                                       outputs=deterministic_model.get_layer(layer_name).output)





opt = SGD(learning_rate=0.001, momentum=0.9)

intermediate_layer_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
# intermediate_layer_model.summary()
intermediate_output = intermediate_layer_model.predict(X_train,batch_size=batch_size)




output_rep_tr = intermediate_output.reshape(X_train.shape[0],32, 16, 1)
#*************spn training ****************
import libspn_keras as spnk

print(spnk.get_default_sum_op())

spnk.set_default_sum_op(spnk.SumOpGradBackprop())
from tensorflow import keras

spnk.set_default_accumulator_initializer(
    keras.initializers.TruncatedNormal(stddev=0.5, mean=1.0)

)

import tensorflow_datasets as tfds




normalize = spnk.layers.NormalizeStandardScore(
    axes=spnk.layers.NormalizeAxes.GLOBAL, input_shape=(32, 16, 1),name="input1") 

train_x_ds = tf.data.Dataset.from_tensor_slices((output_rep_tr,)).batch(batch_size)

test_x_ds = tf.data.Dataset.from_tensor_slices((X_test,)).batch(batch_size)
def take_first(a, b):
  return a

normalize.adapt(train_x_ds)

def biuld_model(normalize, dr):
  spn = keras.Sequential([

                          
                          

    normalize,
    spnk.layers.NormalLeaf(
        num_components=16, 
        location_trainable=True,
        location_initializer=keras.initializers.TruncatedNormal(
            stddev=1.0, mean=0.0)
    ),
    # Non-overlapping products
    tf.keras.layers.Dropout(rate=dr),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[2, 2], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='valid'
    ),
    spnk.layers.Local2DSum(num_sums=16),
    # Non-overlapping products
    # tf.keras.layers.Dropout(rate=0.15),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[2, 2], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='valid'
    ),
    spnk.layers.Local2DSum(num_sums=32),
    # Overlapping products, starting at dilations [1, 1]
    # tf.keras.layers.Dropout(rate=0.1),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[1, 1], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=32),
    # Overlapping products, with dilations [2, 2] and full padding
    # tf.keras.layers.Dropout(rate=0.1),
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[2, 2], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=64),
    # Overlapping products, with dilations [2, 2] and full padding
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[4, 4], 
        kernel_size=[2, 2],
        padding='full'
    ),
    spnk.layers.Local2DSum(num_sums=64),
    # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
    # all scopes
    spnk.layers.Conv2DProduct(
        depthwise=True, 
        strides=[1, 1], 
        dilations=[8, 8], 
        kernel_size=[2, 2],
        padding='final'
    ),
    spnk.layers.SpatialToRegions(),
    # Class roots
    
    spnk.layers.NormalizeStandardScore(),
    # tf.keras.layers.BatchNormalization(name="norm-s2"),
    spnk.layers.DenseSum(num_sums=10, name="befo11resoft" ),
    spnk.layers.RootSum(return_weighted_child_logits=True, name="beforesoft"),
    tf.keras.layers.BatchNormalization(name="norm-s2"),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(10,name="densebe"),
    tf.keras.layers.Softmax(name="last")
  ])
  return spn

def spn_training(X_train, dr):
  normalize = spnk.layers.NormalizeStandardScore(
    axes=spnk.layers.NormalizeAxes.GLOBAL, input_shape=(32, 16, 1))
  train_x_ds = tf.data.Dataset.from_tensor_slices((X_train,)).batch(batch_size)
  normalize.adapt(train_x_ds)
  
  spn_net = biuld_model(normalize, dr)
  optimizer1 = copy.deepcopy(keras.optimizers.Adam(learning_rate=0.007)) #SGD(learning_rate=0.001, momentum=0.9)
  metrics1 = copy.deepcopy([keras.metrics.SparseCategoricalAccuracy()])
  loss1 = copy.deepcopy(keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM, from_logits=False))

  spn_net.compile(loss=loss1, metrics=metrics1, optimizer=optimizer1)
  return spn_net

def cat_to_val(z):
  return np.argmax(z, axis=1)

spn_network = spn_training(output_rep_tr, .3)

# spn_network.summary()

callbacks_sp = [tf.keras.callbacks.TerminateOnNaN()]
hist = spn_network.fit(output_rep_tr, cat_to_val(Y_train),batch_size = batch_size_spn, epochs=10,callbacks=callbacks_sp )# , validation_data=(X_valid, cat_to_val(Y_valid))
intermediate_output = intermediate_layer_model.predict(X_test,batch_size=batch_size) #get_rep(deterministic_model,"feature", X_test ) #
output_rep = intermediate_output.reshape(X_test.shape[0],32, 16, 1)
score, _ = spn_network.evaluate(output_rep,cat_to_val(Y_test) ,  verbose=1)
pred = spn_network.predict(output_rep ,  verbose=1)
acc1 = accur(pred, Y_test)
all_accuracy = acc
all_accuracy_spn = acc1

##########################################################################################################Calculate SPN values before softmax
layer_name = "densebe"

logit_layer_spn = keras.Model(inputs=spn_network.input,
                                       outputs=spn_network.get_layer(layer_name).output)





opt = SGD(learning_rate=0.001, momentum=0.9)

logit_layer_spn.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
# intermediate_layer_model.summary()
logit_test = logit_layer_spn.predict(output_rep,batch_size=batch_size)
logit_test = np.reshape(logit_test, (logit_test.shape[0], -1))






logits_spn = tf.convert_to_tensor(logit_test, dtype=tf.float32, name='logits')
print(logits_spn.shape)

# labels_true = tf.convert_to_tensor(np.argmax(Y_test, axis=1), dtype=tf.int32, name='labels_true')
calib_spn_before = tfp.stats.expected_calibration_error(num_bins=15, 
                                     logits=logits_spn, 
                                     labels_true= labels_true)







#################################################################
temp1 = tf.Variable(initial_value=1.0, trainable=True, dtype=tf.float32) 
intermediate_output = intermediate_layer_model.predict(X_valid,batch_size=batch_size) #get_rep(deterministic_model,"feature", X_test ) #
output_rep = intermediate_output.reshape(X_valid.shape[0],32, 16, 1)
pred_cal = logit_layer_spn.predict(output_rep ,  verbose=1)
pred_cal = np.reshape(pred_cal, (pred_cal.shape[0], -1))
def compute_loss():

    y_pred_model_w_temp = tf.math.divide(pred_cal, temp1)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\
                                tf.convert_to_tensor(Y_valid), y_pred_model_w_temp))
    return loss

optimizer = tf.optimizers.Adam(learning_rate=0.01)

print('Temperature Initial value: {}'.format(temp1.numpy()))

for i in range(500):
    opts = optimizer.minimize(compute_loss, var_list=[temp1])


print('Temperature Final value: {}'.format(temp1.numpy()))


y_pred_model_w_temp = tf.math.divide(logit_test, temp1)
num_bins = 15
# labels_true = tf.convert_to_tensor(np.argmax(Y_test, axis=1), dtype=tf.int32, name='labels_true')
logits = tf.convert_to_tensor(y_pred_model_w_temp, dtype=tf.float32, name='logits')

CEC_spn_AFTER = tfp.stats.expected_calibration_error(num_bins=num_bins, 
                                     logits=logits, 
                                     labels_true=labels_true)




print("the accuracy for spn", acc1)
print("The CEC after calibration for CNN", CEC_cnn_AFTER)
print("The CEC before calibration for CNN",calib_cnn_before)
print("The CEC after calibration for SPN", CEC_spn_AFTER)
print("The CEC before calibration for SPN",calib_spn_before)

