import json
import re
from tqdm import tqdm
from collections import defaultdict
import random
import subprocess

def extract_last_num(text: str) -> float:
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", text)  # 处理形如 123,456
    res = re.findall(r"\\boxed\{(\d+(\.\d+)?)", text)  # 匹配 123456.789
    if len(res) == 0:
        res = re.findall(r"(\d+(\.\d+)?)", text)  # 匹配 123456.789
    if len(res) > 0:
        num_str = res[-1][0]
        return float(num_str)
    else:
        return 0.0

model_name=""
dataset=""
input_path = f"../results/{model_name}-{dataset}.jsonl"
test_path = f"../../math_training/data/{model_name}.json"
result = subprocess.run(['wc', '-l', input_path], stdout=subprocess.PIPE, text=True)
line_count = int(result.stdout.split()[0])

data = defaultdict(list)

result = defaultdict(list)
files = {k: open(f"../../math_training/data/{model_name}-{dataset}-aug_{k}_32.jsonl", "w") for k in [0.5, 1, 5, 10]}
average_cot_success_number = 0


test_set = []
if test_path:
    with open(test_path) as f:
        test_set = json.load(f)
        test_set = set(i["instruction"] for i in test_set)

with open(input_path, 'r') as f:
    for line in tqdm(f, total=line_count):
        item = json.loads(line)
        if test_path and item["instruction"] in test_set: continue
        id = item["id"]
        max_fluct = item["max_fluct"]
        for generated_text in item["generated_texts"]:
            if abs(extract_last_num(generated_text) - item['answer']) < 1e-3:
                temp_item = {
                    "id": id,
                    "max_fluct": max_fluct,
                    "instruction": item["instruction"],
                    "code": item["code"],
                    "output": generated_text,
                    "system": "Please reason step by step, and put your final answer within \\boxed{}.",
                    "input": "",
                    "history": [],
                    "answer": item['answer'],
                }
                if random.random() < 0:
                    test_set.append(temp_item)
                else:
                    files[max_fluct].write(json.dumps(temp_item, ensure_ascii=False) + "\n")
                    # temp_results[max_fluct].append(temp_item)
                break


for file in files.values():
    file.close()