import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import random
import numpy as np
import pandas as pd
from seaborn import lineplot,scatterplot
import matplotlib.pyplot as plt
def RR(x,epsi=1,sed=1):
    # Define the population and weights
    random.seed(sed)
    population = [i for i in range(10)]
    Weights = [1/(np.exp(epsi)+9)] * 10
    Weights[int(x)] = Weights[int(x)] * np.exp(epsi)
    # Perform weighted sampling
    sample = random.choices(population, Weights, k=1)[0]
    return(sample)

# Load the MNIST dataset
def NN(h=2,epsilon=1,epoch=10,ES=False,SEED=1):
    #epsilon=5
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # Normalize the input images
    y_train_noise = np.array([RR(y_train[i],epsilon,SEED * len(y_train) + i) for i in range(len(y_train))],dtype=int)
    #print(np.mean(y_train==y_train_noise),np.exp(epsilon)/(np.exp(epsilon)+9))
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    # Convert the labels to one-hot encoding
    y_train_noise = tf.keras.utils.to_categorical(y_train_noise)
    y_test = tf.keras.utils.to_categorical(y_test)
    # Define the model architecture
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(h, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    # Compile the model
    if ES==False:
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        # Train the model
        # model.fit(x_train, y_train_noise, epochs=epoch,verbose=0,callbacks=cl)
        model.fit(x_train, y_train_noise, epochs=epoch, verbose=0)
    else:
        cl = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            min_delta=0.001,
            patience=2,
            verbose=0,
            mode='auto',
            restore_best_weights=True,
        )
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit(x_train, y_train_noise, epochs=50, verbose=0, callbacks=cl,validation_split=0.33)
    # Evaluate the model on the test data
    test_loss, test_acc = model.evaluate(x_test, y_test)
    return(test_acc)

Error_10 = []
for ep in [1,2,4]:
    for hh in [10,20,30,40]:
        for rep in range(10):
            Error_10.append([rep,ep, hh, NN(hh, ep,epoch=10,SEED=rep)])
            print(ep, hh,rep)
DF10 = pd.DataFrame(Error_10)
DF10.to_csv('DF10')

Error_20 = []
for ep in [1,2,4]:
    for hh in [10,20,40,80,160]:
        for rep in range(30):
            Error_20.append([rep,ep, hh, NN(hh, ep,epoch=10)])
            print(ep, hh,rep)
DF20 = pd.DataFrame(Error_20)
DF20.to_csv('DF20')

Error_ES = []
for ep in [1,2,4]:
    for hh in [10,20,30,40,50]:
        for rep in range(10):
            Error_ES.append([rep,ep, hh, NN(hh, ep,ES=True,SEED=rep)])
            print(ep, hh,rep)
        DFES = pd.DataFrame(Error_ES)
        DFES.to_csv('DFES_loss')

DF10 = pd.DataFrame(Error_10)
DF20 = pd.DataFrame(Error_20)
DF_ES = pd.DataFrame(Error_ES)

lineplot(data=DF10[DF10[1]==1], x=2, y=3, color='blue',label='$\epsilon=1$',  marker= 'o', markersize=5)
lineplot(data=DF10[DF10[1]==2], x=2, y=3, color='red',label='$\epsilon=2$',  marker= 'o', markersize=5)
lineplot(data=DF10[DF10[1]==4], x=2, y=3, color='black',label='$\epsilon=4$',  marker= 'o', markersize=5)
plt.xlabel('Hidden Units')
plt.ylabel('Testing Accuracy')
plt.grid()

DF20 = pd.read_csv('C:\D_Disk\Label_DP_Final\DF20').iloc[:,1::]
DF20.columns = [0,1,2,3]
lineplot(data=DF20[DF20[1]==1], x=2, y=3, color='blue',label='$\epsilon=1$',  marker= 'o', markersize=5)
lineplot(data=DF20[DF20[1]==2], x=2, y=3, color='red',label='$\epsilon=2$',  marker= 'o', markersize=5)
lineplot(data=DF20[DF20[1]==4], x=2, y=3, color='black',label='$\epsilon=4$',  marker= 'o', markersize=5)
plt.grid()
plt.xlabel('Hidden Units')
plt.ylabel('Testing Accuracy')


DFES = pd.read_csv('C:\D_Disk\Label_DP_Final\DFES_loss').iloc[:,1::]
DFES.columns = [0,1,2,3]

lineplot(data=DFES[DFES[1]==1], x=2, y=3, color='blue',label='$\epsilon=1$',  marker= 'o', markersize=5)
lineplot(data=DFES[DFES[1]==2], x=2, y=3, color='red',label='$\epsilon=2$',  marker= 'o', markersize=5)
lineplot(data=DFES[DFES[1]==4], x=2, y=3, color='black',label='$\epsilon=4$',  marker= 'o', markersize=5)
plt.grid()
plt.xlabel('Hidden Units')
plt.ylabel('Testing Accuracy')
plt.grid()