import argparse
import pathlib
import json
import os

def load_aokvqa(aokvqa_dir, split, version='v1p0'):
    assert split in ['train', 'val', 'test', 'test_w_ans']
    dataset = json.load(open(
        os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
    ))
    return dataset


if __name__ == '__main__':

    dataset = load_aokvqa('/home/test/yxl/MCoT/data/aokvqa', 'val')
    with open('/home/test/yxl/MCoT/aokvqa/results/qwen-test/CoT_val.json', 'r') as f:
        mc_preds = json.load(f)
    predictions = {}

    for d in dataset:
        q = d['question_id']
        predictions[q] = {}
        if mc_preds and q in mc_preds.keys():
            predictions[q]['multiple_choice'] = mc_preds[q]

    with open('/home/test/yxl/MCoT/aokvqa/results/qwen-test/CoT_val.json', 'w') as f:
        json.dump(predictions, f)