import json


model_list = [
    'microsoft/Phi-3.5-vision-instruct',
    'openbmb/MiniCPM-V-2_6',
    'Qwen/Qwen2-VL-7B-Instruct',
    'llava-hf/llama3-llava-next-8b-hf',
    'HuggingFaceM4/idefics2-8b',
    'THUDM/glm-4v-9b',
    'llava-hf/llava-v1.6-34b-hf',
    'OpenGVLab/InternVL2-Llama3-76B'
][3:4]

dataset_types = [
    'viquae',
    'infoseek',
    'sqa',
    'mmmu'
]


mmmu_total = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
mmmu_except = ['Architecture_and_Engineering', 'Electronics', 'Energy_and_Power', 'Materials', 'Music', 'Mechanical_Engineering']
mmmu_data = json.load(open('/data/mmmu/test.json', 'r'))
mmmu_idx_mapper = {each['id']: each['type'] for each in mmmu_data}

dataset_type = 'viquae'
full_data = json.load(open(f'/result/{dataset_type}/vllm/glm-4v-9b_result.json', 'r'))
idxs = [each['id'] for each in full_data]
result = []
for model in model_list:
    name = model.split('/')[-1]
    qv_result = json.load(open(f'/result/{dataset_type}/vllm/{name}_result.json', 'r'))
    qv_result = {each['id']: each for each in qv_result}
    kn_result = json.load(open(f'/result/{dataset_type}/backbone/o_{name}_result.json', 'r'))
    kn_result = {each['id']: each for each in kn_result}
    qv_data = json.load(open(f'/format_idx/qv_{name}_{dataset_type}.json', 'r'))
    kn_data = json.load(open(f'/format_idx/o_{name}_{dataset_type}.json', 'r'))

    for each in kn_data:
        if each not in qv_data:
            temp = kn_result[each]
            temp['qv_pred'] = qv_result[each]['pred']
            result.append(temp)

    # for each in qv_data:
    #     if each not in kn_data:
    #         temp = qv_result[each]
    #         temp['kn_pred'] = kn_result[each]['pred']
    #         result.append(temp)
    
    # qv_data = [each for each in qv_data if mmmu_idx_mapper[each] not in mmmu_except]
    # kn_data = [each for each in kn_data if mmmu_idx_mapper[each] not in mmmu_except]
    # both_correct, both_wrong = 0, 0
    # for each in qv_data:
    #     if each in kn_data:
    #         both_correct += 1
    
    # for each in idxs:
    #     if each not in kn_data and each not in qv_data:
    #         both_wrong += 1

    # a.append(round(both_correct / len(qv_data) * 100, 1))
    # b.append(round(both_wrong / (len(idxs) - len(kn_data)) * 100, 1))
    # print(name, both_correct, len(qv_data), both_wrong, len(idxs) - len(kn_data))

# print(a)
# print(b)

with open('/2.json', 'w') as f:
    json.dump(result, f, ensure_ascii=False, indent=4)