import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random

def read_values(prefix, num_models, hessian_things_dir, max_samples=30):
    traces = []
    top1_eigenvalues = []
    for i in range(1, num_models + 1):
        file_path = f"{hessian_things_dir}/{prefix}_{i}_hessian_data.pth"
        trace_value = 0  
        top1_eigenvalue = 0  
        if os.path.exists(file_path):
            loaded_data = torch.load(file_path)
            if 'trace' in loaded_data.keys():
                trace_value = np.mean(loaded_data['trace'])
            if 'top_eigenvalues' in loaded_data.keys():
                if isinstance(loaded_data['top_eigenvalues'], int):
                    top1_eigenvalue = 0
                else:
                    top1_eigenvalue = loaded_data['top_eigenvalues'][0]
                    
            traces.append(trace_value)
            top1_eigenvalues.append(top1_eigenvalue)

    return np.array(traces), np.array(top1_eigenvalues)

sns.set(style="whitegrid")
sns.set_context("talk")
fontsize = 25
hessian_things_dir = ""

model_clip_traces, model_clip_top1 = read_values("model_clip", 50, hessian_things_dir)
model_distilledclip_traces, model_distilledclip_top1 = read_values("model_distilledclip", 50, hessian_things_dir)
model_ssl_traces, model_ssl_top1 = read_values("model_ssl", 50, hessian_things_dir)
model_distilledssl_traces, model_distilledssl_top1 = read_values("model_distilledssl", 50, hessian_things_dir)


# 计算归一化的横坐标
len_clip = len(model_clip_traces)
len_clip_md = len(model_distilledclip_traces)
len_ssl = len(model_ssl_traces)
len_ssl_md = len(model_distilledssl_traces)

normalized_x_clip = np.linspace(1/len_clip, 1, len_clip)
normalized_x_clip_md = np.linspace(1/len_clip_md, 1, len_clip_md)
normalized_x_ssl = np.linspace(1/len_ssl, 1, len_ssl)
normalized_x_ssl_md = np.linspace(1/len_ssl_md, 1, len_ssl_md)


# 绘制 Top1 特征值的绝对值
plt.figure()
first_row_labels = ['clip', 'clip_md']
second_row_labels = ['ssl', 'ssl_md']

line1, = plt.plot(normalized_x_clip, np.abs(model_clip_top1), marker='o', color='red')
line2, = plt.plot(normalized_x_clip_md, np.abs(model_distilledclip_top1), marker='o', color='blue')
line3, = plt.plot(normalized_x_ssl, np.abs(model_ssl_top1), marker='s', color='red')  # 新增
line4, = plt.plot(normalized_x_ssl_md, np.abs(model_distilledssl_top1), marker='s', color='blue')

# 创建两行legend
lines = [line1, line2, None, line3, line4]
labels = first_row_labels + [''] + second_row_labels
plt.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=2, fontsize=18)  # 注意这里改为ncol=3

plt.yscale("log")
plt.xlabel("Normalized Epoch Number", fontsize=fontsize)
plt.ylabel("Abs Sharpness", fontsize=fontsize)
plt.xlim(left=0.0, right=1.0)

plt.savefig("top1_abs_eigenvalue_log.pdf", dpi=100, bbox_inches='tight')



# 创建一个新的图像用于绘制 trace
plt.figure()
first_row_labels = ['clip', 'clip_md']
second_row_labels = ['ssl', 'ssl_md']

# 使用 scatter 替换 plot
scatter1 = plt.scatter(normalized_x_clip, model_clip_traces, marker='o', color='red')
scatter2 = plt.scatter(normalized_x_clip_md, model_distilledclip_traces, marker='o', color='blue')
scatter3 = plt.scatter(normalized_x_ssl, model_ssl_traces, marker='s', color='red')  # 新增
scatter4 = plt.scatter(normalized_x_ssl_md, model_distilledssl_traces, marker='s', color='blue')

# 创建两行legend
scatters = [scatter1, scatter2, None, scatter3, scatter4]
labels_trace = first_row_labels + [''] + second_row_labels
plt.legend(scatters, labels_trace, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=2, fontsize=18)  # 注意这里改为ncol=3

plt.yscale("symlog")
plt.xlabel("Normalized Epoch Number", fontsize=fontsize)
plt.ylabel("Trace", fontsize=fontsize)
plt.xlim(left=0.0, right=1.0)

plt.savefig("trace.pdf", dpi=100, bbox_inches='tight')






