import torch
from glob import glob
from tqdm import tqdm
from utils.config import TASK_INFO, MODEL_LIST
from dataset import build_dataset

gpt_samples = {
    "SEED_2": 2606,
    "MME": 1000,
    "MMBench_CN": 1994,
    "MMBench_EN": 1994,
    "MMMU": 900,
    "CMMMU": 573,
    "ScienceQA": 1467,
    "CVBench": 400
}


bar = tqdm(MODEL_LIST)
for model in bar:
    bar.set_description(model["model_path"])
    for task in TASK_INFO:
        file_list = glob(f"./BPMF_LVLM/{model['store_model_path']}/{task['dataset']}_HS/*.pth")

        if model["model_name"] in ["Gemini", "GPT4"]:
            if task["dataset"] in gpt_samples:
                assert len(file_list) == gpt_samples[task["dataset"]], f"Dataset: {task['dataset']}, Model: {model['model_name']} Expected: {gpt_samples[task['dataset']]} Got: {len(file_list)}"
            else:
                assert len(file_list) == task["sub_sampling"], f"Dataset: {task['dataset']}, Model: {model['model_name']} Expected: {task['sub_sampling']} Got: {len(file_list)}"
        else:
            if task["sub_sampling"] is not None:
                assert len(file_list) == task["sub_sampling"]
            else:   
                assert len(file_list) == task["num_samples"]

        for file in file_list:
            ins = torch.load(file)
            ins2 = torch.load(file.replace(model['store_model_path'], "BLIP2-opt-2.7B"))
            for key in ["question", "label"]:
                if isinstance(ins[key], list):
                    ins[key] = sorted(ins[key])
                    ins2[key] = sorted(ins2[key])

                if ins[key] != ins2[key]:
                    print(f"Key: {key} not equal for {file}")
                    print(ins)
                    print(ins2)
                    raise ValueError("Key not equal")
