
import datasets
import json
from tqdm import tqdm

from inference_rlhf.code.helpers.utils import set_seeds, rget_json_files_from_dir

GPT_4O_PATH = "anonymous/anonymous/inference-rlhf/data/math/gpt-4o-mini"

def main():
    # Load math data
    math_data = datasets.load_dataset("DigitalLearningGmbH/MATH-lighteval", "default", split="test", trust_remote_code=True)
    math_questions = [data["problem"] for data in math_data]

    # Load math 500 data
    math_500_data = datasets.load_dataset("HuggingFaceH4/MATH-500", "default", split="test", trust_remote_code=True)
    math_500_questions = set([data["problem"] for data in math_500_data])

    overlap_indices = []
    for i, question in enumerate(math_questions):
        if question in math_500_questions:
            overlap_indices.append(i)
    
    print('Found {} overlap questions'.format(len(overlap_indices)))
    print(f"Overlap indices: {overlap_indices}")

    # Find gpt-4o-pass@1 for all the overlap indices
    json_files = rget_json_files_from_dir(GPT_4O_PATH)
    prompt_idx_to_pass_at_1 = dict()
    for file in tqdm(json_files, desc="Reading gpt-4o-pass@1 ..."):
        with open(file, 'r') as f:
            data = json.load(f)
        prompt_idx = data[0]['prompt_idx']
        results = [d['strict_correct'] if 'strict_correct' in d else d['correct'] for d in data]
        prompt_idx_to_pass_at_1[prompt_idx] = sum(results) / len(results)

    # print gpt-4o-pass@1 for all the overlap indices in sorted order
    overlap_indices.sort(key=lambda x: prompt_idx_to_pass_at_1[x])
    for idx in overlap_indices[:100]:
        print(f"GPT-4o-pass@1 for question {idx}: {prompt_idx_to_pass_at_1[idx]}")

    breakpoint()
    print()


if __name__ == "__main__":
    main()