from config import *
import tensorflow as tf
import csv


def load_data(path_data):
    if ex_mode == 1:  # MNIST FFNN
        file = open(path_data)
        csvreader = csv.reader(file)
        rows = []
        for row in csvreader:
            res = [int(i) for i in row]
            rows.append(res)
        rows = np.asarray(rows)
        print(np.shape(rows))
        y_test = rows[:, 0]
        x_test = rows[:, 1:]
        x_test = x_test / np.float64(255)
        x_test_n = (x_test - mean) / std

    elif ex_mode == 2:  # MNIST CNN
        file = open(path_data)
        csvreader = csv.reader(file)
        rows = []
        for row in csvreader:
            res = [int(i) for i in row]
            rows.append(res)
        rows = np.asarray(rows)
        print(np.shape(rows))
        y_test = rows[:, 0]
        x_test = rows[:, 1:]
        x_test = x_test.reshape((100, 28, 28, 1)) / np.float64(255)
        x_test_n = (x_test - mean) / std

    elif ex_mode == 3:  # CIFAR10 FNN
        file = open(path_data)
        csvreader = csv.reader(file)
        rows = []
        for row in csvreader:
            res = [int(float(i)) for i in row]
            rows.append(res)
        rows = np.asarray(rows)
        print(np.shape(rows))
        y_test = rows[:, 0]
        x_test = rows[:, 1:]
        x_test = x_test.reshape((100, 32, 32, 3)) / np.float64(255)
        x_test_n = (x_test - mean) / std

    elif ex_mode == 4:  # CIFAR10 CNN
        file = open(path_data)
        csvreader = csv.reader(file)
        rows = []
        for row in csvreader:
            res = [int(float(i)) for i in row]
            rows.append(res)
        rows = np.asarray(rows)
        print(np.shape(rows))
        y_test = rows[:, 0]
        x_test = rows[:, 1:]
        x_test = x_test.reshape((100, 32, 32, 3)) / np.float64(255)
        x_test_n = (x_test - mean) / std

    elif ex_mode == 5:
        file = open(path_data)
        csvreader = csv.reader(file)
        rows = []
        for row in csvreader:
            res = [float(i) for i in row]
            rows.append(res)
        rows = np.asarray(rows)
        print(np.shape(rows))
        y_test = rows[:, 0]
        x_test = rows[:, 1:]
        x_test_n = (x_test - mean) / std

    return x_test, x_test_n, y_test


