
import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 设置绘图风格
sns.set(style="whitegrid")
plt.rcParams['font.family'] = 'DejaVu Sans'

parser = argparse.ArgumentParser()
parser.add_argument('--log', type=str, required=True, help='Path to .pkl log file')
args = parser.parse_args()

# 加载数据
with open(args.log, "rb") as f:
    log_data = pickle.load(f)

mod_factor = np.array(log_data["mod_factor"])
ce_loss = np.array(log_data["ce_loss"])
kd_loss = np.array(log_data["kd_loss"])
total_loss = np.array(log_data["total_loss"])
acc = np.array(log_data["acc"])

# 可视化 1：modulating factor 分布
plt.figure(figsize=(6, 4))
sns.histplot(mod_factor, bins=40, kde=True, color="royalblue")
plt.title("Distribution of Modulating Factor")
plt.xlabel("Modulating Factor")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("mod_factor_distribution.png", dpi=300)
plt.close()

# 可视化 2：按 mod factor 分 bin 查看 loss 分布
bins = np.linspace(mod_factor.min(), mod_factor.max(), 6)
bin_indices = np.digitize(mod_factor, bins)
bin_means = []
for i in range(1, len(bins)):
    bin_total_loss = total_loss[bin_indices == i]
    bin_means.append(np.mean(bin_total_loss) if len(bin_total_loss) > 0 else 0)

plt.figure(figsize=(6, 4))
plt.plot(bins[1:], bin_means, marker='o', color="orange")
plt.title("Average Total Loss vs Modulating Factor Bin")
plt.xlabel("Modulating Factor Bin")
plt.ylabel("Average Total Loss")
plt.tight_layout()
plt.savefig("loss_vs_mod_factor.png", dpi=300)
plt.close()

# 可视化 3：Hard vs Easy 样本比较 (二值划分)
threshold = np.median(mod_factor)
is_hard = mod_factor < threshold
hard_losses = total_loss[is_hard]
easy_losses = total_loss[~is_hard]

plt.figure(figsize=(6, 4))
sns.kdeplot(hard_losses, label='Hard Samples', fill=True, color="red", linewidth=2)
sns.kdeplot(easy_losses, label='Easy Samples', fill=True, color="green", linewidth=2)
plt.title("Loss Distribution: Hard vs Easy Samples")
plt.xlabel("Total Loss")
plt.legend()
plt.tight_layout()
plt.savefig("hard_vs_easy_loss.png", dpi=300)
plt.close()

print("[INFO] Visualization saved:")
print("- mod_factor_distribution.png")
print("- loss_vs_mod_factor.png")
print("- hard_vs_easy_loss.png")
print("[DEBUG] Max mod_factor:", mod_factor.max())

