import numpy as np
from sklearn.mixture import GaussianMixture


def get_t(data, k=2, eps=1e-3):
    f = np.ravel(data).astype(np.float)
    f = f.reshape(-1, 1)
    g = GaussianMixture(n_components=k, covariance_type='full')
    g.fit(f)
    means = g.means_
    covars = np.sqrt(g.covariances_)
    return covars[0] * np.sqrt(-2 * np.log(eps) * covars[0] * np.sqrt(2 * np.pi)) + means[0]


new_similarity_poisoned = np.load('npy/blended/3000/auged_poisoned_caption_3000.npy')
original_similarity_poisoned = np.load('npy/blended/3000/poisoned_caption.npy')
new_similarity_benign = np.load('npy/clean/clipcap_clean.npy')
original_similarity_benign = np.load('npy/clean/ori_clean.npy')

new_similarity_all = np.concatenate([new_similarity_poisoned, new_similarity_benign])
original_similarity_all = np.concatenate([original_similarity_poisoned, original_similarity_benign])

consistency = new_similarity_all - original_similarity_all

q = 50
poisoned_subset_indices = np.argpartition(consistency, -q)[-q:]
num_of_benign_in_poisoned_subset = 0
for ind in poisoned_subset_indices:
    if ind > 3000:
        num_of_benign_in_poisoned_subset += 1
print(f"Number of benign pairs mistakenly grouped in to poisoned subset: {num_of_benign_in_poisoned_subset}")
np.save('npy/blended/3000/pure_poison.npy', poisoned_subset_indices)

threshold = get_t(consistency, k=5, eps=1e-5)[0]
benign_subset_indices = np.where(np.array(consistency) < threshold)[0]

num_of_poison_in_benign_subset = 0
for ind in benign_subset_indices:
    if ind < 3000:
        num_of_poison_in_benign_subset += 1
print(f"Number of poison pairs mistakenly grouped in to benign subset: {num_of_poison_in_benign_subset}")
np.save('npy/blended/3000/pure_benign.npy', poisoned_subset_indices)
print('./')
