import json
import matplotlib.pyplot as plt
import numpy as np

def draw_layer_scores(attention_layer_scores, save_name="test_attention_heatmap.png"):
    fig, ax = plt.subplots(figsize=(12, 8))
    im = ax.imshow(attention_layer_scores,
                   cmap='viridis',  # 颜色映射
                   aspect='auto',  # 自动调整宽高比
                   interpolation='nearest',  # 禁用插值
                   vmin=0,  # 颜色范围最小值
                   vmax=np.max(attention_layer_scores))  # 颜色范围最大值
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('Attention Score', rotation=270, labelpad=15)

    # 设置坐标轴标签
    ax.set_xlabel('Context Tokens', fontsize=12)
    ax.set_ylabel('Layer Index', fontsize=12)
    ax.set_title('Attention Scores Across Layers', fontsize=14, pad=20)

    # 设置刻度标签
    num_layers, num_tokens = attention_layer_scores.shape
    ax.set_yticks(np.arange(num_layers))
    ax.set_xticks(np.arange(num_tokens))

    # 显示网格线（可选）
    ax.grid(which='minor', color='gray', linestyle=':', linewidth=0.5)

    # 优化布局
    plt.tight_layout()
    plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()



def draw_u_curve(attention_layer_scores, save_name="test_u_curve.png"):
    attention_layer_scores -= attention_layer_scores.min(axis=1, keepdims=True)
    # 创建画布
    a = attention_layer_scores
    plt.figure(figsize=(10, 6))
    draw_all = False
    if draw_all:
        # 遍历每一行并绘制虚线
        for i, row in enumerate(a):
            plt.plot(row,
                     linestyle='--',
                     linewidth=1.5,
                     alpha=0.7,
                     label=f'Row {i + 1}'
                     )
    else:
        att_first_mean = attention_layer_scores[0: int(0.5 * len(attention_layer_scores))].mean(axis=0)
        plt.plot(att_first_mean, linestyle='-', alpha=0.9, linewidth=2, label=f"average_0to{0.5 * len(attention_layer_scores)}")
        att_last_mean = attention_layer_scores[int(0.5 * len(attention_layer_scores)):].mean(axis=0)
        plt.plot(att_last_mean, linestyle='-', alpha=0.9, linewidth=2,
                 label=f"average_{0.5 * len(attention_layer_scores)}to{len(attention_layer_scores)}")
    att_mean = attention_layer_scores.mean(axis=0)
    plt.plot(att_mean,linestyle='-',alpha=0.9,linewidth=5, label="average")
    # 添加标注idia
    plt.title("Row-wise Line Plots", fontsize=14)
    plt.xlabel("Column Index", fontsize=12)
    plt.ylabel("Value", fontsize=12)
    plt.grid(alpha=0.2)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')  # 右侧图例

    # 显示图像
    plt.tight_layout()
    if save_name is not None:
        plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()
    exit(2)


def u_curve():
    model_type = "vicuna_7b"
    data_name = "HotpotQA"
    setting_type = "concat"
    filename = f"interpre/layer_scores/{model_type}_{data_name}_{setting_type}_scores.json"
    with open(filename,'r') as f:
        results = json.load(f)
    all_scores = np.array(results['all_scores'])
    scores = np.array(results['scores'])
    print(scores.shape)
    idx = 0
    one_score = scores[idx]
    draw_u_curve(all_scores,f"interpre/layer_scores/{model_type}_{data_name}_{setting_type}_allscores_u_3.png")
    n_layers = len(one_score)
    all_layer_average_scores = scores.mean(axis=1)
    print(all_layer_average_scores.shape)
    draw_u_curve(all_layer_average_scores)


if __name__ == '__main__':
    u_curve()