import os
import random

import matplotlib.pyplot as plt
import numpy as np
from keras import initializers
from keras.layers import Dense, Input
from keras.models import Model, load_model
from keras.optimizers import SGD
from sklearn.cluster import KMeans


class Clusterer:
    def __init__(self, images, image_ids, n_clusters):
        self.images = images
        self.image_ids = image_ids
        self.n_clusters = n_clusters

    def getImages(self, exampleIds):
        images = []
        idxs = []
        for i, ex in enumerate(self.examples):
            if i in exampleIds:
                for n in range(self.num_numbers):
                    for m in range(self.num_digits_per_nr):
                        if ex[2][n][m] not in idxs:
                            images.append(ex[0][n][m])
                            idxs.append(ex[2][n][m])
        return np.array(images), idxs

    def savefigs(self, images, labels, n_clusters):
        clusterData = []
        for i in range(n_clusters):
            clusterData.append([])
        for i in range(len(labels)):
            clusterData[labels[i]].append([images[i]])

        for i in range(n_clusters):
            ncols = 5
            nrows = 5
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)

            fig.set_figheight(15)
            fig.set_figwidth(15)

            randomNoList = random.sample(range(0, len(clusterData[i])), nrows * ncols)
            count = 0

            for r in range(nrows):
                for c in range(ncols):
                    axes[r][c].imshow(np.reshape(clusterData[i][randomNoList[count]][0], (28, 28)))
                    count += 1
            fig.savefig('./kmeans_cluster_' + str(i) + '.png')


class ClustererKMeansWithAutoEncoder(Clusterer):
    def __init__(self, images, image_ids, n_clusters):
        super().__init__(images, image_ids, n_clusters)

    def autoencoder(self, dims, act='relu', init='glorot_uniform'):
        """
        Fully connected auto-encoder model, symmetric.
        Arguments:
            dims: list of number of units in each layer of encoder. dims[0] is input dim, dims[-1] is units in hidden layer.
                The decoder is symmetric with encoder. So number of layers of the auto-encoder is 2*len(dims)-1
            act: activation, not applied to Input, Hidden and Output layers
        return:
            (ae_model, encoder_model), Model of autoencoder and model of encoder
        """
        n_stacks = len(dims) - 1
        # input
        input_img = Input(shape=(dims[0],), name='input')
        x = input_img
        # internal layers in encoder
        for i in range(n_stacks - 1):
            x = Dense(dims[i + 1], activation=act, kernel_initializer=init, name='encoder_%d' % i)(x)

        # hidden layer
        encoded = Dense(dims[-1], kernel_initializer=init, name='encoder_%d' % (n_stacks - 1))(
            x)  # hidden layer, features are extracted from here

        x = encoded
        # internal layers in decoder
        for i in range(n_stacks - 1, 0, -1):
            x = Dense(dims[i], activation=act, kernel_initializer=init, name='decoder_%d' % i)(x)

        # output
        x = Dense(dims[0], kernel_initializer=init, name='decoder_0')(x)
        decoded = x
        return Model(inputs=input_img, outputs=decoded, name='AE'), Model(inputs=input_img, outputs=encoded,
                                                                          name='encoder')

    def cluster(self, savefigs=False):
        X = self.images.reshape(len(self.images), -1)
        X = X.astype(float) / 255.

        # If encoder model and weights already available, reuse. Else run the autoencoder.
        results_dir = './results'

        if os.path.exists(results_dir + '/encoder_model.h5') and os.path.exists(results_dir + '/encoder_weights.h5'):
            encoder = load_model(results_dir + '/encoder_model.h5')
            encoder.load_weights(results_dir + '/encoder_weights.h5')

        else:

            if not os.path.exists(results_dir):
                os.makedirs(results_dir)

            dims = [X.shape[-1], 500, 500, 2000, 10]
            init = initializers.VarianceScaling(scale=1. / 3., mode='fan_in',
                                                distribution='uniform')
            pretrain_optimizer = SGD(learning_rate=1, momentum=0.9)
            pretrain_epochs = 300
            batch_size = 256
            autoencoder, encoder = self.autoencoder(dims, init=init)
            autoencoder.compile(optimizer=pretrain_optimizer, loss='mse')

            autoencoder.fit(X, X, batch_size=batch_size, epochs=pretrain_epochs)  # , callbacks=cb)
            autoencoder.save_weights(results_dir + '/ae_weights.h5')
            encoder.save(results_dir + '/encoder_model.h5')
            encoder.save_weights(results_dir + '/encoder_weights.h5')

        self.encoder = encoder

        # KMeans clustering using encoder weights
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
        predicted = kmeans.fit(encoder.predict(X))

        self.cluster_centers = predicted.cluster_centers_

        cluster_ids = {}
        for i, idx in enumerate(self.image_ids):
            cluster_ids[idx] = predicted.labels_[i]

        if savefigs == True:
            self.savefigs(self.images, predicted.labels_, self.n_clusters)

        return cluster_ids
