import torch
import faiss

def perform_clustering(X, n_clusters=7, niter=20, verbose=False):
    n, d = X.shape
    X = X.detach().cpu().contiguous().numpy()
    kmeans = faiss.Kmeans(
        d, 
        n_clusters, 
        niter=niter, 
        verbose=verbose,
        gpu=False
    )
    
    kmeans.train(X)
    centroids = kmeans.centroids
    
    _, labels = kmeans.index.search(X, 1)
    return centroids, labels
def propagate_aff_cam_with_bkg(cams, aff=None, mask=None):
    b, n, c = cams.shape

    n_pow = 2
    n_log_iter = 1

    if mask is not None:
        for i in range(b):
            for label_id in mask:
                aff[i, ((mask==label_id) == False)] = 0


    cams_rw = cams.clone()

    aff = aff.detach() ** n_pow
    aff = aff / (torch.sum(aff, dim=1, keepdim=True) + 1e-4)

    for i in range(n_log_iter):
        aff = torch.matmul(aff, aff)

    for i in range(b):
        _cams = cams[i].reshape(-1, c)
        _aff = aff[i]
        _cams_rw = torch.matmul(_aff, _cams)
        cams_rw[i] = _cams_rw.reshape(cams_rw[i].shape)

    return cams_rw