import torch
import numpy as np 
import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from copy import deepcopy

X = torch.randn(100, 2)

def kmeans_pp(X, K, max_iters=100):
    # if input is numpy array, convert to torch tensor
    if type(X) == np.ndarray:
        X = torch.from_numpy(X).float()

    centers = X[torch.randint(X.size(0), (1,))]

    for i in range(K - 1):

        distances = torch.cdist(X, centers).min(dim=1).values + 1e-6
        probs = distances ** 2 / torch.sum(distances ** 2)
        idx = torch.multinomial(probs, num_samples=1)
        centers = torch.cat([centers, X[idx]], dim=0)

    for i in range(max_iters):

        distances = torch.cdist(X, centers)
        cluster_assignment = torch.argmin(distances, dim=1)
        new_centers2 = torch.stack([torch.mean(X[cluster_assignment == k], dim=0)
                                   for k in range(K)], dim=0)
        new_centers = []
        for k in range(K):
            if torch.sum(cluster_assignment == k) > 0:
                new_centers.append(torch.mean(X[cluster_assignment == k], dim=0))
            else:
                new_centers.append(centers[k])
        new_centers = torch.stack(new_centers, dim=0)
        assert not torch.isnan(new_centers).any()

        if torch.all(torch.eq(new_centers, centers)):
            break
        centers = deepcopy(new_centers)
    return centers, cluster_assignment
