import os
import csv
import json
import pickle
from tqdm import tqdm


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r", encoding="utf-8") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def compare_model_modality(model, modality):
    print(f"Model: {model}")
    print(f"Modality: {modality}")

    ### Loading pkl and solutions
    with open("pkl/data_test.pkl", 'rb') as file:
        data_test = pickle.load(file)
    solutions = get_data(os.path.join(f"benchmark/{model}/{modality}"))

    ### Creating output directories
    comparison_model_dir = os.path.join(f"comparison/{model}")
    if not os.path.exists(comparison_model_dir):
        os.makedirs(comparison_model_dir)
    output_file = os.path.join(comparison_model_dir, f"{modality}.csv")

    ### Verifying
    verifications = []
    types = ['Recognition', 'Understanding', 'Grounding', 'Reasoning']
    for idx, (item, sol) in tqdm(enumerate(zip(data_test, solutions)), total=len(data_test), desc="Processing diagrams"):
        s_list = [0 if sol['solution'][t]['choice'] not in "ABCD" else int(item['QAs'][t]['Answer'] == sol['solution'][t]['choice']) for t in types]
        verifications.append((idx, s_list))

    ### Saving results
    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['idx'] + types)
        for idx, v in verifications:
            writer.writerow([idx] + v)


def main():
    # model_list = os.listdir("benchmark")
    # modality_list = ["real", "synthetic", "triple"]
    # for model in model_list:
    #     for modality in modality_list:
    #         compare_model_modality(model, modality)

    compare_model_modality("Qwen2.5-VL-3B-Instruct", "empty")
    compare_model_modality("Qwen2.5-VL-7B-Instruct", "empty")
    compare_model_modality("Qwen2.5-VL-32B-Instruct", "empty")
    compare_model_modality("Qwen2.5-VL-72B-Instruct", "empty")
    compare_model_modality("llava-v1.6-vicuna-7b-hf", "empty")
    compare_model_modality("llava-v1.6-vicuna-13b-hf", "empty")
    compare_model_modality("llava-v1.6-34b-hf", "empty")


if __name__ == '__main__':
    main()

