import os
import json
import random
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
import verl.utils.reward_score.math as verl_math
import numpy as np

# ---------------- 配置 ----------------
datasets = ["AIME24", "AIME25", "AMC", "Olympiad"]
leg = ["Qwen3-1.7B", "DAPO", "DAPO-HAMMER"]
leg2 = {
    "Qwen3-1.7B": "Qwen3-1.7B",
    "DAPO": "pareto-DAPO",
    "DAPO-HAMMER": "pareto-DAPO-HAMMER",
}

S = 64
T = 10   

verify_func = math_metric(
    gold_extraction_target=(LatexExtractionConfig(),),
    pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
)

def compute_pass_at_item(item, k, test_times=5):
    ground_truth = item["gound_truth"]
    outputs = item["output"]
    pp = 0
    for _ in range(test_times):
        for i in range(k):
            idx = random.randint(0, len(outputs) - 1)
            try:
                string_in_last_boxed = verl_math.last_boxed_only_string(outputs[idx])
            except Exception:
                string_in_last_boxed = None
            answer_idx = verl_math.remove_boxed(string_in_last_boxed) if string_in_last_boxed else "0"
            try:
                score, _ = verify_func([f"\[{ground_truth}\]"], [f"\[{answer_idx}\]"])
            except Exception:
                score = 0
            if score >= 1:
                pp += 1
                break
    return pp / test_times

def compute_cons_at_item(item, k):
    ground_truth = item["gound_truth"]
    outputs = item["output"]
    answer_list = []
    for i in range(min(k, len(outputs))):
        try:
            string_in_last_boxed = verl_math.last_boxed_only_string(outputs[i])
        except Exception:
            string_in_last_boxed = None
        answer_list.append(string_in_last_boxed)
    if not answer_list:
        return 0
    most_common_answer = max(set(answer_list), key=answer_list.count)
    most_common_answer = verl_math.remove_boxed(most_common_answer) if most_common_answer else "0"
    try:
        score, _ = verify_func([f"\[{ground_truth}\]"], [f"\[{most_common_answer}\]"])
    except Exception:
        score = 0
    return score

def compute_avg_metrics(data, dataset, T):
    sampled = random.sample(data, min(T, len(data)))
    p1 = sum(compute_pass_at_item(item, 1) for item in sampled) / len(sampled)
    p10 = sum(compute_pass_at_item(item, 10) for item in sampled) / len(sampled)
    if dataset == "Olympiad":
        p100 = sum(compute_pass_at_item(item, 32) for item in sampled) / len(sampled)
        c100 = sum(compute_cons_at_item(item, 32) for item in sampled) / len(sampled)
    else:
        p100 = sum(compute_pass_at_item(item, 100) for item in sampled) / len(sampled)
        c100 = sum(compute_cons_at_item(item, 100) for item in sampled) / len(sampled)
    return p1, p10, p100, c100

results = {}

for dataset in datasets:
    results[dataset] = {}
    for lab in leg:
        file_path = os.path.join(dataset, f"{leg2[lab]}.json")
        with open(file_path, "r") as f:
            data = json.load(f)
        
        xs_1, ys_10, ys_100, ys_cons = [], [], [], []
        for _ in range(S):
            p1, p10, p100, c100 = compute_avg_metrics(data, dataset, T)
            # remove redandant points
            xs_1.append(p1 + np.random.uniform(-1e-5, 1e-5))
            ys_10.append(p10)
            ys_100.append(p100)
            ys_cons.append(c100)

        results[dataset][lab] = {
            "pass1": xs_1,
            "pass10": ys_10,
            "pass100": ys_100,
            "cons100": ys_cons
        }

with open("parsed_metrics.json", "w") as f:
    json.dump(results, f, indent=2)
