import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.manifold import TSNE
import torch.nn.functional as F
import os

def plot_tsne(classifier, encoder, loader, args):
    classifier.eval()
    encoder.eval()
    print("-------Plotting TSNE---------")
    features, labels, domains = [], [], []
 
    num_classes = classifier.fc.weight.size(0)
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            target = data[1]
            inputs = inputs.cuda()
            target = target
            # compute output
            f = encoder(inputs)
            y = classifier(f)
            features.extend(F.normalize(f, dim=1, p=2).cpu().numpy().tolist())
            labels.extend(target)

    prototypes = F.normalize(classifier.fc.weight, dim=1, p=2).data.cpu().numpy().tolist()

    features.extend(prototypes)
    prototypes_labels = torch.arange(num_classes).cpu().numpy().tolist()
    labels.extend(prototypes_labels)
    # define the colormap
    cmap = plt.cm.jet
    # extract all colors from the .jet map
    cmaplist = [cmap(i) for i in range(cmap.N)]
    # create the new map
    cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)

    features, labels = np.array(features), np.array(labels)
    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)
    
    plt.figure(figsize=(10, 10))
    markers = ["o", "D"]

    end = len(features) - num_classes
    plt.scatter(X_tsne[:end, 0], X_tsne[:end,1], c=labels[:end], cmap=cmap, marker=markers[0], alpha=0.5)
    plt.scatter(X_tsne[end:, 0], X_tsne[end:,1], c="black", cmap=cmap, marker=markers[1], alpha=1)
 
    plt.show()
    plt.savefig(os.path.join('./tsne/{}_{}.png'.format("-".join([str(s) for s in args.s]), args.t)))
    classifier.train()
    encoder.train()
 
