# Import required modules etc
import numpy as np
import sklearn
from sklearn import model_selection
from sklearn.model_selection import train_test_split
import pandas as pd
from scipy import special
import os
import json
import gzip
from urllib.request import urlopen
import random
from sklearn.metrics import confusion_matrix
from datetime import datetime
import tensorflow as tf
import condor_tensorflow as condor


from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV

##########################
### SETTINGS
##########################

# Seeds for reproducibility

#seed_value=321
#seed_value=231
seed_value=123
np.random.seed(seed_value)
os.environ['PYTHONHASHSEED']=str(seed_value)
random.seed(seed_value)
tf.random.set_seed(seed_value)

#
num_classes=10

# Method to load MNIST from local file
def load_data(path):
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)

# Fetch and format the mnist data
(mnist_images, mnist_labels), (mnist_images_test, mnist_labels_test)  = load_data('./mnist.npz')

# Split off a validation dataset for early stopping
mnist_images, mnist_images_val, mnist_labels, mnist_labels_val = \
model_selection.train_test_split(mnist_images, mnist_labels, test_size = 5000, random_state = seed_value)

mnist_images= mnist_images/255.0
mnist_images_val=mnist_images_val/255.0
mnist_images_test=mnist_images_test/255.0


mnist_images=mnist_images.reshape(55000,28,28,1)
mnist_images_val=mnist_images_val.reshape(5000,28,28,1)
mnist_images_test=mnist_images_test.reshape(10000,28,28,1)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten


def create_model():
  model = tf.keras.Sequential()
  model.add(Conv2D(64,kernel_size=3, activation='relu', input_shape=(28,28,1)))
  model.add(Conv2D(32, kernel_size=3, activation = "relu"))
  model.add(Flatten())
  model.add(tf.keras.layers.Dense(num_classes-1))
  model.compile(optimizer= 'adam',
  loss=condor.SparseCondorOrdinalCrossEntropy(),
  metrics = [condor.SparseOrdinalEarthMoversDistance(),
      condor.SparseOrdinalMeanAbsoluteError(),
      condor.SparseOrdinalAccuracy(name='accuracy'),
      condor.SparseOrdinalAccuracy(name='AccTol1',tolerance=1)])
  return model

#model=KerasClassifier(build_fn=create_model)

model=create_model();

print(mnist_images.shape)
print(mnist_labels.shape)

print(mnist_images_val.shape)
print(mnist_labels_val.shape)
#quit()
model.fit(mnist_images, mnist_labels, validation_data=(mnist_images_val, mnist_labels_val),epochs=100, callbacks = [tf.keras.callbacks.EarlyStopping(patience = 10, restore_best_weights = True)])
#params={'batch_size':[64,128,256], 
#        'epochs':[10,30,60,90],
#	'num_classes':[10],
#	'dropout':[0.1,0.15,0.2,0.25],
#	'learning_rate':[1e-1,1e-2,1e-3],
#        'unit':[128,256,512,768,1024]
#        }
# Note that the model generates 1 fewer outputs than the number of classes. 
#model.summary()

# This takes about 5 minutes on CPU, 2.5 minutes on GPU.
#history = model.fit(dataset, epochs = num_epochs, validation_data = val_dataset,
#                    callbacks = [tf.keras.callbacks.EarlyStopping(patience = patient, restore_best_weights = True)])

#gs=GridSearchCV(estimator=model, param_grid=params, cv=5)
#gs=gs.fit(mnist_images, mnist_labels)

#best_params=gs.best_params_
#accuracy=gs.best_score_
#print(best_params)
#print(accuracy)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')

model.save(os.path.join('./', 'model', timestamp + '_' + 'model.h5'))

# Evaluate on test dataset.
model.evaluate(mnist_images_test, mnist_labels_test)


print("Predict on test dataset")

# Note that these are ordinal (cumulative) logits, not probabilities or regular logits.
ordinal_logits = model.predict(mnist_images_test)

# Convert from logits to label probabilities. This is initially a tensorflow tensor.
pred_probs = pd.DataFrame(condor.ordinal_softmax(ordinal_logits).numpy())

# Get labels based on max prob
labels = pred_probs.idxmax(axis = 1)
labels.head()

# Compare to logit-based cumulative probs
cum_probs = pd.DataFrame(ordinal_logits).apply(special.expit).cumprod(axis=1)
labels2 = cum_probs.apply(lambda x: x > 0.5).sum(axis = 1)
labels2.head()


# Output metrics
print("Mean absolute label error version 1:", np.mean(np.abs(labels - mnist_labels_test)))
print("Mean absolute label error version 2:", np.mean(np.abs(labels2 - mnist_labels_test)))

print("Accuracy of label version 1:", np.mean(labels == mnist_labels_test))
print("Accuracy of label version 2:", np.mean(labels2 == mnist_labels_test))

print("Accuracy tolerance 1 of label version 1:", np.mean(np.abs(labels - mnist_labels_test) <= 1))

print("Accuracy tolerance 1 of label version 2:", np.mean(np.abs(labels2 - mnist_labels_test) <= 1))

