import json
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from matplotlib.patches import Patch

datasets = ["AIME24", "AIME25", "AMC"]
dataset_name = {
    "AIME24": "AIME 2024",
    "AIME25": "AIME 2025",
    "AMC": "AMC 2023",
}

leg = ["Qwen3-1.7B", "DAPO", "DAPO-HAMMER"]
colors = ["#FF8C00", "#00BFFF", "#FF6347"] 
fontsize = 16

with open("parsed_metrics.json", "r") as f:
    results = json.load(f)

# ---------------- 绘图 ----------------
fig, axes = plt.subplots(nrows=3, ncols=len(datasets), figsize=(4*len(datasets), 6.5))

for i, lab in enumerate(leg):
    for col, dataset in enumerate(datasets):
        xs = np.array(results[dataset][lab]["pass1"])
        ys_10 = np.array(results[dataset][lab]["pass10"])
        ys_100 = np.array(results[dataset][lab]["pass100"])
        ys_cons = np.array(results[dataset][lab]["cons100"])

        for r, ys_data in zip(range(3), [ys_10, ys_100, ys_cons]):
            ax = axes[r, col]
            ax.scatter(xs, ys_data, color=colors[i], s=20, alpha=0.5, edgecolor='k', linewidth=0.2)

            # KDE density
            xy = np.vstack([xs, ys_data])
            kde = gaussian_kde(xy)
            x_grid = np.linspace(min(xs), max(xs), 100)
            y_grid = np.linspace(min(ys_data), max(ys_data), 100)
            X, Y = np.meshgrid(x_grid, y_grid)
            Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)

            ax.contourf(X, Y, Z, levels=5, colors=[colors[i]], alpha=0.25)

            ax.grid(True, linestyle='--', alpha=0.5)
            ax.set_axisbelow(True)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

for col, dataset in enumerate(datasets):
    axes[0, col].set_title(dataset_name[dataset], fontsize=fontsize + 2, pad=12)

axes[0, 0].set_ylabel("pass@10", fontsize=fontsize)
axes[1, 0].set_ylabel("pass@100", fontsize=fontsize)
axes[2, 0].set_ylabel("cons@100", fontsize=fontsize)

for c in range(len(datasets)):
    axes[2, c].set_xlabel("pass@1", fontsize=fontsize)

proxy_handles = [Patch(facecolor=colors[i], alpha=0.25, edgecolor='k') for i in range(len(leg))]
fig.legend(
    proxy_handles, leg,
    loc="lower center",
    ncol=len(leg),
    fontsize=fontsize,
    frameon=False,
    bbox_to_anchor=(0.5, 0.92)
)

fig.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig("pareto_curves_density_beauty.png", dpi=300, bbox_inches='tight')
plt.show()
