import json
import time
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed


model_list = [
    'microsoft/Phi-3.5-vision-instruct',
    'openbmb/MiniCPM-V-2_6',
    'Qwen/Qwen2-VL-7B-Instruct',
    'llava-hf/llama3-llava-next-8b-hf',
    'HuggingFaceM4/idefics2-8b',
    'THUDM/glm-4v-9b',
    'llava-hf/llava-v1.6-34b-hf',
    'OpenGVLab/InternVL2-Llama3-76B'
]


dataset_types = [
    'viquae',
    'infoseek',
    'sqa',
    'mmmu'
]


mmmu_total = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
mmmu_except = ['Architecture_and_Engineering', 'Electronics', 'Energy_and_Power', 'Materials', 'Music', 'Mechanical_Engineering']
viquae_result = []
infoseek_result = []
sqa_result = []
mmmu_full_result = []
mmmu_sample_result = []
mmmu_detail_result = [[] for _ in model_list]
mmmu_data = json.load(open('/data/mmmu/test.json', 'r'))
mmmu_idx_mapper = {each['id']: each['type'] for each in mmmu_data}
total_nums = {each: 0 for each in mmmu_total}
for _, v in mmmu_idx_mapper.items(): total_nums[v] += 1
sample_length = sum([v for k, v in total_nums.items() if k not in mmmu_except])


def call_deepseek(prompt: str) -> str:
    client = OpenAI(api_key="sk-3232b76bc35b442e9a6659ebbf5fd66e", base_url="https://api.deepseek.com")
    response = client.chat.completions.create(
        model = "deepseek-chat",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.7
    )
    return response.choices[0].message.content


def deepseek_judge(idx, pred, label):
    prompt = f"""You will get a prediction and an answer of the same question, please judge whether the prediction is correct or not.
The answer is two parts, one part is an alphabet, one part is a sentence.
If the prediction can match one part of the answer, then the prediction is correct. 
If the prediction can't match any part of the answer, then the prediction is wrong.
Prediction:
 {pred}
Answer:
 {label}
Only output the result, no need to explain, result should be one word "Yes" or "No".
Result:
"""
    while True:
        try:
            answer = call_deepseek(prompt)
            break
        except Exception as e:
            print(e)
            time.sleep(5)
    return idx, answer == "Yes"


def make_json(mode, model, dataset):
    model = model.split('/')[-1]
    if mode == 'qb':
        try: data = json.load(open(f'/result/{dataset}/backbone/q_{model}_result.json'))
        except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    elif mode == 'o':
        try: data = json.load(open(f'/result/{dataset}/backbone/o_{model}_result.json'))
        except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    elif mode == 'e':
        try: data = json.load(open(f'/result/{dataset}/vllm/e_{model}_result.json'))
        except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    elif mode == 'qv':
        try: data = json.load(open(f'/result/{dataset}/vllm/q_{model}_result.json'))
        except: 
            try: data = json.load(open(f'/result/{dataset}/vllm/{model}_result.json'))
            except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    elif mode.startswith('qr'):
        try: data = json.load(open(f'/result/{dataset}/vllm/{mode}_{model}_result.json'))
        except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    elif mode.startswith('qk'):
        try: data = json.load(open(f'/result/{dataset}/vllm/{mode}_{model}_result.json'))
        except: data = [dict(id="a", pred="b", candidate=["c"] * 10)]
    return data


def choice(ids, preds, candidates):
    result = {idx: 0 for idx in ids}
    for idx, p, c in zip(ids, preds, candidates):
        fp = p.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
        for i, cc in enumerate(c):
            if i < 6 and p.startswith(cc):
                result[idx] = 1
                break
            elif i >= 6:
                fc = cc.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
                if fp.count(fc) > 0:
                    result[idx] = 1
                    break
    return result


def viq(ids, preds, candidates):
    result = {idx: 0 for idx in ids}
    for idx, p, c in zip(ids, preds, candidates):
        fp = p.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
        for cc in c:
            fc = cc.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
            if fp.count(fc) > 0:
                result[idx] = 1
                break
    return result


def find_all_numbers(input_string):
    import re
    numbers = re.findall(r'\d+', input_string)
    return numbers


def inf(ids, preds, candidates):
    result = {idx: 0 for idx in ids}
    for idx, p, c in zip(ids, preds, candidates):
        fp = p.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
        if isinstance(c[0], dict):
            ranges = c[0]["range"]
            pre_nums = find_all_numbers(p)
            for nums in pre_nums:
                if float(nums) >= ranges[0] and float(nums) <= ranges[1]:
                    result[idx] = 1
                    break
        else:
            for cc in c:
                fc = cc.lower().replace('"', '').replace("'", "").replace('the ', '').replace('-', ' ')
                if fp.count(fc) > 0:
                    result[idx] = 1
                    break
    return result


