import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def draw_heatmap(data, labels, title="Heatmap", save_path=None):
    # 绘制热力图
    plt.figure(figsize=(10, 8))
    sns.heatmap(data, xticklabels=labels, yticklabels=labels, annot=True, fmt=".2f", cmap="YlGnBu")
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    plt.savefig(save_path, dpi=300) if save_path else None



def make_latex_table(acc_res, acc_res_ci, model_names, caption="Accuracy Results", label="tab:acc_results"):
    table = []
    table.append("\\begin{table}")
    table.append("\\centering")
    table.append("\\begin{tabular}{l" + "c" * len(model_names) + "}")
    table.append("\\toprule")
    header = [""] + model_names
    table.append(" & ".join(header) + " \\\\")
    table.append("\\midrule")
    for i, row_name in enumerate(model_names):
        row = [row_name]
        for j in range(len(model_names)):
            mean = acc_res[i, j] * 100
            ci = acc_res_ci[i, j] * 100
            row.append(f"{mean:.2f}\\% $\\pm$ {ci:.2f}")
        table.append(" & ".join(row) + " \\\\")
    table.append("\\bottomrule")
    table.append("\\end{tabular}")
    table.append(f"\\caption{{{caption} (百分比表示)}}")
    table.append(f"\\label{{{label}}}")
    table.append("\\end{table}")
    return "\n".join(table)

def calc_confident_interval(accs, confidence=0.95):
    n = len(accs)
    mean = np.mean(accs)
    stderr = np.std(accs) / np.sqrt(n)
    h = stderr * 1.96  # for 95% confidence interval
    return mean - h, mean + h, h


if __name__ == "__main__":
    # 数据
    labels = [
        'qwen_0.5B', 'llama_8B', 'qwen_1.5B', 'deepseekv3', 'llama_70B',
        'qwen_14B', 'qwen_3B', 'gpt-4o', 'qwen_32B', 'qwen_72B', 'qwen_7B'
    ]
    data = np.array([
        [0.98812787, 0.32895019, 0.41346255, 0.2444609 , 0.42270556, 0.18237309, 0.1925959 , 0.26379972, 0.26351951, 0.28121759, 0.2457969 ],
        [0.31547718, 0.88820581, 0.80212754, 0.85392874, 0.83021501, 0.8281339 , 0.82595449, 0.86534616, 0.84386385, 0.85940979, 0.85924786],
        [0.41204849, 0.81178488, 0.86223077, 0.79186215, 0.83232643, 0.75408093, 0.75848839, 0.80891965, 0.79735216, 0.81094885, 0.79477159],
        [0.22652756, 0.83494078, 0.76952216, 0.91088217, 0.77807259, 0.88423031, 0.87303841, 0.89258553, 0.87684951, 0.8817167 , 0.88284568],
        [0.4287918 , 0.82338439, 0.81560368, 0.79320678, 0.88307976, 0.74327558, 0.74447951, 0.8110396 , 0.79494462, 0.81643991, 0.78974279],
        [0.16321228, 0.79947335, 0.72557651, 0.87085266, 0.72133269, 0.92219337, 0.88369211, 0.86254907, 0.86111409, 0.85118045, 0.86791718],
        [0.16724049, 0.80384763, 0.73128299, 0.86902675, 0.72513508, 0.89338594, 0.91149749, 0.86048552, 0.85233115, 0.84738748, 0.8678642 ],
        [0.25519555, 0.84301723, 0.78533161, 0.88865635, 0.79825809, 0.87474637, 0.8605649 , 0.91665096, 0.87632964, 0.88455789, 0.88792575],
        [0.24571859, 0.8309412 , 0.77772008, 0.88597312, 0.78473774, 0.88335707, 0.86311767, 0.88829094, 0.9014819 , 0.88529117, 0.87716191],
        [0.26619183, 0.84292001, 0.78968843, 0.88758909, 0.80521201, 0.87074132, 0.85608249, 0.89305932, 0.88170222, 0.90531711, 0.87887587],
        [0.23297797, 0.84147819, 0.77500274, 0.88351944, 0.78152947, 0.88206322, 0.87172213, 0.89320357, 0.87090372, 0.87624242, 0.91273237]
    ])
