import pickle
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(42)

def tsne(all_outputs):
    # prepare
    label_set = ["PPG", "ECG", "GSR", "EEG_F", "ACC_X", "ACC_Y", "ACC_Z", "EEG_O", "EEG_L", "EEG_R"]
    # label_set = ["PPG", "ECG", "GSR", "EEG_F", "ACC_X"]
    samples, labels = list(), list()
    for i in range(len(all_outputs)):
        curr_cls = [c for c in all_outputs[i][1] if not np.isnan(np.sum(c))]
        samples += curr_cls
        labels += [label_set[i] for _ in range(len(curr_cls))]
    
    # T-SNE
    X_embedded = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=10).fit_transform(np.array(samples))

    # plot
    plt.style.use('seaborn-darkgrid')

    signal_to_color = {
        "PPG": 'red',
        "ECG": "orange",
        "ACC_X": "green",
        "ACC_Y": "limegreen",
        "ACC_Z": "lime",
        "EEG_F": "blue",
        "EEG_O": "royalblue",
        "EEG_L": "cyan",
        "EEG_R": "deepskyblue",
        "GSR": 'black'
    }

    figs = list()

    for label in sorted(label_set):
        idxs = [i for i in range(len(labels)) if labels[i] == label]
        figs.append(plt.scatter(X_embedded[idxs, 0], X_embedded[idxs, 1], label=label, edgecolors='black', alpha=0.5, c=signal_to_color[label]))

    # plt.title("T-SNE Plot of Embedding of Special Token: [CLS]")
    plt.xlabel("Component 1", fontsize=16)
    plt.ylabel("Component 2", fontsize=16)

    # edit legend
    handles, labels_ = plt.gca().get_legend_handles_labels()

    # # check original raw orders
    # label_orders = [(i, labels_[i]) for i in range(len(labels_))]
    # for l in label_orders:
    #     print(l)
    # exit()

    # past the order below for reference
    # (0, 'ACC_X')
    # (1, 'ACC_Y')
    # (2, 'ACC_Z')
    # (3, 'ECG')
    # (4, 'EEG_F')
    # (5, 'EEG_L')
    # (6, 'EEG_O')
    # (7, 'EEG_R')
    # (8, 'GSR')
    # (9, 'PPG')

    # adjust based on the check result
    order = [9,3,8,0,1,2,4,6,7,5]
    legend = plt.legend(
        [handles[idx] for idx in order],
        [labels_[idx] for idx in order],
        bbox_to_anchor=(1.0, 1.02), # (1.22, 1.02), (right-left, top-down)
        fontsize=16,
        markerscale=1.2,
        frameon=True
    )

    for fi in figs:
        fi.set_alpha(0.2)
    # legend.get_frame().set_alpha(0.5)
    # plt.legend()

    # plt.grid()

    # save
    plt.savefig("figures/cls_tsne_mae.pdf", format="pdf", bbox_inches="tight")
    plt.show()

if __name__ == '__main__':
    with open("data/processed_files/processed_samples_for_nld_mae.pkl", 'rb') as f:
        all_outputs = pickle.load(f)
    tsne(all_outputs)