from sklearn.manifold import TSNE
import numpy as np
import pandas as pd
import torch
from torchvision import datasets, transforms 
import  matplotlib.pyplot as plt
import seaborn as sn

MNIST_data = datasets.MNIST('./data', train=False, 
                                transform=transforms.Compose([transforms.ToTensor()]))
data_1000 = torch.flatten(MNIST_data.train_data[0:10000,:,:], start_dim = -2)
labels_1000 = MNIST_data.train_labels[0:10000]
print("end")

# attack_dataset = torch.load("./data/attacked_mnist.pt")
# data_1000 = torch.flatten(attack_dataset["data"][0:10000,:,:,:], start_dim = 1).cpu()
# labels_1000 = attack_dataset["labels"][0:10000].cpu()

# Picking the top 1000 points as TSNE takes a lot of time for 15K points

model = TSNE(n_components=2, random_state=0)
# configuring the parameteres
# the number of components = 2
# default perplexity = 30
# default learning rate = 200
# default Maximum number of iterations for the optimization = 1000
tsne_data = model.fit_transform(data_1000)
# creating a new data frame which help us in ploting the result data
tsne_data = np.vstack((tsne_data.T, labels_1000)).T
tsne_df = pd.DataFrame(data=tsne_data, columns=("Dim_1", "Dim_2", "label"))
# Ploting the result of tsne
sn.FacetGrid(tsne_df, hue="label", size=6).map(plt.scatter, "Dim_1", "Dim_2").add_legend()
plt.show()