def load_net(path_net):
    if ex_mode == 0:
        modelNN = tf.keras.models.load_model(path_net)

    elif ex_mode == 1:  # MNIST FNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.Dense(200, input_shape=(784,), activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)
        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 2:  # MNIST CNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.ZeroPadding2D(padding=(1, 1), input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(4, 4, strides=(2, 2), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
            tf.keras.layers.Conv2D(8, 4, strides=(2, 2), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.Permute((3, 1, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)
        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 3:  # CIFAR FNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.InputLayer(input_shape=(32, 32, 3)),
            tf.keras.layers.Permute((3, 1, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(200, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)

        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 4:  # CIFAR10 CNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.ZeroPadding2D(padding=(6, 6), input_shape=(32, 32, 3)),
            tf.keras.layers.Conv2D(4, 13, strides=(1, 1), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
            tf.keras.layers.Conv2D(4, 4, strides=(2, 2), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
            tf.keras.layers.Conv2D(8, 3, strides=(1, 1), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
            tf.keras.layers.Conv2D(8, 4, strides=(2, 2), padding='valid', activation=tf.nn.relu),
            tf.keras.layers.Permute((3, 1, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(256, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)
        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 5:
        modelNN = tf.keras.models.load_model(path_net)

    return modelNN


def load_net_2(path_net):
    if ex_mode == 0:
        modelNN = tf.keras.models.load_model(path_net)

    elif ex_mode == 1:  # MNIST FNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.Dense(20, input_shape=(784,), activation=tf.nn.relu),
            tf.keras.layers.Dense(10)
        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 2:  # MNIST CNN
        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
            tf.keras.layers.Dense(20, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)
        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 3 or ex_mode == 4:  # CIFAR10

        modelNN = tf.keras.models.Sequential([
            tf.keras.layers.Permute((3, 1, 2), input_shape=(32, 32, 3)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(20, activation=tf.nn.relu),
            tf.keras.layers.Dense(10)

        ])
        modelNN.load_weights(path_net)

    elif ex_mode == 5:
        modelNN = tf.keras.models.load_model(path_net)

    return modelNN


def model_properties(model):
    model.compile(optimizer='sgd', loss='SparseCategoricalCrossentropy', metrics=['accuracy'])
    W_model = model.get_weights()
    model.summary()
    n_neu = dict()
    n_neu_cum = dict()
    tmp = np.prod(list(model.layers[0].input_shape)[1:])
    n_neu[0] = [tmp]
    n_neu_cum[0] = [tmp]
    W = dict()
    layer_type = dict()
    layer_activation = dict()
    k_layer = 1
    i_weight = 0

    for k in range(len(model.layers)):
        if k == len(model.layers) - 1 and model.layers[k].__class__.__name__ == 'Dense':  # last layer
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1]]
            layer_type[k_layer] = 'Dense'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]

        elif model.layers[k].__class__.__name__ == 'Conv1D':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1], [model.layers[k].strides],
                          [model.layers[k].padding]]
            layer_type[k_layer] = 'Conv1D'
            layer_activation[k_layer] = model.layers[k].activation.__name__
            if layer_activation[k_layer] == 'None':
                n_neu[k_layer] = [tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            else:
                n_neu[k_layer] = [tmp, tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp,
                                      n_neu_cum[k_layer - 1][-1] + np.sum(n_neu[k_layer])]
            i_weight += 2
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'MaxPooling1D':
            if model.layers[k].pool_size[0] != model.layers[k].strides[0] or model.layers[k].padding != 'valid':
                raise Exception(
                    "Sorry, this framework only supports neural networks that has same size for pooling and striding and 'valid' padding")
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [[model.layers[k].pool_size[0]], [], [model.layers[k].strides[0]], [model.layers[k].padding]]
            layer_type[k_layer] = 'MaxPooling1D'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'Conv2D':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1], [model.layers[k].strides],
                          [model.layers[k].padding]]
            layer_type[k_layer] = 'Conv2D'
            layer_activation[k_layer] = model.layers[k].activation.__name__
            if layer_activation[k_layer] == 'None':
                n_neu[k_layer] = [tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            else:
                n_neu[k_layer] = [tmp, tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp,
                                      n_neu_cum[k_layer - 1][-1] + np.sum(n_neu[k_layer])]
            i_weight += 2
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'MaxPooling2D':
            if model.layers[k].pool_size != model.layers[k].strides or model.layers[k].padding != 'valid':
                raise Exception(
                    "Sorry, this framework only supports neural networks that has same size for pooling and striding and 'valid' padding")
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [[model.layers[k].pool_size], [], [model.layers[k].strides], [model.layers[k].padding]]
            layer_type[k_layer] = 'MaxPooling2D'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'Dense':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1]]
            layer_type[k_layer] = 'Dense'
            layer_activation[k_layer] = model.layers[k].activation.__name__
            if layer_activation[k_layer] == 'none':
                n_neu[k_layer] = [tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            else:
                n_neu[k_layer] = [tmp, tmp]
                n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp,
                                      n_neu_cum[k_layer - 1][-1] + np.sum(n_neu[k_layer])]

            i_weight += 2
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'Flatten':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [[], []]
            layer_type[k_layer] = 'Flatten'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'Permute':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [[], []]
            layer_type[k_layer] = 'Permute'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            k_layer += 1

        elif model.layers[k].__class__.__name__ == 'Dropout':
            continue

        elif model.layers[k].__class__.__name__ == 'ZeroPadding2D':
            tmp = np.prod(list(model.layers[k].output_shape)[1:])
            W[k_layer] = [[model.layers[k].padding[0]], [model.layers[k].padding[1]]]  # top, bottom, left, right
            layer_type[k_layer] = 'ZeroPadding2D'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            n_neu_cum[k_layer] = [n_neu_cum[k_layer - 1][-1] + tmp]
            k_layer += 1

        else:
            raise Exception(
                "Sorry, this framework only supports Dense, Conv1D, MaxPooling1D, Conv2D, MaxPooling2D, "
                "ZeroPadding2D, Flatten, Permute, and Dropout layers.")

    return W, layer_type, layer_activation, n_neu, n_neu_cum


def preprocess(layer_type, x):
    if layer_type[1] == 'Conv1D':
        data = np.expand_dims(x, axis=-1) if len(np.shape(x)[1:]) == 1 else x
    elif layer_type[1] == 'Conv2D':
        data = np.expand_dims(x, axis=-1) if len(np.shape(x)[1:]) == 2 else x
    else:
        data = x
    return data
