from sklearn.manifold import TSNE
import numpy as np

import matplotlib as mpl
from matplotlib import pyplot as plt

# mpl.use('Qt5Agg')  # interactive mode works with this, pick one
# mpl.use('TkAgg')  # interactive mode works with this, pick one

# mpl.use('Agg')


import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn


def feat_vis(features, labels):
    tsne = TSNE(n_components=2, random_state=0)
    cluster = np.array(tsne.fit_transform(np.array(features)))
    actual = np.array(labels)

    plt.figure(figsize=(10, 10))
    # cifar = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # for i, label in zip(range(10), cifar):
    #     idx = np.where(actual == i)
    #     plt.scatter(cluster[idx, 0], cluster[idx, 1], marker='.', label=label)

    plt.scatter(cluster[:, 0], cluster[:, 1], marker='.')

    plt.legend()
    plt.show()
    return

def feature_analysis(feats_clean, feats_adv, labels):
    # concat clean & adv
    feats_all = torch.cat((feats_clean, feats_adv), dim=0).detach().cpu()
    labels_all = torch.cat((labels, labels), dim=0).detach().cpu()  # labels
    labels_all2 = torch.cat((torch.zeros_like(labels), torch.ones_like(labels)), dim=0).detach().cpu()  # attack label
    labels_all3 = torch.cat([1 if i//64==0 else 0 for i in range(feats_all.size(0))], dim=0)

    # tsne
    tsne_seed = 0
    tsne = TSNE(n_components=2, random_state=tsne_seed)
    # cluster = np.array(tsne.fit_transform(np.array(feats_all)))
    cluster = tsne.fit_transform(feats_all)

    # plot
    plt.figure(figsize=(10, 10))

