import os
import re
import json

datasettype = "viquae"
modeltype = "OpenGVLab/InternVL2-Llama3-76B"
modaltype = "vllm"

mode = "q"
name = modeltype.split('/')[-1] if mode is None else mode + '_' + modeltype.split('/')[-1]
with open(f'/result/{datasettype}/{modaltype}/{name}_result.json', 'r') as f:
    data = json.load(f)
    
mmmu_total = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
mmmu_except = ['Architecture_and_Engineering', 'Electronics', 'Energy_and_Power', 'Materials', 'Music', 'Mechanical_Engineering']
mmmu_dict = {}


def find_all_numbers(input_string):
    numbers = re.findall(r'\d+', input_string)
    return numbers

count = 0
idx = []
for each in data:
    if mode == 'e':
        entity = each['entity'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
        pred = each['pred'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
        if entity in pred:
            count += 1
            idx.append(each['id'])

    elif datasettype == 'viquae':
        pred = each['pred'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
        for l in each['candidate']:
            label = l.replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
            if label in pred: 
                count += 1
                idx.append(each['id'])
                break
    
    elif datasettype == 'sqa':
        for i, l in enumerate(each['candidate']):
            if i < 6 and each['pred'].startswith(l):
                count += 1
                idx.append(each['id'])
                break
            elif i >= 6:
                label = l.replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
                pred = each['pred'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
                if pred.count(label) > 0:
                    count += 1
                    idx.append(each['id'])
                    break

    elif datasettype == 'infoseek':
        answer = each['candidate']
        pred = each['pred'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
        if isinstance(answer[0], dict):
            ranges = answer[0]["range"]
            pre_nums = find_all_numbers(pred)
            for nums in pre_nums:
                if float(nums) >= ranges[0] and float(nums) <= ranges[1]:
                    count += 1
                    idx.append(each['id'])
                    break
        else:
            for ans in answer:
                ans = ans.replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
                if pred.count(ans) > 0:
                    count += 1
                    idx.append(each['id'])
                    break

    elif datasettype == 'mmmu':
        if each['type'] in mmmu_except: continue
        kind = each['type']
        if kind not in mmmu_dict: mmmu_dict[kind] = [0, 0]
        for i, l in enumerate(each['candidate']):
            if i < 6 and each['pred'].startswith(l):
                count += 1
                idx.append(each['id'])
                mmmu_dict[kind][0] += 1
                break
            elif i >= 6:
                label = l.replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
                pred = each['pred'].replace('"', '').replace("'", "").lower().replace('the ', '').replace(' ', '-')
                if pred.count(label) > 0:
                    count += 1
                    idx.append(each['id'])
                    mmmu_dict[kind][0] += 1
                    break
        mmmu_dict[kind][1] += 1

print(len(data))
print(f"{count} / {round(count * 100 / len(data), 2)}%")
if datasettype == 'mmmu':
    print(sum([each[1] for each in mmmu_dict.values()]))
    print(f"{count} / {round(count * 100 / sum([each[1] for each in mmmu_dict.values()]), 2)}%")
    for k, v in mmmu_dict.items():
        print(f"{k}: {round(v[0] * 100 / v[1], 2)}%")

folder =f'/idx/{datasettype}/{modaltype}/'
if not os.path.exists(folder): os.makedirs(folder)
with open(f'/idx/{datasettype}/{modaltype}/{name}_result.json', 'w') as f:
    json.dump(idx, f, ensure_ascii=False, indent=4)