import json
from tqdm import tqdm
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed


mode = 'shuffle'

model_list = [
    'microsoft/Phi-3.5-vision-instruct',
    'openbmb/MiniCPM-V-2_6',
    'Qwen/Qwen2-VL-7B-Instruct',
    'llava-hf/llama3-llava-next-8b-hf',
    'HuggingFaceM4/idefics2-8b',
    'THUDM/glm-4v-9b',
    'llava-hf/llava-v1.6-34b-hf',
    'OpenGVLab/InternVL2-Llama3-76B'
]

mmmu_total = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
mmmu_except = ['Architecture_and_Engineering', 'Electronics', 'Energy_and_Power', 'Materials', 'Music', 'Mechanical_Engineering']
mmmu_data = json.load(open('/data/mmmu/test.json', 'r'))
mmmu_idx_mapper = {each['id']: each['type'] for each in mmmu_data}
total_nums = {each: 0 for each in mmmu_total}
for _, v in mmmu_idx_mapper.items(): total_nums[v] += 1


def call_deepseek(prompt: str) -> str:
    client = OpenAI(api_key="sk-93445f7f52ca47e58c481a4311f8be29", base_url="https://api.deepseek.com")
    response = client.chat.completions.create(
        model = "deepseek-chat",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.7
    )
    return response.choices[0].message.content


def deepseek_judge(pred, label):
    prompt = f"""You should judge whether the given response match the given answer.
Response:
 {pred}
Answer:
 {label}
Only output the result, no need to explain, result should be one word "Yes" or "No".
Result:
"""
    while True:
        try:
            answer = call_deepseek(prompt)
            break
        except Exception as e:
            raise
            print(e)
            time.sleep(1)
    return answer == "Yes"

def ds(preds, labels):
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(deepseek_judge, pred, label) for pred, label in zip(preds, labels)]        
        results = [future.result() for future in tqdm(as_completed(futures), tot=len(preds))]
    return len([each for each in results if each == True])


if mode == 'oe':
    for model in model_list:
        name = model.split('/')[-1]
        qb_data = json.load(open(f'/result/oe/backbone/q_{name}_result.json', 'r'))
        qv_data = json.load(open(f'/result/oe/vllm/q_{name}_result.json', 'r'))
        qb_correct, qc_correct, total = 0, 0, 0
        for each in qb_data:
            if each['type'] not in mmmu_except and each['pred'].count(each['candidate']) > 0:
                qb_correct += 1
        for each in qv_data:
            if each['type'] not in mmmu_except:
                total += 1
                if each['pred'].count(each['candidate'])  > 0:
                    qc_correct += 1
        print(name, qb_correct, qc_correct, total, round(qc_correct * 100 / total, 1), round((100 - qb_correct * 100 / qc_correct), 1))


elif mode == 'shuffle':
    for model in model_list:
        name = model.split('/')[-1]
        qb_data = json.load(open(f'/result/shuffle/backbone/q_{name}_result.json', 'r'))
        qv_data = json.load(open(f'/result/shuffle/vllm/q_{name}_result.json', 'r'))
        qb_correct, qc_correct, total = 0, 0, 0
        qb_data = [each for each in qb_data if each['type'] not in mmmu_except]
        qv_data = [each for each in qv_data if each['type'] not in mmmu_except]
        qb_correct = ds([each['pred'] for each in qb_data], [each['candidate'] for each in qb_data])
        qc_correct = ds([each['pred'] for each in qv_data], [each['candidate'] for each in qv_data])
        total = len(qv_data)
        print(name, qb_correct, qc_correct, total, round(qc_correct * 100 / total, 1), round((100 - qb_correct * 100 / qc_correct), 1))