import torch
from sklearn.manifold import MDS, TSNE
import matplotlib.pyplot as plt
from scipy.spatial.distance import squareform, pdist

def MDS_plot(X, path, red=True, s=None):
    embedding = MDS(n_components=2, dissimilarity="precomputed")
    similarities = squareform(pdist(X, 'euclidean'))
    X_transformed = embedding.fit_transform(similarities)
    torch.save(X_transformed, path[:-3] + "pt")
    torch.save(X, path[:-4] + "_raw.pt")
    plt.clf()
    # plt.tight_layout()
    if red:
        plt.scatter(X_transformed[:-1, 0], X_transformed[:-1, 1])
        plt.scatter(X_transformed[-2, 0],
                    X_transformed[-2, 1], color="red")
        if X_transformed[-1, 0] ** 2 + X_transformed[-1, 1] ** 2 < 10000:
            plt.scatter(X_transformed[-1, 0],
                        X_transformed[-1, 1], color="orange")
    else:
        if s is None:
            plt.scatter(X_transformed[:, 0], X_transformed[:, 1])
        else:
            plt.scatter(X_transformed[:, 0], X_transformed[:, 1], s=s)
    plt.legend()
    plt.savefig(path)
    return

def TSNE_plot(X, path, red = True):
    embedding = TSNE(n_components=2)
    X_transformed = embedding.fit_transform(X)
    torch.save(X_transformed, path[:-3] + "pt")
    plt.clf()
    if red:
        plt.scatter(X_transformed[:-1, 0], X_transformed[:-1, 1])
        plt.scatter(X_transformed[-2, 0], X_transformed[-2, 1], color="red")
        if X_transformed[-1, 0] ** 2 + X_transformed[-1, 1] ** 2 < 10000:
            plt.scatter(X_transformed[-1, 0],
                        X_transformed[-1, 1], color="orange")
    else:
        plt.scatter(X_transformed[:, 0], X_transformed[:, 1])
    plt.legend()
    plt.savefig(path)
    return