import torch
from torch import nn
import numpy as np
import random
import os
import torch.nn.functional as F
from torch.distributions import Bernoulli, RelaxedBernoulli
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from scipy.spatial import distance
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision.utils import make_grid
from scipy.signal import savgol_filter

def tsne_plot(embedded_features, labels):
	def unique(list1): 
		unique_list = [] 
		for x in list1: 
			if x not in unique_list: 
				unique_list.append(x) 
		return unique_list

	labels=np.asarray(labels)
	tsne_model = TSNE(n_components=2,init='pca')
	print('fitting')
	X_2d = tsne_model.fit_transform(embedded_features)
	print('fitted')
    
	target_names=unique(labels) 
	target_names=[0,1,2,3,4,5,6,7,8,9]

	target_ids=range(len(target_names))
	plt.figure(figsize=(16, 16)) 
	colors=['orangered','lawngreen','deepskyblue','black','brown','grey','orange','yellow','pink','cyan','magenta']
	i_color=0
	for i, label in zip(target_ids, target_names):
		label_name = str(label)

		if label == 0:
			label_name = '0'
		elif label == 1:
			label_name = '1'
		elif label == 3:
			label_name = '2'
		elif label == 2:
			label_name = '3'
		elif label == 4:
			label_name = '4'
		elif label == 5:
			label_name = '5'
		elif label == 6:
			label_name = '6'
		elif label == 7:
			label_name = '7'
		elif label == 8:
			label_name = '8'
		elif label == 9:
			label_name = '9'

		plt.scatter(X_2d[labels == i, 0], X_2d[labels == i, 1],c=colors[i_color], label = label_name)
		i_color+=1

	#plt.legend(loc=2)
	plt.legend(loc=8, mode = "expand", ncol = 10)
	plt.legend(fontsize=40) # using a size in points
	plt.legend(fontsize="xx-large") # using a named size
