import matplotlib.pyplot as plt
import numpy as np

# 数据
# data = np.array([
#     [0.84962479, 0.69892694, 0.36980047, 0.51097859, 0.6023821,  0.57792226, 0.53662385, 0.4488197 ],
#     [0.68976542, 0.85254047, 0.3470345 , 0.51277074, 0.61641939, 0.58222999, 0.51480569, 0.43467951],
#     [0.3706071 , 0.35460747, 0.85302518, 0.35262502, 0.37406094, 0.37580577, 0.38141447, 0.43364828],
#     [0.53616728, 0.5281919 , 0.34727114, 0.81206162, 0.45054573, 0.55953455, 0.53116676, 0.46675413],
#     [0.63910892, 0.66077352, 0.36774834, 0.47352364, 0.76716287, 0.53582451, 0.50440671, 0.43327781],
#     [0.59895074, 0.59848395, 0.37698442, 0.5545673 , 0.50363412, 0.82994971, 0.55355006, 0.48514353],
#     [0.55120735, 0.53280687, 0.38804603, 0.53124421, 0.48428589, 0.56037386, 0.82903606, 0.54636051],
#     [0.46298632, 0.45111079, 0.44384484, 0.46833252, 0.42548162, 0.4907777 , 0.54435527, 0.83468489]
# ])

# labels = ['qwen_32B', 'deepseekv3', 'qwen_0.5B', 'llama_8B', 
#           'gpt-4o', 'qwen_7B', 'qwen_3B', 'qwen_1.5B']

def draw_heatmap(data, labels, title = 'Model Similarity Heatmap', save_path = None):
    """
    绘制热力图并保存

    :param data: 热力图数据
    :param labels: 标签
    :param save_path: 保存路径，默认为None表示不保存
    """
    # 绘制热力图
    plt.figure(figsize=(10, 8))
    heatmap = plt.imshow(data, cmap='viridis')

    # 添加刻度标签
    plt.xticks(ticks=np.arange(len(labels)), labels=labels, rotation=45, ha="right")
    plt.yticks(ticks=np.arange(len(labels)), labels=labels)

    # 添加颜色条
    plt.colorbar(heatmap)

    # 添加数值标签
    for i in range(len(labels)):
        for j in range(len(labels)):
            plt.text(j, i, f"{data[i, j]:.2f}", ha='center', va='center', color='w' if data[i, j] < 0.5 else 'black')

    plt.title(title)
    
    if save_path:
        plt.savefig(save_path)
    
    plt.tight_layout()
    plt.show()
