import numpy as np
from tqdm import tqdm
import sklearn.cluster as clust


class SpamRm:
    def __init__(self, pi, n_workers, n_classes):
        self.pi = pi
        self.n_workers = n_workers
        self.n_classes = n_classes

    def spam_score(self):
        spam = []
        for idx in tqdm(range(self.n_workers)):
            A = self.pi[idx]
            spam.append(
                1
                / (self.n_classes * (self.n_classes - 1))
                * np.sum(((A[np.newaxis, :, :] - A[:, np.newaxis, :]) ** 2))
                / 2
            )
        self.spam = np.array(spam)

    def get_spammers(self, thresh=None, k=2):
        if thresh is not None:
            return np.where(self.spam < thresh)[0]
        else:
            km = clust.KMeans(n_clusters=k)
            grp = km.fit_predict(self.spam.reshape(-1, 1))
            spam = np.argmin(np.unique(grp, return_counts=True)[1])
            return np.where(grp == spam)[0]
