import json
import re

path_prefix = ""

def budget_cnt(solution_str):
    regex = r"^<budget>([^<]*(?:<(?!/?budget>)[^<]*)*)<\/budget><solution>([\s\S]*?)<\/solution>$"
        # regex = r"^<budget>(\d+)<\/budget>\n<solution>([\s\S]*?)<\/solution>$"

    match = re.search(regex, solution_str, re.DOTALL)
    # if the format is not correct, reward is 0
    if match is None or len(match.groups()) != 2:
        # format_reward = 0.0
        return None, None, None
    else:
        try:
            budgetcnt = int(match.group(1).strip())
            pred_solution = match.group(2).strip()
            return budgetcnt, 0.5 * budgetcnt, 1.5 * budgetcnt
        except:
            return None, None, None


if __name__ == "__main__":
    data_path = f"{path_prefix}/data/resdata/s1-train-budget_output.json"
    output_path = f"{path_prefix}/data/resdata/s1-train-budget_output_budget.json"
    data = json.load(open(data_path, "r"))
    for item in data:
        item["budget"], budget_min, budget_max = budget_cnt(item["model_response"])
        if item["budget"]:
            item['budget_match'] = budget_min <= item["response_length"] <= budget_max
        else:
            item['budget_match'] = None

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

    print("Results saved to", output_path)
    print(f"Meaningful Budget rate: {len([t for t in data if t['budget_match'] is not None]) / len(data):.2%}")
    print(f"Budget match rate: {len([t for t in data if t['budget_match']]) / len(data):.2%}")