import torch
import cProfile
import numpy as np
class Kmeans_torch:
    def __init__(self, k,z, n_init, lr=0.1):
        self.k = k
        self.centers = None
        self.best_cost = None
        self.history = None
        self.n_init = n_init
        self.lr = lr
        self.z = z
        
    def fit(self, data, max_iter=1000, initial_centers=None):
        if initial_centers is not None:
            c = torch.from_numpy(initial_centers)
        else:
            c = torch.rand(self.k, data.shape[1], dtype=torch.double)
        c.requires_grad = True
        cost_history = []
        cost_history.append(torch.sum(torch.min(torch.cdist(data, c),dim=1)[0]**self.z).item())
        sgd = torch.optim.AdamW([c], lr=self.lr)
        for i in range(max_iter):
            dists = torch.cdist(data, c)
            min_dists, clusters = torch.min(dists, dim=1)
            loss = torch.sum(min_dists**self.z)
            loss.backward()
            sgd.step()
            sgd.zero_grad()
            cost_history.append(loss.item())
        self.centers = c.detach().numpy()
        self.best_cost = cost_history[-1]
        self.history = cost_history
        
    def fit_2(self, data, max_reassigns, max_sgd, initial_centers = None):
        history = []
        if initial_centers is None:
            c = self.get_initial_kzplusplus(data)
        else:
            c = initial_centers
        # c.requires_grad = True
        for i in range(max_reassigns):
            min_dists, clusters = self.assign_clusters(data, c)
            history.append(torch.sum(min_dists**self.z).item())
            for j in range(self.k):
                cluster = data[clusters==j]
                c[j] = self.get_cluster_center(cluster, c[j], max_sgd)
        self.centers = c
        self.best_cost = history[-1]
        self.history = history
        
    def assign_clusters(self, data, centers_in):
        centers = torch.from_numpy(centers_in)
        dists = torch.cdist(data, centers)
        min_dists, clusters = torch.min(dists, dim=1)
        return min_dists,clusters
                
    def get_initial_kzplusplus(self, data):
        centers = data[np.random.choice(data.shape[0], 1)]
        for i in range(self.k-1):
            dists = torch.cdist(data, centers)
            min_dists = torch.min(dists, dim=1)[0]
            costs = min_dists**self.z
            probs = (costs/costs.sum()).numpy()
            centers = np.append(centers, data[np.random.choice(data.shape[0], 1, p=probs)], axis=0)
        return centers
    
    def get_cluster_center(self, cluster, center, max_sgd):
        opt_center2 = torch.mean(cluster, dim=0)
        center = torch.from_numpy(center)
        center = center.reshape((1,center.shape[0])) 
        center.requires_grad = True
        sgd = torch.optim.AdamW([center], lr=self.lr)
        dists = torch.cdist(cluster, center)
        assert(max_sgd>0)
        for i in range(max_sgd):
            dists = torch.cdist(cluster, center)
            loss = torch.sum(dists**self.z)
            loss.backward()
            sgd.step()
            sgd.zero_grad()
        opt_center2 = opt_center2.reshape((1,opt_center2.shape[0]))
        opt_center2_cost = torch.sum(torch.cdist(cluster, opt_center2)**self.z).item()
        cost_center = torch.sum(dists**self.z).item()
        return center.detach().numpy()
        