import torch
import torchvision.transforms as transforms
from sklearn.cluster import KMeans
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
inv_normalize = transforms.Normalize(mean=[-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010], std=[1/0.2023, 1/0.1994, 1/0.2010])
to_pil = transforms.ToPILImage()


def _MadryDefense(net, poisoned_trainset, num_poisons_expected = 1000, n_clusters = None):
    class_indices = [[] for _ in range(len(poisoned_trainset.classes))]
    clean_indices = []
    bad = []
    feats = []
    layer_cake = list(net.children())
    feature_extractor = torch.nn.Sequential(*(layer_cake[:-1]), torch.nn.Flatten())
    # extract all feature vectors
    with torch.no_grad():
        print(len(poisoned_trainset))
        for i in range(len(poisoned_trainset)):
            img, target, _ = poisoned_trainset[i]
            img = img.unsqueeze(0).to(device)
            feats.append(feature_extractor(img))

            class_indices[target].append(i)
    for i in range(len(class_indices)):
        temp_feats = [] # features of one class
        for temp_index in class_indices[i]:
            temp_feats.append(feats[temp_index])
        temp_feats = torch.cat(temp_feats)
        print(temp_feats.shape)
        mean_feat = torch.mean(temp_feats, dim=0)
        temp_feats = temp_feats - mean_feat
        U, S, V = torch.svd(temp_feats, compute_uv=True, some=False)
        vec = V[:, 0]
        vals = []
        for j in range(temp_feats.shape[0]):
            vals.append(torch.dot(temp_feats[j], vec).pow(2))
        k = min(int(1.5 * num_poisons_expected), len(vals) - 1)
        _, indices = torch.topk(torch.tensor(vals), k)
        bad_indices = []
        for temp_index in indices:
            bad_indices.append(class_indices[i][temp_index])
        clean = list(set(class_indices[i]) - set(bad_indices))
        bad = bad + list(set(bad_indices))
        clean_indices = clean_indices + clean
    return clean_indices, bad

    
def _ActivationClustering(net, poisoned_trainset, n_clusters=2, num_poisons_expected = None):

    class_indices = [[] for _ in range(len(poisoned_trainset.classes))]
    clean_indices = []
    bad = []
    feats = []
    layer_cake = list(net.children())
    feature_extractor = torch.nn.Sequential(*(layer_cake[:-1]), torch.nn.Flatten())
    with torch.no_grad():
        for i in range(len(poisoned_trainset)):
            img, target, _ = poisoned_trainset[i]
            img = img.unsqueeze(0).to(device)
            feats.append(feature_extractor(img).squeeze(0))
            class_indices[target].append(i)

    for i in range(len(class_indices)):
        temp_feats = np.array([feats[temp_idx].squeeze(0).cpu().numpy() for temp_idx in class_indices[i]])
        kmeans = KMeans(n_clusters=n_clusters).fit(temp_feats)
        if kmeans.labels_.sum() >= len(kmeans.labels_) / 2.:
            clean_label = 1
        else:
            clean_label = 0
        clean = []
        for (bool, idx) in zip((kmeans.labels_ == clean_label).tolist(), list(range(len(kmeans.labels_)))):
            if bool:
                clean.append(class_indices[i][idx])
        clean_indices = clean_indices + clean
        bad = bad + list(set(class_indices[i]) - set(clean))
    return clean_indices, bad



def get_defense(defense):
    if defense.lower() == 'madry':
        return _MadryDefense
    elif defense.lower() == 'activation_clustering':
        return _ActivationClustering
