# thanks https://github.com/kjahan/k_means/blob/master/src/clustering.py
# generalized from which to handle high dimensional features

import numpy as np
import random as rand
import multiprocessing

def plain_distance(a, b):
    return np.linalg.norm(a-b, ord=2)

class Clusterer:
    def __init__(self, distance, TYPE, K, TOL):
        self.distance = distance
        self.type = TYPE
        self.k = K
        self.tol = float(TOL)
        self.clusters = None  # to be a dict of {int: list}
        self.means = []  # means of clusters

    def next_random(self, index, points, clusters):
        # this method returns the next random node
        # pick next node that has the maximum distance from other nodes
        dist = [0] * len(points)
        for i, point_1 in enumerate(points):
            # compute this node distance from all other points in cluster
            for cluster in clusters.values():
                point_2 = cluster[0]
                dist[i] += self.distance(point_1, point_2)
                # now let's return the point that has the maximum distance from previous nodes
        max_idx = np.argsort(np.array(dist))[-1]
        max_point = points[max_idx]
        return max_point

    def initial_means(self, points):
        # compute the initial means
        # pick the first node at random
        point_ = rand.choice(points)
        clusters = dict()
        clusters.setdefault(0, []).append(point_)
        points = [_ for _ in points if not (_ == point_).all()]
        # now let's pick k-1 more random points
        for i in range(1, self.k):
            point_ = self.next_random(i, points, clusters)
            # clusters.append([point_])
            clusters.setdefault(i, []).append(point_)
            points = [_ for _ in points if not (_ == point_).all()]
        # compute mean of clusters
        self.means = self.compute_means(clusters)

    def compute_means(self, clusters):
        means = []
        for cluster in clusters.values():
            means.append(np.mean(cluster, axis=0))
        return means

    def assign_points(self, points):
        # assign nodes to the cluster with the smallest mean
        clusters = dict()
        for point in points:
            dist = []
            for mean in self.means:
                dist.append(self.distance(point, mean))
            index = np.argmin(np.array(dist))
            clusters.setdefault(index, []).append(point)
        return clusters

    def need_stop(self, means, threshold):
        # means is new, self.means is old, verify nobody is static
        for i in range(len(self.means)):
            mean_1 = self.means[i]
            mean_2 = means[i]
            if self.distance(mean_1, mean_2) > threshold:
                return False
        return True

    def get_centroids(self, ):
        for i, cluster in self.clusters.items():
            chosen = rand.choices(cluster, k=4)
            yield self.means[i], chosen

    def infer(self, x):
        # x: (num_sample, num_feat)
        # returns the closest centroids with same shape
        tba = []
        for sample in x:
            dist = []
            for mean in self.means:
                dist.append(self.distance(sample, mean))
            idx = np.argsort(np.array(dist))[0]
            tba.append(self.means[idx])
        centroids = np.vstack(tba)
        return centroids

    def fit(self, train_array):
        assert train_array.shape[0] > self.k
        points_ = [point for point in train_array]  # a list of feature arrays
        # compute the initial means
        self.initial_means(points_)
        print(f"Starting {self.k}-Means...")
        iterations = 0
        while 1:
            # assignment step: assign each node to the cluster with the closest mean
            points_ = [point for point in train_array]
            clusters = self.assign_points(points_)
            means = self.compute_means(clusters)
            if not self.need_stop(means, self.tol):
                self.means = []
                self.means = means
            else:
                break
            iterations += 1
        print(f"{self.k}-Means completed in {iterations} iterations")
        self.clusters = clusters

    def get_assignment(self, train_array, get_mean_error=False):
        # well, not necessarily training points, can also be used for inference
        points_ = [point for point in train_array]
        if not get_mean_error:
            res = []
            for point in points_:
                dist = []
                for mean in self.means:
                    dist.append(self.distance(point, mean))
                index = np.argmin(np.array(dist))
                res.append(index)
            return res
        else:
            res = []
            total_error = 0
            for point in points_:
                dist = []
                for mean in self.means:
                    dist.append(self.distance(point, mean))
                index = np.argmin(np.array(dist))
                res.append(index)
                total_error += dist[index]
            return res, total_error / len(res)

def worker(nodes,argument_dict):
    c = Clusterer(**argument_dict)
    c.fit(nodes)
    return c

def get_best_clusterer(nodes, times, argument_dict):
    cpus = np.minimum(times, multiprocessing.cpu_count())
    inputs = [(nodes, argument_dict), ] * times
    cb = 10e5
    tbr = None
    with multiprocessing.Pool(cpus) as p:
        for cluster in p.starmap(worker, inputs):
            err = cluster.get_assignment(nodes, get_mean_error=True)[1]
            print(f"Yielding mean error of {err}")
            if err < cb:
                cb = err
                tbr = cluster
        p.close()
        p.join()
    return tbr