import math
import numpy as np

class ImageLabelPredictor:
    def __init__(self, examples, train_x, cluster_centers, encoder):
        self.examples = examples
        self.train_x = train_x
        self.cluster_centers = cluster_centers
        self.encoder = encoder

    def init_image_labels(self, cluster_ids, cluster_labels):
        self.image_labels = {}
        for ex in self.examples:
            for n in range(len(ex[2])):
                for m in range(len(ex[2][n])):
                    imageId = ex[2][n][m]
                    if imageId is not None:
                        clusterId = cluster_ids[imageId]
                        self.image_labels[imageId] = cluster_labels[clusterId]

    def get_image_label_accuracy(self, correctLabels):
        correct = 0
        for imageId in self.image_labels:
            if self.image_labels[imageId] == correctLabels[imageId]:
                correct += 1
        accuracy = round(correct/len(self.image_labels)*100,2)
        return accuracy

    def get_image_as_point(self, imageIds):
        images = []
        for imageId in imageIds:
            images.append(self.train_x[imageId])
        images = np.array(images)
        images1 = images.reshape(len(images), -1)
        images2 = images1.astype(float) / 255.
        images_as_point = self.encoder.predict(images2)
        return images_as_point

    def euclidean_distance(self, point1, point2):
        sum_squares = 0
        for i in range(10):
            sum_squares += (point2[i] - point1[i]) * (point2[i] - point1[i])
        return math.sqrt(sum_squares)

    def get_image_label_distance_from_centroid(self, cluster_ids, threshold):
        imageIds = list(self.image_labels.keys())
        correct = set()
        images_as_point = self.get_image_as_point(imageIds)
        for i, img_point in enumerate(images_as_point):
            cluster_center = self.cluster_centers[cluster_ids[imageIds[i]]]
            distance = self.euclidean_distance(img_point, cluster_center)
            if distance > threshold - 1 and distance <= threshold:
                correct.add(imageIds[i])
        return correct

    def fix_image_labels_new_new(self, correct):
        changed = False
        for ex in self.examples:
            partial_sum = 0
            unresolved = []
            for n in range(len(ex[2])):
                for m in range(len(ex[2][n])):
                    if ex[2][n][m] is not None:
                        p = len(ex[2][n]) - 1 - m
                        coeff = pow(10, p)
                        if ex[2][n][m] in correct:
                            partial_sum += coeff * self.image_labels[ex[2][n][m]]
                        else:
                            unresolved.append([ex[2][n][m], coeff])
            if len(unresolved) == 1:
                label = int((ex[1] - partial_sum) / unresolved[0][1])
                if label in range(0,10):
                    self.image_labels[unresolved[0][0]] = label
                    if unresolved[0][0] not in correct:
                        correct.add(unresolved[0][0])
                        changed = True

            elif len(unresolved) > 1:
                for i in range(len(unresolved)):
                    label = self.image_labels[unresolved[i][0]]
                    coeff = unresolved[i][1]
                    partial_sum += label * coeff
                if partial_sum == ex[1]:
                    for i in range(len(unresolved)):
                        if unresolved[i][0] not in correct:
                            correct.add(unresolved[i][0])
                            changed = True
        return changed

    def improve_image_labels(self, cluster_ids):
        correct = set()
        for t in range(1, 6):
            new_correct = self.get_image_label_distance_from_centroid(cluster_ids, t)
            correct.update(new_correct)
            changed = True
            while changed == True:
                changed = self.fix_image_labels_new_new(correct)
        return correct
