from sklearn.cluster import KMeans
import numpy as np
import cProfile

class K_z_algo:
    def __init__(self, k,z, n_init):
        self.k = k
        self.z = z
        self.clusters = None
        self.centers = None
        
    def fit(self, data, init_centers=None):
        if init_centers is None:
            centers = data[np.random.choice(data.shape[0], self.k, replace=False)]
        else:
            centers = init_centers
        print(f"initial cost: {self.calculate_cost(data, centers)}")
        centers = self.optimize_centers(data, centers)
        
        # for i in range(100):
        #     centers = self.optimize_centers(data, centers)
        #     reassign = self.reassign_points(data, centers)
        #     centers = self.get_new_centers(data)
        #     cost = self.calculate_cost(data, centers)
        #     print(f"cost: {cost}")
            
    def optimize_centers(self, data, centers, gradient_steps=10,lr=0.1):
        self.reassign_points(data, centers)
        cost_before = self.calculate_cost(data, centers)
        #Have centers and assignments, so we're going to do gradient descent for each cluster seperately
        for i, center in enumerate(centers):
            points = data[self.clusters == i]
            for _ in range(gradient_steps):
                grad = self.calculate_gradient(points, center)
                center = center - grad*lr
            centers[i] = center
                            
        cost_after = self.calculate_cost(data, centers)
        return centers

    def reassign_points(self, data, centers):
        dist_to_centers = self.get_dists_to_centers(data, centers)
        self.clusters = np.argmin(dist_to_centers, axis=0)
        print(f"clusters: {self.clusters}")
        
    def calculate_gradient(self, cluster, center):
        dists = np.linalg.norm(cluster - center, axis=1)
        dist_to_power = dists**self.z
        diff = cluster - center
        individual_grads = diff * dist_to_power[:, np.newaxis]
        grad = -np.sum(individual_grads, axis=0)*self.z
        return grad    
    def get_dists_to_centers(self, data, centers):
            return np.linalg.norm(data - centers[:, np.newaxis], axis=2)
            
    def calculate_cost(self, data, centers):
        dists_to_centers = np.linalg.norm(data - centers[:, np.newaxis], axis=2)
        min_dists = np.min(dists_to_centers, axis=0)
        return np.sum(min_dists**self.z)
        
        
    def get_original_cost(self, data):
        self.kmeans.fit(data)
        return self.kmeans.inertia_
    
    def get_subset_and_original(self, data_full, data_subset):
        self.kmeans.fit(data_subset)
        full_cost = self.get_cost_for_centers(data_full, self.kmeans.cluster_centers_)
        
        return self.kmeans.inertia_, full_cost
    def get_cost_for_centers(self, data, centers):
        dists_to_centers = np.linalg.norm(data - centers[:, np.newaxis], axis=2)
        min_dists = np.min(dists_to_centers, axis=0)
        return np.sum(min_dists**2)
    
        