import json
import os

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.helpers.utils import rget_json_files_from_dir

DATA_DIR = "anonymous/anonymous/inference-rlhf/data/math/qwen-25-7b"

json_files = rget_json_files_from_dir(DATA_DIR)

p_at_1s = []
for json_file in tqdm(json_files, desc="Processing json files"):
    if json_file.endswith("generations-123.json"):
        with open(json_file, "r") as f:
            data = json.load(f)
        p_at_1 = sum([d["strict_correct"] for d in data]) / len(data)
        p_at_1s.append(p_at_1)

print(np.mean(p_at_1s))