import matplotlib.pyplot as plt
import pickle
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
import os
from tqdm import tqdm

np.random.seed(42)

def read_data(task_name, sample_num=100, task='class', model_name='fusion'):
    X, labels = list(), list()
    fns = list()

    # for storing another model representation
    X_1 = list()

    # fetch labels
    for fn in tqdm(sorted(os.listdir("data/{}/samples".format(task_name)))):
        if fn[0] == '.':
            continue
        with open("data/{}/samples/{}".format(task_name, fn), 'rb') as f:
            labels.append(pickle.load(f)['label'])
        fns.append(fn)
    labels, fns = np.array(labels), np.array(fns)

    # sample
    chosen_idxs = list()
    if task != 'class':
        labels = [tuple(n) for n in labels]
    for label in set(labels):
        curr_idxs = [i for i in range(len(labels)) if labels[i] == label]
        if len(curr_idxs) <= sample_num:
            chosen_idxs += curr_idxs
        else:
            curr_idxs = np.random.choice(curr_idxs, sample_num, replace=False).tolist()
            chosen_idxs += curr_idxs
    
    # fetch embeddings
    labels = [labels[idx] for idx in chosen_idxs]
    for fn in tqdm(fns[chosen_idxs]):
        if type(model_name) == str:
            with open("data/{}/{}/{}".format(task_name, model_name, fn), 'rb') as f:
                X.append(pickle.load(f).tolist())
        else: # add pre and post
            with open("data/{}/{}/{}".format(task_name, model_name[0], fn), 'rb') as f:
                X.append(pickle.load(f).tolist())
            with open("data/{}/{}/{}".format(task_name, model_name[1], fn), 'rb') as f:
                X_1.append(pickle.load(f).tolist())
    
    # fetch query embedding
    with open("data/{}/nlp_embed/0".format(task_name), 'rb') as f:
        curr_nlp_embed = pickle.load(f)
        question = np.array([curr_nlp_embed['question']])
    with open("data/{}/label_map.pkl".format(task_name), "rb") as f:
        label_map = pickle.load(f)
        answers = np.array([label_map[l] for l in label_map])
    queries = {
        'question': question,
        'answers': answers
    }

    # down sample first
    if len(np.array(X).shape) > 2:
        # X = np.array(X)[:, 0, :]
        X = PCA(n_components=384).fit_transform(np.array(X)[:, 0, :])
        # if len(X[0][0]) > 384:
        #     X = PCA(n_components=384, svd_solver='full').fit_transform(np.array(X)[:, 0, :])
        #     X = (X - np.mean(X, axis=1, keepdims=True)) / np.std(X, axis=1, keepdims=True)

    # if len(np.array(X_1).shape) > 2:
    #     if len(X_1[0][0]) > 384:
    #         X_1 = PCA(n_components=384).fit_transform(np.array(X_1)[:, 0, :])
    # X_1 = PCA(n_components=384, svd_solver='full').fit_transform(np.array(X_1))
    # X_1 = np.concatenate((X_1, X_1), axis=1)
    # for q in queries:
    #     queries[q] = np.concatenate((queries[q], queries[q]), axis=1)
    X_1 = np.array(X_1)[:, 0, :]

    # return
    if type(model_name) == str:
        return X, labels, queries
    return X, X_1, labels, queries

def get_tsne(X, queries=None, X_1=None):
    # old pipeline
    tsne_obj = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=10, metric='cosine')

    if queries is None:
        return tsne_obj.fit_transform(np.array(X))

    # stack all query first
    stacked_X = np.concatenate((X, queries['answers'], queries['question']), axis=0)
    # stacked_X = np.concatenate((X, queries['answers']), axis=0)

    if X_1 is not None:
        stacked_X = np.concatenate((X_1, stacked_X), axis=0)

    return tsne_obj.fit_transform(np.array(stacked_X))

