import matplotlib.pyplot as plt
import keras
from keras.datasets import fashion_mnist
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Activation
from keras.layers import Conv2D, MaxPooling2D
import sys
sys.path.append('../')
from aaa_issta import lib_deep
from sklearn import model_selection



def get_data(datasetname):
    '''
    Description
        This function returns data based on the name of a dataset

    Parameters
        datasetname: 'iris','wine','heart','car','cancer','bank','adult','connect'

    Returns
        X_train,y_train,X_hold,y_hold
    '''
    if 'mnist' in datasetname:
        num_classes = 10
        # input image dimensions
        img_rows, img_cols = 28, 28
        # the data, split between train and test sets
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

        if K.image_data_format() == 'channels_first':
            x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
            x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        else:
            x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
            x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

        X_train = x_train.astype('float32')
        X_hold = x_test.astype('float32')
        X_train /= 255
        X_hold /= 255

        # convert class vectors to binary class matrices
        y_train = keras.utils.to_categorical(y_train, num_classes)
        y_hold = keras.utils.to_categorical(y_test, num_classes)
    if 'cifar' in datasetname:
        batch_size = 32
        num_classes = 100
        epochs = 50

        num_predictions = 20

        # The data, split between train and test sets:
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

        X_train = x_train.astype('float32')
        X_hold = x_test.astype('float32')
        X_train /= 255
        X_hold /= 255

        # convert class vectors to binary class matrices
        y_train = keras.utils.to_categorical(y_train, num_classes)
        y_hold = keras.utils.to_categorical(y_test, num_classes)

    return X_train,y_train,X_hold,y_hold



def create_model_mnist(dropoutrate):
    img_rows, img_cols = 28, 28

    if K.image_data_format() == 'channels_first':
        input_shape = (1, img_rows, img_cols)
    else:
        input_shape = (img_rows, img_cols, 1)
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax'))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(learning_rate=dropoutrate),
                  metrics=['accuracy'])
    return model



def draw_Figure5(datasetname):
    '''
    Description
        This function draw sub-figures in Figure 4 in our submission.

    Parameters
        datasetname: 'mnist','cifar'

    Returns
        a subfigure in Figure 4 of the defined dataset
    '''


    usetestaccuracy = False # whether to present test accuracy
    
    resultpath = './../results/deeplearning_'+str(datasetname)+'.csv'

    X_train_o, y_train_o, X_hold, y_hold = get_data(datasetname)
    X_train, X_valid, y_train, y_valid = model_selection.train_test_split(X_train_o, y_train_o, test_size=0.2)


    pvlist = []
    cvlist = []
    testlist = []

    droputratelist = [0.01,0.001,0.0001,0.00001]
    #droputratelist = [0.8]


    for i in droputratelist: # interate for each depth

        print('Dropout rate: ' + str(i))

        epoch = 10
        dic_metric_value = lib_deep.get_MV_CV_and_test_deep_sametrain(create_model_mnist, X_train, y_train, X_valid,
                                                                      y_valid, X_hold, y_hold,
                                                                      epoch, i, resultpath)
        pv = dic_metric_value['pv']
        pvtestaccuracy = dic_metric_value['test']
        cv = dic_metric_value['cv']

        pvlist.append(pv)
        testlist.append(pvtestaccuracy)
        cvlist.append(cv)

    f = open(resultpath, 'a')
    f.write('pvlist=[')
    for item in pvlist:
        f.write("%s," % item)
    f.write(']\n')

    f.write('cvlist=[')
    for item in cvlist:
        f.write("%s," % item)
    f.write(']\n')

    f.write('testlist=[')
    for item in testlist:
        f.write("%s," % item)
    f.write(']\n')

    f.close()



    # plt.figure(figsize=(4, 5))
    # modellist = droputratelist
    # ax = plt.subplot(111)
    # ax.set_ylim([-0.05, 1.05])
    # plt.figtext(0.5, 0.9, datasetname, fontsize=25, ha='center')
    # plt.tick_params(axis='x', labelsize=20)
    # plt.tick_params(axis='y', labelsize=20)
    # plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2, 0.0])
    # plt.xlabel("dropout rate", fontsize=18)

    # plt.plot(modellist, pvlist, 'o-', color="r", label="PV") # draw PV
    # plt.plot(modellist, cvlist, 'o-', color="g", label='Validation accuracy') # draw CV
    # plt.plot(modellist, testlist, 'o-', color="black", label='Test accuracy') # draw test accuracy

    # plt.legend(fontsize=18)
    # plt.savefig("cifar.pdf")

    # plt.show()



if __name__ == '__main__':
    draw_Figure5('fashionmnist-learningrate')





