from sklearn.cluster import KMeans
import numpy as np
from sklearnex import patch_sklearn
from numba import jit
import util
import cProfile
patch_sklearn()
class Kmeans_l2:
    def __init__(self, k, n_init):
        self.k = k
        self.kmeans = KMeans(n_clusters=k, n_init=n_init)
        self.centers_subset = None
        self.centers_full = None
    def get_original_cost(self, data):
        self.kmeans.fit(data)
        self.centers_full = self.kmeans.cluster_centers_
        return self.kmeans.inertia_
    
    def get_subset_and_original(self, data_full, data_subset):
        self.kmeans.fit(data_subset)
        self.centers_subset = self.kmeans.cluster_centers_
        dists = self.calculate_min_distances(data_full, self.kmeans.cluster_centers_)
        # dists = self.calculate_min_distances6(data_full, self.kmeans.cluster_centers_)
        full_cost = np.sum(dists**2)
        #full_cost = get_cost_for_centers(data_full, self.kmeans.cluster_centers_)
        
        return self.kmeans.inertia_, full_cost
    
    def calculate_min_distances(self,data, centers):
        min_distances = np.zeros(data.shape[0])
        for i in range(data.shape[0]):
            min_distances[i] = np.min(np.linalg.norm(data[i] - centers, axis=1))
        return min_distances

    
    def get_centers_subset(self):    
        assert(self.centers_subset is not None)
        return self.kmeans.cluster_centers_
   
    def get_centers_full(self):
        assert(self.centers_full is not None)
        return self.kmeans.cluster_centers_