seed=42#321#123
import tensorflow as tf
import condor_tensorflow as condor
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
tf.random.set_seed(seed)
np.random.seed(seed)
import os
os.environ['PYTHONHASHSEED']=str(seed)


# create some toy data
num_classes = 4
samps = 1000
X = np.random.normal(size=(samps,2))
y = np.zeros(samps,dtype=np.int64)

for ind in range(samps):
    x = X[ind,:]
    y[ind] = np.digitize(np.arctan2(x[0],x[1]),
                         np.linspace(-np.pi,np.pi,num=num_classes+1)) - 1

# do a test train split
X_train, X_test, labels_train, labels_test = \
  train_test_split(X, y, test_size = samps//10, random_state = 1)

# Create neural network model
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(2))
model.add(tf.keras.layers.Dense(10,activation="relu"))
model.add(tf.keras.layers.Dense(10,activation="relu"))
model.add(tf.keras.layers.Dense(num_classes-1))

model.summary()

# compile and fit the model
model.compile(loss = condor.SparseCondorOrdinalCrossEntropy(),
              metrics = [condor.SparseOrdinalMeanAbsoluteError(),
                         condor.SparseOrdinalEarthMoversDistance(),
                         condor.SparseOrdinalAccuracy(),
                         condor.SparseOrdinalAccuracy(tolerance=1)],
              optimizer = tf.keras.optimizers.Adam())

history = model.fit(x = X_train,
                    y = labels_train,
                    epochs = 100,
                    batch_size = 32,
                    validation_split = 0.2,
                    callbacks = [tf.keras.callbacks.EarlyStopping(patience = 10,
                                                                  min_delta = 0.001,
                                                                  restore_best_weights = True)])
# see how well we did
model.evaluate(X_test, labels_test)

#ordinal_logits = model.predict(X_test)
#tensor_probs = condor.ordinal_softmax(ordinal_logits)
#probs_df = pd.DataFrame(tensor_probs.numpy())
#
## Compute probability on wrong labels
#wrong = np.zeros(len(labels_test))
#for index, row in probs_df.iterrows():
#    yi = labels_test[index]
#    wrong[index] = 1. - row[np.int_(yi)]
#
#print("Average percentage incorrect class: ",wrong.mean())
##Average percentage incorrect class:  0.045961754322052004
##Average earth movers distance: 0.0537
#
## Find probability distribution over classes
#print(probs_df.mean())
