import matplotlib.pyplot as plt
import keras
from keras.datasets import fashion_mnist,cifar10
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Activation
from keras.layers import Conv2D, MaxPooling2D
import sys
sys.path.append('../')
from aaa_issta import lib_deep
from sklearn import model_selection



def get_data(datasetname):
    '''
    Description
        This function returns data based on the name of a dataset

    Parameters
        datasetname: 'iris','wine','heart','car','cancer','bank','adult','connect'

    Returns
        X_train,y_train,X_hold,y_hold
    '''
    if 'mnist' in datasetname:
        num_classes = 10
        # input image dimensions
        img_rows, img_cols = 28, 28
        # the data, split between train and test sets
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

        if K.image_data_format() == 'channels_first':
            x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
            x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        else:
            x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
            x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

        X_train = x_train.astype('float32')
        X_hold = x_test.astype('float32')
        X_train /= 255
        X_hold /= 255

        # convert class vectors to binary class matrices
        y_train = keras.utils.to_categorical(y_train, num_classes)
        y_hold = keras.utils.to_categorical(y_test, num_classes)
    if 'cifar' in datasetname:
        batch_size = 32
        num_classes = 10
        epochs = 50

        num_predictions = 20

        # The data, split between train and test sets:
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()

        X_train = x_train.astype('float32')
        X_hold = x_test.astype('float32')
        X_train /= 255
        X_hold /= 255

        # convert class vectors to binary class matrices
        y_train = keras.utils.to_categorical(y_train, num_classes)
        y_hold = keras.utils.to_categorical(y_test, num_classes)

    return X_train,y_train,X_hold,y_hold



def create_model_cifar(dropoutrate,x_train):
    model = Sequential()
    model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation('softmax'))

    # initiate RMSprop optimizer
    opt = keras.optimizers.RMSprop(lr=dropoutrate, decay=1e-6)

    # Let's train the model using RMSprop
    model.compile(loss='categorical_crossentropy',
                  optimizer=opt,
                  metrics=['accuracy'])

    return model



def draw_Figure5(datasetname):
    '''
    Description
        This function draw sub-figures in Figure 4 in our submission.

    Parameters
        datasetname: 'mnist','cifar'

    Returns
        a subfigure in Figure 4 of the defined dataset
    '''


    usetestaccuracy = False # whether to present test accuracy
    
    resultpath = './../results/deeplearning_'+str(datasetname)+'.csv'

    X_train_o, y_train_o, X_hold, y_hold = get_data(datasetname)
    X_train, X_valid, y_train, y_valid = model_selection.train_test_split(X_train_o, y_train_o, test_size=0.2)


    pvlist = []
    cvlist = []
    testlist = []

    droputratelist = [0.01,0.001,0.0001,0.00001]
    #droputratelist = [0.8]


    for i in droputratelist: # interate for each depth

        print('Dropout rate: ' + str(i))

        epoch = 50
        dic_metric_value = lib_deep.get_MV_CV_and_test_deep_sametrain(create_model_cifar, X_train, y_train, X_valid,
                                                                      y_valid, X_hold, y_hold,
                                                                      epoch, i, resultpath)
        pv = dic_metric_value['pv']
        pvtestaccuracy = dic_metric_value['test']
        cv = dic_metric_value['cv']

        pvlist.append(pv)
        testlist.append(pvtestaccuracy)
        cvlist.append(cv)

    f = open(resultpath, 'a')
    f.write('pvlist=[')
    for item in pvlist:
        f.write("%s," % item)
    f.write(']\n')

    f.write('cvlist=[')
    for item in cvlist:
        f.write("%s," % item)
    f.write(']\n')

    f.write('testlist=[')
    for item in testlist:
        f.write("%s," % item)
    f.write(']\n')

    f.close()





if __name__ == '__main__':
    draw_Figure5('cifar10-learningrate')






