import copy
from keras.datasets import mnist

from cluster_label_predictor import *
from image_label_predictor import *
from data_generation import *
from trainer_with_predicted_labels import Trainer as twpl
from trainer_cnn_predicted_labels import Trainer as tcpl
import time
import numpy as np

class NNumbersAddition:
    def __init__(self, num_examples, n_clusters):
        (self.train_X, self.train_Y), (self.test_X, self.test_Y) = mnist.load_data()

        self.num_examples = num_examples
        self.n_clusters = n_clusters
        self.epochs = 1

    def get_unique_images(self, image_ids):
        unique_images = []
        for id in image_ids:
            unique_images.append(self.train_X[id])
        return np.array(unique_images)

    def run(self, num_numbers, num_digits_per_nr):

        print('num_numbers = {}, num_digits_per_nr = {}, num_examples = {}. '.format(num_numbers, num_digits_per_nr, self.num_examples))

        print("generating dataset")
        start_time = time.time()
        self.examples, self.image_ids = generateDataset(self.train_X, self.train_Y, self.num_examples, num_numbers, num_digits_per_nr)
        unique_images = self.get_unique_images(self.image_ids)
        print('dataset generation time: {}'. format(round(time.time() - start_time, 2)))
        self.test_examples, self.test_image_ids = generateDataset(self.test_X, self.test_Y, self.num_examples, num_numbers, num_digits_per_nr)

        print('clustering')
        start_time = time.time()
        clusterer = ClustererKMeansWithAutoEncoder(unique_images, self.image_ids, self.n_clusters)
        cluster_ids = clusterer.cluster(savefigs=False)
        clustering_time = round(time.time() - start_time, 2)
        print('clustering time: {}'. format(clustering_time))

        image_label_predictor = ImageLabelPredictor(self.examples, self.train_X, clusterer.cluster_centers, clusterer.encoder)

        print('predicting cluster labels')
        start_time = time.time()
        cluster_label_predictor = ClusterLabelPredictorCVXPYMajority(self.n_clusters, self.examples)
        cluster_labels, loss, time_taken = cluster_label_predictor.predict_cluster_labels(cluster_ids)

        print(cluster_labels)
        image_label_predictor.init_image_labels(cluster_ids, cluster_labels)
        initial_label_accuracy = image_label_predictor.get_image_label_accuracy(self.train_Y)
        print('image label accuracy: {}%'.format(initial_label_accuracy))

        cluster_labels, loss = cluster_label_predictor.fix_cluster_labels(cluster_ids, cluster_labels)
        print(cluster_labels)
        image_label_predictor.init_image_labels(cluster_ids, cluster_labels)
        fixed_label_accuracy = image_label_predictor.get_image_label_accuracy(self.train_Y)
        print('image label accuracy: {}%'.format(fixed_label_accuracy))

        cluster_label_prediction_time = round(time.time() - start_time,2)
        print('cluster label prediction time: {}'.format(cluster_label_prediction_time))

        start_time = time.time()
        image_label_predictor.improve_image_labels(cluster_ids)
        final_label_accuracy = image_label_predictor.get_image_label_accuracy(self.train_Y)
        print('image label accuracy: {}%'.format(final_label_accuracy))
        image_label_prediction_time = round(time.time() - start_time,2)
        print('image label prediction time: {}'.format(image_label_prediction_time))

        print('training classifier')
        start_time = time.time()
        trainer = tcpl(num_numbers, self.n_clusters, self.examples, self.image_ids, self.train_X)
        num_epochs = 10
        trainer.train(image_label_predictor.image_labels, num_epochs)
        classifier_training_time = round(time.time() - start_time, 0)
        print('testing')
        classifier_accuracy = trainer.test(self.test_X, self.test_Y)
        total_training_time = clustering_time + cluster_label_prediction_time + image_label_prediction_time + classifier_training_time

        print(' accuracy = {}%, training time = {}s.'.format(classifier_accuracy, classifier_training_time))
        output = '{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(num_numbers, num_digits_per_nr, self.num_examples,
                                                                initial_label_accuracy, fixed_label_accuracy,
                                                                final_label_accuracy, classifier_accuracy,
                                                                clustering_time, cluster_label_prediction_time,
                                                                image_label_prediction_time, classifier_training_time,
                                                                total_training_time)
        print(output)
        f = open("./results/output.txt", "a")
        f.write(output)
        f.close()

    def get_image_label_accuracy(self, image_labels, correctLabels):
        correct = 0
        for imageId in image_labels:
            if image_labels[imageId] == correctLabels[imageId]:
                correct += 1
        accuracy = round(correct/len(image_labels)*100,2)
        return accuracy

if __name__ == "__main__":
    for h in range(2, 6):
        for w in range(1, 10):
            num_examples = 60000 // (w * h)
            for _ in range(3):
                nNumbersAddition = NNumbersAddition(num_examples=num_examples, n_clusters=10)
                nNumbersAddition.run(h, w)