def plot(X, labels, l_names=[], task_name='ecg_heart_cat', title='title', task='class', queries=None, X_1=None):
    # config setting
    colors = ['blue', 'orange', 'red', 'lime', 'cyan', 'black', 'yellow']

    if X_1 is not None:
        colors_1 = ['green', 'red']
        l_names_1 = ["{}_align".format(ln) for ln in l_names]

    if task != 'class':
        l_idx = 0
        min_, max_ = np.min([n[l_idx] for n in set(labels)]), np.max([n[l_idx] for n in set(labels)])
        # min_, max_ = np.min([n for n in set(labels)]), np.max([n for n in set(labels)])
        min_ -= 0.1

    # start plot
    plt.clf()
    plt.style.use("seaborn-v0_8-dark")
    # plot query and answer point
    if queries is not None:
        X = X[:-(1+len(l_names)), :] # +1 if add question
        query_points = X[-(1+len(l_names)):, :] # +1 if add question
    
    if X_1 is not None:
        X_1 = X[:len(X_1), :]
        X = X[len(X_1):, :]
        
    # main plot
    for label in sorted(set(labels), reverse=True):
        idxs = [i for i in range(len(labels)) if labels[i] == label]

        if task == 'class':
            plt.scatter(X[idxs, 0], X[idxs, 1], label=l_names[label], edgecolors='black', alpha=0.5, c=colors[label])
        else:
            plt.scatter(X[idxs, 0], X[idxs, 1], edgecolors='black', alpha=(label[l_idx]-min_)/(max_-min_), c='red', s=120)
            # plt.scatter(X[idxs, 0], X[idxs, 1], edgecolors='black', alpha=(label-min_)/(max_-min_), c='red', s=120)
        
        if X_1 is not None:
            plt.scatter(X_1[idxs, 0], X_1[idxs, 1], label=l_names_1[label], edgecolors='black', alpha=0.5, c=colors_1[label])

    if queries is not None:
        plt.scatter([query_points[-1, 0]], [query_points[-1, 1]], edgecolors='black', label="Question", s=22**2, c='purple') # uncomment if add question
        for i in range(len(l_names)):
            plt.scatter([query_points[i, 0]], [query_points[i, 1]], edgecolors='black', label="State: {}".format(l_names[i]), s=22**2, c=colors[-i])
            
    # general setting
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")
    plt.title(title)
    if task == 'class':
        plt.legend(frameon=True)
    fig_name = "{}_tsne.pdf".format(task_name)
    plt.savefig(fig_name, format="pdf", bbox_inches="tight")
    print("kubectl cp physio-model-59d4995db8-b6swn:ppg_bp/{} figures/{} -c physio-model".format(fig_name, fig_name))

def main(
    task_name='ecg_heart_cat',
    l_names=['Normal', 'Abnormal'],
    sample_num=100,
    title='Heart Beat Abnormal Detection',
    task='class',
    model_name='fusion'
):
    X, labels, queries = read_data(task_name, sample_num=sample_num, task=task, model_name=model_name)
    # queries = None
    X = get_tsne(X, queries=queries)
    plot(X, labels, l_names=l_names, task_name=task_name, title=title, task=task, queries=queries)

def main_pre_post(
    task_name='ecg_heart_cat',
    l_names=['Normal', 'Abnormal'],
    sample_num=100,
    title='Heart Beat Abnormal Detection',
    task='class',
    model_name=['mae', 'mae_msitf']
):
    X, X_1, labels, queries = read_data(task_name, sample_num=sample_num, task=task, model_name=model_name)
    # queries = None
    X = get_tsne(X, queries=queries, X_1=X_1)
    plot(X, labels, l_names=l_names, task_name=task_name, title=title, task=task, queries=queries, X_1=X_1)

if __name__ == '__main__':
    # main(
    #     task_name='non_invasive_bp',
    #     sample_num=500,
    #     title='Cuff-less and Non-invasive Blood Pressure Estimation',
    #     task='reg'
    # )

    # main(
    #     task_name='ppg_hgb',
    #     sample_num=500,
    #     title='Non-invasive Hemoglobin Estimation',
    #     task='reg'
    # )

    # main(
    #     task_name='uci_har',
    #     l_names=['Walking', 'Walking Upstairs', 'Walking Downstairs', "Sitting", "Standing", "Laying"],
    #     sample_num=500,
    #     title='Human Acitivity Recognition'
    # )

    # main(
    #     task_name='CVD',
    #     l_names=['Normal', 'Cerebrovascular Disease', 'nsufficiency of cerebral blood supply'],
    #     sample_num=5000,
    #     title='Fatigue Detection'
    # )

    main_pre_post(
        task_name='ecg_heart_cat',
        l_names=['Normal', 'Abnormal'],
        sample_num=4000, # 4000, 5000
        title='Heart Beat Abnormal Detection',
        model_name=['mae', 'mae_msitf']
    )