import numpy as np
from keras.layers import Conv2D
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import MaxPooling2D
from keras.models import Sequential
from keras.optimizers import SGD


class MNIST_CNN:
    # define cnn model
    def __init__(self):
        model = Sequential()
        model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 1)))
        model.add(MaxPooling2D((2, 2)))
        model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform'))
        model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform'))
        model.add(MaxPooling2D((2, 2)))
        model.add(Flatten())
        model.add(Dense(100, activation='relu', kernel_initializer='he_uniform'))
        model.add(Dense(10, activation='softmax'))
        # compile model
        opt = SGD(learning_rate=0.01, momentum=0.9)
        model.compile(optimizer=opt, loss='SparseCategoricalCrossentropy', metrics=['accuracy'])
        self.model = model


class Trainer:
    def __init__(self, num_numbers, n_clusters, examples, image_ids, train_X):
        self.num_numbers = num_numbers
        self.n_clusters = n_clusters
        self.examples = examples
        self.image_ids = image_ids
        self.train_X = train_X
        self.network = MNIST_CNN()

    def train(self, cluster_labels, epochs):
        train_x, train_y = self.process_examples(cluster_labels)
        # fit model
        self.network.model.fit(train_x, train_y, epochs=epochs, batch_size=32, verbose=0)
        # save model
        #self.network.model.save('final_classifier_model.h5')

    def process_examples(self, cluster_labels):
        train_x = []
        train_y = []
        for idx in self.image_ids:
            train_x.append(self.train_X[idx])
            train_y.append(cluster_labels[idx])

        train_x = np.array(train_x)
        train_y = np.array(train_y)
        train_x = self.prep_images(train_x)

        # one hot encode target values
        # train_y = to_categorical(train_y)
        return train_x, train_y

    def prep_images(self, images):
        # reshape
        images = images.reshape((images.shape[0], 28, 28, 1))
        # convert from integers to floats
        images = images.astype('float32')
        # normalize to range 0-1
        images = images / 255.0
        # return normalized images
        return images

    def test(self, test_X, test_Y):
        test_x = self.prep_images(test_X)
        # one hot encode target values
        # test_y = to_categorical(self.test_Y)
        _, acc = self.network.model.evaluate(test_x, test_Y, verbose=0)
        acc = round(acc * 100.0, 2)
        print('Accuracy of the network on the 10000 test images: {} '.format(acc))
        return acc
