import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

# 读取GC_text并绘图
def test_meihua():
    x = np.arange(14)
    y = np.sin(x / 2)

    plt.step(x, y + 2, label='pre (default)')
    plt.plot(x, y + 2, 'o--', color='grey', alpha=0.3)

    plt.step(x, y + 1, where='mid', label='mid')
    plt.plot(x, y + 1, 'o--', color='grey', alpha=0.3)

    plt.step(x, y, where='post', label='post')
    plt.plot(x, y, 'o--', color='grey', alpha=0.3)

    plt.grid(axis='x', color='0.95')
    plt.legend(title='Parameter where:')
    plt.title('plt.step(where=...)')
    plt.show()


# 传入两个GC矩阵，画出ROC图并返回曲线下面积  GC是真值  GC_est是预测值
def draw_GC_ROC_curve(GC, GC_est_original,draw = True):
    GC = GC.flatten().astype(np.float64)
    GC_est = GC_est_original.flatten().astype(np.float64)
    if draw:
        score = draw_ROC_curve(GC, GC_est / GC_est.max())
    else:
        score = metrics.roc_auc_score(GC, GC_est / GC_est.max())
    return score

# 画ROC图
def draw_ROC_curve(label, pre):
    FPR, TPR, P = metrics.roc_curve(label, pre)
    plt.plot(FPR, TPR, 'b*-', label='roc')
    plt.plot([0, 1], [0, 1], 'r--', label="45°")
    plt.legend()
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.show()
    AUC_score = metrics.auc(FPR, TPR)
    return AUC_score

# ---------------------------------2画图--------------------------------
# my_color = ["#6cc6e3", "#accf78", "#fac6a9", "#d8538a", "#fff87b", "#c57c72"]
# my_color = ["#001eff", "#ff0000", "#0cff00", "#00fff0", "#d200ff", "#fffc00"]
def draw_var(GC_true_list, GC_pre_list, sparsity_a=2, sparsity_b=6):
    label = ['0.'+str(i) for i in range(sparsity_a,sparsity_b)]
    for i,(GC_true,GC_pre) in enumerate(zip(GC_true_list, GC_pre_list)):
        GC_true = GC_true.flatten().astype(np.float64)
        GC_pre = GC_pre.flatten().astype(np.float64)
        FPR, TPR, P = metrics.roc_curve(GC_true, GC_pre / GC_pre.max())
        print("sparsity=0."+str(i+2)+"score="+str(metrics.roc_auc_score(GC_true, GC_pre / GC_pre.max())))
        plt.plot(FPR, TPR, marker='.', label='0.'+str(i+2), alpha=0.8)  # 打印标签，如sparsity=0.2
        # plt.plot(FPR, TPR, marker='.', label='0.'+str(i+2), color=my_color[i], alpha=1)  # 添加自定义颜色
        # plt.stackplot(FPR, TPR, color=my_color[i], alpha=0.9)

    # 画一个斜线
    plt.plot([0, 1], [0, 1], '--', color=(1, 0, 0), alpha=0.5)
    plt.legend(title="sparsity") # 自动检测线条并添加图例说明

    plt.grid(True)# 画网格线
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.show()
    # draw_GC_ROC_curve(GC_true,GC_pre)
    return

def draw_var_P(GC_true_list, GC_pre_list, P_a=10, P_b=30):
    label = [str(i) for i in range(P_a, P_b, 5)]
    for i,(GC_true,GC_pre) in enumerate(zip(GC_true_list, GC_pre_list)):
        GC_true = GC_true.flatten().astype(np.float64)
        GC_pre = GC_pre.flatten().astype(np.float64)
        FPR, TPR, P = metrics.roc_curve(GC_true, GC_pre / GC_pre.max())
        print("P="+str(label[i])+"score="+str(metrics.roc_auc_score(GC_true, GC_pre / GC_pre.max())))
        plt.plot(FPR, TPR, marker='.', label=label[i], alpha=0.8)  # 打印标签，如sparsity=0.2
        # plt.plot(FPR, TPR, marker='.', label='0.'+str(i+2), color=my_color[i], alpha=1)  # 添加自定义颜色
        # plt.stackplot(FPR, TPR, color=my_color[i], alpha=0.9)

    # 画一个斜线
    plt.plot([0, 1], [0, 1], '--', color=(1, 0, 0), alpha=0.5)
    plt.legend(title="P")  # 自动检测线条并添加图例说明

    plt.grid(True)# 画网格线
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.show()
    # draw_GC_ROC_curve(GC_true,GC_pre)
    return

