import os
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.colors as mcolors
plt_colors=list(mcolors.TABLEAU_COLORS.keys()) #颜色变化
# mcolors.TABLEAU_COLORS可以得到一个字典，可以选择TABLEAU_COLORS,CSS4_COLORS等颜色组

def visualize_factor(all_loss, all_loss_label=['Regression_loss'], img_path=os.getcwd()):
    # 1.训练时先新建个列表，然后将loss值调用列表的append方法存入列表中
    # 2.例如: 列表'Regression_loss'，然后将列表名替换列表loss_label，利用plot即可画出曲线
    # 3.最后将画的图保存成图片，img_path为自定义的图片保存路径
    set_fontsize = 24
    plt.figure(figsize=(12,8))
    ax = plt.gca()
    plt.rcParams.update({"font.size": set_fontsize})#此处必须添加此句代码方可改变标题字体大小
    color_map = [1,5,9]
    for loss, loss_label, color_index in zip(all_loss, all_loss_label, color_map):
        epoch_num = list(range(0, len(loss)))
        plt.plot(epoch_num, loss, color=plt.get_cmap('tab20c')(color_index),
                 linewidth = 3.0, label=loss_label)
        # plt.plot(epoch_num, loss, color=plt.get_cmap('Set2')(all_loss.index(loss)), label=loss_label)
        # Suggested Color Map: Set1\Set2\Set3\tab10\Accent\Paired\Pastel1\Pastel2
    plt.ylabel('Factor Value',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    plt.xlabel('# Training Step',fontdict={'family' : 'Times New Roman', 'size': set_fontsize})
    plt.xlim(0, 113)
    plt.ylim(0.95, 1.02)
    plt.legend(loc='right',prop={'family' : 'Times New Roman', 'size': set_fontsize})
    plt.title('Variation of Momentum Fusion Factors', fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    plt.yticks(fontsize=set_fontsize)
    plt.xticks(fontsize=set_fontsize)
    x_label = ax.get_xticklabels()
    [x_label_temp.set_fontname('Times New Roman') for x_label_temp in x_label]
    y_label = ax.get_yticklabels()
    [y_label_temp.set_fontname('Times New Roman') for y_label_temp in y_label]
    # plt.tick_params(width=0.5, labelsize=4)

    import time
    plt.savefig(os.path.join(img_path, str(time.time())+"vis_factor.pdf"))
    # plt.show()

def visualize_loss(all_loss, all_loss_label=['Regression_loss'], img_path=os.getcwd()):
    # 1.训练时先新建个列表，然后将loss值调用列表的append方法存入列表中
    # 2.例如: 列表'Regression_loss'，然后将列表名替换列表loss_label，利用plot即可画出曲线
    # 3.最后将画的图保存成图片，img_path为自定义的图片保存路径
    set_fontsize = 24
    plt.figure(figsize=(12,8))
    ax = plt.gca()
    plt.rcParams.update({"font.size": set_fontsize})#此处必须添加此句代码方可改变标题字体大小
    for loss, loss_label in zip(all_loss, all_loss_label):
        epoch_num = list(range(0, len(loss)))
        plt.plot(epoch_num, loss, color=plt.get_cmap('Set2')(all_loss.index(loss)),
                 linewidth = 3.0, label=loss_label)
        # plt.plot(epoch_num, loss, color=plt.get_cmap('Set2')(all_loss.index(loss)), label=loss_label)
        # Suggested Color Map: Set1\Set2\Set3\tab10\Accent\Paired\Pastel1\Pastel2
    plt.ylabel('Loss Value',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    plt.xlabel('# Training Step',fontdict={'family' : 'Times New Roman', 'size': set_fontsize})
    plt.xlim(0, 35)
    plt.legend(loc='upper right',prop={'family' : 'Times New Roman', 'size': set_fontsize})
    plt.title('Variation of Feature Distillation Loss', fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    plt.yticks(fontsize=set_fontsize)
    plt.xticks(fontsize=set_fontsize)
    x_label = ax.get_xticklabels()
    [x_label_temp.set_fontname('Times New Roman') for x_label_temp in x_label]
    y_label = ax.get_yticklabels()
    [y_label_temp.set_fontname('Times New Roman') for y_label_temp in y_label]
    # plt.tick_params(width=0.5, labelsize=4)

    import time
    plt.savefig(os.path.join(img_path, str(time.time())+"vis_loss.pdf"))
    # plt.show()

def visualize_loss_double_y(all_loss, all_loss_label=['Regression_loss'], img_path=os.getcwd()):
    fig = plt.figure(figsize=(12,8))
    set_fontsize = 24
    plt.rcParams.update({"font.size": set_fontsize})#此处必须添加此句代码方可改变标题字体大小
    fig2 = fig.add_subplot()
    list_2 = [0, 3]
    all_loss_2 = [all_loss[i] for i in list_2]
    all_loss_label_2 = [all_loss_label[i] for i in list_2]
    for loss, loss_label in zip(all_loss_2, all_loss_label_2):
        epoch_num = list(range(0, len(loss)))
        fig2.plot(epoch_num, loss, color=mcolors.TABLEAU_COLORS[plt_colors[all_loss.index(loss)]],
                  linewidth = 3.0, linestyle='--', label=loss_label)
    fig2.set_ylabel('$L_{gt}$ / $L_{fusion}$ Value',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    fig2.set_xlim(0,113)
    fig2.set_ylim(0, 14)
    fig2.legend(bbox_to_anchor=(0.8, 1),prop={'family' : 'Times New Roman', 'size': set_fontsize})
    fig1 = fig2.twinx()
    list_1 = [1, 2]
    all_loss_1 = [all_loss[i] for i in list_1]
    all_loss_label_1 = [all_loss_label[i] for i in list_1]
    for loss, loss_label in zip(all_loss_1, all_loss_label_1):
        epoch_num = list(range(0, len(loss)))
        if all_loss.index(loss) == 1:
            fig1.plot(epoch_num, loss, color=mcolors.TABLEAU_COLORS[plt_colors[all_loss.index(loss)+5]],
                      linewidth = 3.0, label=loss_label)
        else:
            fig1.plot(epoch_num, loss, color=mcolors.TABLEAU_COLORS[plt_colors[all_loss.index(loss)]],
                      linewidth=3.0, label=loss_label)
    fig1.set_ylabel('$L_{reg}$ / $L_{cls}$ Value',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    fig1.set_xlim(0, 113)
    fig1.set_ylim(0,2.5)
    fig1.legend(loc='upper right',prop={'family' : 'Times New Roman', 'size': set_fontsize})
    plt.title('Variation of Training Losses across Steps',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    plt.yticks(fontsize=set_fontsize)
    plt.xticks(fontsize=set_fontsize)
    fig2.set_xlabel('# Training Step',fontdict={'family': 'Times New Roman', 'size': set_fontsize})
    x_label = fig2.get_xticklabels()
    [x_label_temp.set_fontname('Times New Roman') for x_label_temp in x_label]
    y2_label = fig2.get_yticklabels()
    [y2_label_temp.set_fontname('Times New Roman') for y2_label_temp in y2_label]
    y1_label = fig1.get_yticklabels()
    [y1_label_temp.set_fontname('Times New Roman') for y1_label_temp in y1_label]
    import time
    plt.savefig(os.path.join(img_path, str(time.time())+"vis_loss.pdf"))
    # plt.show()

def visualize_embedding(X): # X是需要可视化的表征
    # X = X.cpu().detach().numpy()
    import matplotlib.pyplot as plt
    from sklearn import manifold
    tsne = manifold.TSNE(n_components=2, init='pca', random_state=666)
    X_tsne = tsne.fit_transform(X)
    Y = list(range(X.shape[0])) # list of the numbers of samples
    print("Org data dimension is {}."
          "Embedded data dimension is {}".format(X.shape[-1], X_tsne.shape[-1]))
    x_min, x_max = X_tsne.min(0), X_tsne.max(0) # axis=0 for each coloumn
    X_norm = (X_tsne - x_min) / (x_max - x_min)  # Normalization
    plt.figure(figsize=(12, 12))
    for i in range(X_norm.shape[0]):
        plt.text(X_norm[i, 0], X_norm[i, 1], str(Y[i]), color=plt.cm.Set1(Y[i]),
                 fontdict={'weight': 'bold', 'size': 9})
    plt.xticks([X_norm.min(0)[0], X_norm.max(0)[0]])
    plt.yticks([X_norm.min(0)[1], X_norm.max(0)[1]])
    plt.show()