def ds(ids, preds, labels):
    with ThreadPoolExecutor(max_workers=50) as executor:
        futures = [executor.submit(deepseek_judge, idx, pred, label) for idx, pred, label in zip(ids, preds, labels)]        
        results = [future.result() for future in tqdm(as_completed(futures), total=len(futures))]
    return {each[0]: each[1] for each in results}


def return_result(mode, model, dataset):
    json_data = make_json(mode, model, dataset)
    if dataset == "mmmu":
        ids = [each['id'] for each in json_data]
        preds = [each['pred'] for each in json_data]
        labels = [each['candidate'][-5] for each in json_data]
        candidates = [each['candidate'] for each in json_data]
        result = ds(ids, preds, labels)
    elif dataset == 'viquae':
        ids = [each['id'] for each in json_data]
        preds = [each['pred'] for each in json_data]
        candidates = [each['candidate'] for each in json_data]
        result = viq(ids, preds, candidates)
    elif dataset == 'infoseek':
        ids = [each['id'] for each in json_data]
        preds = [each['pred'] for each in json_data]
        candidates = [each['candidate'] for each in json_data]
        result = inf(ids, preds, candidates)
    elif dataset == 'sqa':
        ids = [each['id'] for each in json_data]
        preds = [each['pred'] for each in json_data]
        labels = [each['candidate'][-5] for each in json_data]
        candidates = [each['candidate'] for each in json_data]
        result = ds(ids, preds, labels)
    return result


def pipelines(mode, model_list, dataset_types):
    if mode.startswith('qr') or mode.startswith('qk'):
        dataset_types = ['viquae', 'infoseek']
    pipeline_result = {}
    for dataset in dataset_types:
        for model in model_list:
            result = return_result(mode, model, dataset)
            pipeline_result[(dataset, model)] = result

    for i, model in enumerate(model_list):
        result = pipeline_result[('infoseek', model)]
        counts = sum(result.values())
        idxs = [k for k, v in result.items() if v == 1]
        with open(f'/format_idx/{mode}_{model.split("/")[-1]}_infoseek.json', 'w') as f: json.dump(idxs, f, indent=4)
        infoseek_result.append(f"{round(counts * 100 / len(result), 2)}%")

    for i, model in enumerate(model_list):
        result = pipeline_result[('viquae', model)]
        counts = sum(result.values())
        idxs = [k for k, v in result.items() if v == 1]
        with open(f'/format_idx/{mode}_{model.split("/")[-1]}_viquae.json', 'w') as f: json.dump(idxs, f, indent=4)
        viquae_result.append(f"{round(counts * 100 / len(result), 2)}%")

    if mode.startswith('qr') or mode.startswith('qk'):
        data = [dict(m=m, v=v, i=i) for m, v, i in zip(model_list, viquae_result, infoseek_result)]
        df = pd.DataFrame(data)
        df.columns = ['model', 'viquae', 'infoseek']
        df.set_index('model', inplace=True)
        df.to_csv(f'/format_mtc/{mode}.csv')
        return
    
    for i, model in enumerate(model_list):
        temp_result = {kind: 0 for kind in mmmu_total}
        result = pipeline_result[('mmmu', model)]
        excepts, fulls = 0, 0
        idxs = []
        for k, v in result.items():
            if v == 1: 
                temp_result[mmmu_idx_mapper[k]] += 1
                fulls += 1
                idxs.append(k)
                if k not in mmmu_except: excepts += 1
        with open(f'/format_idx/{mode}_{model.split("/")[-1]}_mmmu.json', 'w') as f: json.dump(idxs, f, indent=4)
        mmmu_detail_result[i] = [f"{round(v * 100 / total_nums[k], 2)}%" for k, v in temp_result.items()]
        mmmu_full_result.append(f"{round(fulls * 100 / len(result), 2)}%")
        mmmu_sample_result.append(f"{round(excepts * 100 / sample_length, 2)}%")
    
    for i, model in enumerate(model_list):
        result = pipeline_result[('sqa', model)]
        counts = sum(result.values())
        idxs = [k for k, v in result.items() if v == 1]
        with open(f'/format_idx/{mode}_{model.split("/")[-1]}_sqa.json', 'w') as f: json.dump(idxs, f, indent=4)
        sqa_result.append(f"{round(counts * 100 / len(result), 2)}%")
    
    data = [dict(m=m, v=v, i=i, s=s, mm=mm, mm_=mm_) for m, v, i, s, mm, mm_ in zip(model_list, viquae_result, infoseek_result, sqa_result, mmmu_sample_result, mmmu_full_result)]
    df = pd.DataFrame(data)
    df.columns = ['model', 'viquae', 'infoseek', 'sqa', 'mmmu', 'mmmu_full']
    df.set_index('model', inplace=True)
    df.to_csv(f'/format_mtc/{mode}.csv')

    df = pd.DataFrame(mmmu_detail_result).T
    df.columns = model_list
    df.index = mmmu_total
    df.to_csv(f'/format_mtc/{mode}_mmmu_detail.csv')
    

mode = "qv"
pipelines(mode, model_list, dataset_types)