def draw_loz(GC_true_list, GC_pre_list,F_a=10,F_b=50):

    label = [str(i) for i in range(F_a, F_b, 10)]
    for i,(GC_true,GC_pre) in enumerate(zip(GC_true_list, GC_pre_list)):
        GC_true = GC_true.flatten().astype(np.float64)
        GC_pre = GC_pre.flatten().astype(np.float64)
        FPR, TPR, P = metrics.roc_curve(GC_true, GC_pre / GC_pre.max())
        print("FP="+str(label[i])+"score="+str(metrics.roc_auc_score(GC_true, GC_pre / GC_pre.max())))
        plt.plot(FPR, TPR, marker='.', label=label[i], alpha=0.8)  # 打印标签，如sparsity=0.2
        # plt.plot(FPR, TPR, marker='.', label='0.'+str(i+2), color=my_color[i], alpha=1)  # 添加自定义颜色
        # plt.stackplot(FPR, TPR, color=my_color[i], alpha=0.9)

    # 画一个斜线
    plt.plot([0, 1], [0, 1], '--', color=(1, 0, 0), alpha=0.5)
    plt.legend(title="F")  # 自动检测线条并添加图例说明

    plt.grid(True)# 画网格线
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.show()
    # draw_GC_ROC_curve(GC_true,GC_pre)
    return

# ---------------------------1读取文件--------------------------------
def draw_var_AUROC():
    # 变sparsity=0.2~0.9
    folder_name = 'mixer_lam=0.00002_final'
    # folder_name = 'NGC_mlp_adam_var'
    # folder_name = 'NGC_lstm_adam_var'
    sparsity_a = 2
    sparsity_b = 6
    GC_true_list = [np.loadtxt('./save/'+folder_name+'/sparsity=0.'+str(i)+'_true.txt') for i in range(sparsity_a,sparsity_b)]
    GC_pre_list = [np.loadtxt('./save/'+folder_name+'/sparsity=0.'+str(i)+'_pre.txt') for i in range(sparsity_a,sparsity_b)]
    draw_var(GC_true_list, GC_pre_list, sparsity_a, sparsity_b)

    # 变P=10~25
    # P_a = 10
    # P_b = 30
    #
    # folder_name = 'mixer_p=N'
    # # folder_name = 'NGC_lstm_adam_var_P=N'
    # # folder_name = 'NGC_mlp_adam_var_P=N'
    # GC_true_list = [np.loadtxt('./save/' + folder_name + '/p=' + str(i) + '_true.txt') for i in
    #                 range(P_a, P_b, 5)]
    # GC_pre_list = [np.loadtxt('./save/' + folder_name + '/p=' + str(i) + '_pre.txt') for i in
    #                range(P_a, P_b, 5)]
    # draw_var_P(GC_true_list, GC_pre_list, P_a, P_b)

def draw_loz_AUROC():
    folder_name = 'mixer_lam=0.025_loz'
    # folder_name = 'NGC_mlp_ista_loz'
    # folder_name = 'NGC_lstm_ista_loz'
    # 起始
    F_a = 10
    # 结束
    F_b = 50
    GC_true_list = [np.loadtxt('./save/'+folder_name+'/F='+str(i)+'_true.txt') for i in range(F_a,F_b,10)]
    GC_pre_list = [np.loadtxt('./save/'+folder_name+'/F='+str(i)+'_pre.txt') for i in range(F_a,F_b,10)]
    draw_loz(GC_true_list, GC_pre_list)


if __name__ == '__main__':
    draw_var_AUROC()
    # draw_loz_AUROC()
