import json
from datasets import load_dataset

def clean_unrecognized_datas_viquae():
    with open("data/viquae/multiple_choice_data.json", "r") as fin:
        dataset = json.load(fin)
    
    preds = {}
    with open("outputs/analysis/viquae/Qwen2-VL-7B-Instruct/viquae_mc_recognize_T0.0.txt", "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    
    cleaned_datas = []
    for data in dataset:
        data_id = data["id"]
        pred = preds.get(data_id)
        if pred is None:
            continue
        if pred.lower() in data["original_question"].lower():
            cleaned_datas.append(data)
    print(len(cleaned_datas))
    
    with open("data/viquae/cleaned_dataset_mc_Qwen2-VL-7B-Instruct.json", "w") as fout:
        json.dump(cleaned_datas, fout, indent=2, ensure_ascii=False)
        
def clean_unrecognized_datas_infoseek():
    with open("data/infoseek/infoseek_val_mc.json", "r") as fin:
        dataset = json.load(fin)
    
    preds = {}
    with open("outputs/analysis/infoseek/Qwen2-VL-7B-Instruct/infoseek_mc_recognize_T0.0.txt", "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    
    cleaned_datas = []
    for data in dataset:
        data_id = data["data_id"]
        pred = preds.get(data_id)
        if pred is None:
            continue
        if pred.lower() in data["entity"].lower() or data["entity"].lower() in pred.lower():
            cleaned_datas.append(data)
    print(len(cleaned_datas))
    
    with open("data/infoseek/cleaned_dataset_mc_Qwen2-VL-7B-Instruct.json", "w") as fout:
        json.dump(cleaned_datas, fout, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    # clean_unrecognized_datas_viquae()
    clean_unrecognized_datas_infoseek()