# Convert VLMEvalKit tsv files to HF datasets and save images to disk for image tokenization
import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict, concatenate_datasets
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from utils import dataset_URLs, download_file, decode_base64_to_image_file
import string
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help='path to save the output')
    parser.add_argument('--dataset_name', type=str, default='COCO', help='dataset name')
    parser.add_argument('--download_datasets', action='store_true', help='download datasets')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    print('Number of available gpus:', torch.cuda.device_count())
    
    try:
        print('GPU model name:', torch.cuda.get_device_name(0))
        print('GPU memory size:', torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024, 'GB')
    except:
        print('No GPU available.')
    
    args = parser.parse_args()
    
    return args

def download_datasets(args):
    # we only save the used set of each dataset to save disk space
    raw_datasets = DatasetDict()
    raw_datasets['test'] = load_dataset("HuggingFaceM4/COCO", trust_remote_code=True, name="2014_captions", split='test')
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'COCO'), max_shard_size="20GB")
    raw_datasets = DatasetDict()
    raw_datasets["validation"] = load_dataset("HuggingFaceM4/NoCaps", trust_remote_code=True, split="validation")
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'NoCaps'), max_shard_size="20GB")
    raw_datasets = DatasetDict()
    raw_datasets['test'] = load_dataset("umaru97/flickr30k_train_val_test", trust_remote_code=True, split='test')
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'Flickr30K'), max_shard_size="20GB")
    
    raw_datasets_val = DatasetDict()
    raw_datasets_val['validation'] = load_dataset("HuggingFaceM4/VQAv2", trust_remote_code=True, split='validation')
    print(raw_datasets_val)
    raw_datasets_val.save_to_disk(os.path.join(args.output_dir, 'VQAv2_VAL'), max_shard_size="20GB")
    raw_datasets_test = DatasetDict()
    raw_datasets_test['testdev'] = load_dataset("HuggingFaceM4/VQAv2", trust_remote_code=True, split='testdev')
    print(raw_datasets_test)
    raw_datasets_test.save_to_disk(os.path.join(args.output_dir, 'VQAv2_TEST'), max_shard_size="20GB")
    raw_datasets = load_dataset("Multimodal-Fatima/OK-VQA_test", trust_remote_code=True)
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'OK-VQA'), max_shard_size="20GB")
    raw_datasets_val = DatasetDict()
    raw_datasets_val['validation'] = load_dataset("Multimodal-Fatima/VizWiz", split='validation')
    print(raw_datasets_val)
    raw_datasets_val.save_to_disk(os.path.join(args.output_dir, 'VizWiz_VAL'), max_shard_size="20GB")
    raw_datasets_test = DatasetDict()
    raw_datasets_test['test'] = load_dataset("Multimodal-Fatima/VizWiz", split='test')
    print(raw_datasets_test)
    raw_datasets_test.save_to_disk(os.path.join(args.output_dir, 'VizWiz_TEST'), max_shard_size="20GB")
    raw_datasets = load_dataset("Multimodal-Fatima/TextVQA_validation")
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'TextVQA'), max_shard_size="20GB")
    raw_datasets_test = DatasetDict()
    raw_datasets_test['testmini'] = load_dataset("AI4Math/MathVista", split='testmini')
    print(raw_datasets_test)
    def transform_answer_math_vista(example):
        if example['question_type'] == 'multi_choice':
            example['answer_transformed'] = string.ascii_uppercase[example['choices'].index(example['answer'])]
        else:
            example['answer_transformed'] = example['answer']
        example['language'] = example['metadata']['language']
        example['source'] = example['metadata']['source']
        example['category_2'] = example['metadata']['category']
        example['category'] = example['metadata']['task']
        example['context'] = example['metadata']['context']
        example['grade'] = example['metadata']['grade']
        example['skills'] = example['metadata']['skills']

        return example
    raw_datasets_test = raw_datasets_test.map(transform_answer_math_vista)
    # raw_datasets_test.save_to_disk(os.path.join(args.output_dir, 'MathVista_Mix'), max_shard_size="20GB")
    raw_datasets_test.filter(lambda x: x['question_type'] == 'multi_choice').save_to_disk(os.path.join(args.output_dir, 'MathVista_MultiChoice'), max_shard_size="20GB")
    raw_datasets_test.filter(lambda x: x['question_type'] == 'free_form').save_to_disk(os.path.join(args.output_dir, 'MathVista_OpenEnded'), max_shard_size="20GB")
    
    raw_datasets = load_dataset("facebook/winoground")
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'Winoground'), max_shard_size="20GB")
    transform_raw_dataset = []
    def transform_winoground(example):
        # https://twitter.com/ChengleiSi/status/1731047075528561119
        image_0, image_1, caption_0, caption_1 = example['image_0'], example['image_1'], example['caption_0'], example['caption_1']
        transform_raw_dataset.append({
            "image": image_0,
            "caption": caption_0,
            "answer": 'yes',
            'id': example['id'],
            'question_id': example['id'] * 4,
            'tag': example['tag'],
            'secondary_tag': example['secondary_tag'],
            'num_main_preds': example['num_main_preds'],
            'collapsed_tag': example['collapsed_tag']
        })
        transform_raw_dataset.append({
            "image": image_0,
            "caption": caption_1,
            "answer": 'no',
            'id': example['id'],
            'question_id': example['id'] * 4 + 1,
            'tag': example['tag'],
            'secondary_tag': example['secondary_tag'],
            'num_main_preds': example['num_main_preds'],
            'collapsed_tag': example['collapsed_tag']
        })
        transform_raw_dataset.append({
            "image": image_1,
            "caption": caption_0,
            "answer": 'no',
            'id': example['id'],
            'question_id': example['id'] * 4 + 2,
            'tag': example['tag'],
            'secondary_tag': example['secondary_tag'],
            'num_main_preds': example['num_main_preds'],
            'collapsed_tag': example['collapsed_tag']
        })
        transform_raw_dataset.append({
            "image": image_1,
            "caption": caption_1,
            "answer": 'yes',
            'id': example['id'],
            'question_id': example['id'] * 4 + 3,
            'tag': example['tag'],
            'secondary_tag': example['secondary_tag'],
            'num_main_preds': example['num_main_preds'],
            'collapsed_tag': example['collapsed_tag']
        })
    raw_datasets.map(transform_winoground, desc="Transforming Winoground")
    transform_raw_dataset = {"test": Dataset.from_list(transform_raw_dataset)}
    transform_raw_dataset = DatasetDict(transform_raw_dataset)
    print(transform_raw_dataset)
    transform_raw_dataset.save_to_disk(os.path.join(args.output_dir, 'Winoground-YN'), max_shard_size="20GB")

    raw_datasets = load_dataset("nlphuji/whoops")
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'WHOOPS-Caption'), max_shard_size="20GB")
    # below for WHOOPS-VQA, but its VQA annotations are not very good, so we don't use it
    transform_raw_dataset = []
    question_id = 0
    reserve_column_names = ['designer_explanation', 'selected_caption', 'crowd_captions', 'crowd_explanations', 'commonsense_category', 'image_id', 'image_designer']
    def transform_whoops_vqa(example, idx):
        qa_pairs = example['question_answering_pairs']
        num_qa_pair = len(qa_pairs)
        global question_id
        for i in range(num_qa_pair):
            new_example = {}
            for key in reserve_column_names:
                new_example[key] = example[key]
            new_example['image_index'] = idx
            new_example['question_id'] = question_id
            question_id += 1
            new_example['question'] = qa_pairs[i][0]
            new_example['answer'] = qa_pairs[i][1]
            transform_raw_dataset.append(new_example)
    raw_datasets.map(transform_whoops_vqa, with_indices=True, desc="Transforming WHOOPS VQA")
    transform_raw_dataset = {"test": Dataset.from_list(transform_raw_dataset)}
    transform_raw_dataset = DatasetDict(transform_raw_dataset)
    print(transform_raw_dataset)
    transform_raw_dataset.save_to_disk(os.path.join(args.output_dir, 'WHOOPS-VQA'), max_shard_size="20GB")

    # raw_datasets = load_dataset("Otter-AI/POPE")
    raw_datasets = load_dataset("lmms-lab/POPE")
    def fix_spelling_errors(example):
        if "the imange" in example['question']:
            example['question'] = example['question'].replace("the imange", "the image")
        return example
    raw_datasets = raw_datasets.map(fix_spelling_errors)
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'POPE'), max_shard_size="20GB")
    
    # raw_datasets = load_dataset("Shengcao1006/MMHal-Bench")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'MMHal-Bench'), max_shard_size="20GB")

    ##### USE VLMEvalKiT# #####
    # raw_datasets = load_dataset("rayguan/HallusionBench")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'HallusionBench'), max_shard_size="20GB")

    # raw_datasets = load_dataset("Otter-AI/MME")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'MME'), max_shard_size="20GB")
    # raw_datasets = load_dataset("Otter-AI/MMBench")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'MMBench'), max_shard_size="20GB")
    # raw_datasets = load_dataset("Otter-AI/MMVet")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'MMVet'), max_shard_size="20GB")
    # raw_datasets = load_dataset("AILab-CVC/SEED-Bench")
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'SEED-Bench'), max_shard_size="20GB")
    ##### END #####

    raw_datasets = load_dataset("liuhaotian/llava-bench-in-the-wild")
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'llava-bench-in-the-wild'), max_shard_size="20GB")

    
    subsets_MMMU = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
    # https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/data_utils.py
    DOMAIN_CAT2SUB_CAT = {
        'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
        'Business': ['Accounting', 'Economics', 'Finance', 'Manage','Marketing'],
        'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics',],
        'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health'],
        'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
        'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
    }
    subset_to_category = {}
    for category, sub_categories in DOMAIN_CAT2SUB_CAT.items():
        for sub_category in sub_categories:
            subset_to_category[sub_category] = category
    datasetdict_list = []
    for subset in subsets_MMMU:
        datasetdict_list.append(load_dataset("MMMU/MMMU", subset))
        # print(subset, datasetdict_list[-1])
        for split in datasetdict_list[-1].keys():
            datasetdict_list[-1][split] = datasetdict_list[-1][split].add_column("subset", [subset for _ in range(len(datasetdict_list[-1][split]))])
        # print(subset, datasetdict_list[-1])
    datasetdict_MMMU = DatasetDict()
    for split in ["dev", "validation", "test"]:
        datasetdict_MMMU[split] = concatenate_datasets([x[split] for x in datasetdict_list])
    datasetdict_MMMU = datasetdict_MMMU.map(lambda x: {'category': subset_to_category[x['subset']]}, num_proc=args.process_num_workers)
    print(datasetdict_MMMU)
    datasetdict_MMMU_TEST = datasetdict_MMMU.pop('test')
    datasetdict_MMMU.pop('dev')
    datasetdict_MMMU.filter(lambda x: x['question_type'] == 'multiple-choice').save_to_disk(os.path.join(args.output_dir, 'MMMU_VAL_MultiChoice'), max_shard_size="20GB")
    datasetdict_MMMU.filter(lambda x: x['question_type'] == 'open').save_to_disk(os.path.join(args.output_dir, 'MMMU_VAL_OpenEnded'), max_shard_size="20GB")
    datasetdict_MMMU_TEST.filter(lambda x: x['question_type'] == 'multiple-choice').save_to_disk(os.path.join(args.output_dir, 'MMMU_TEST_MultiChoice'), max_shard_size="20GB")
    datasetdict_MMMU_TEST.filter(lambda x: x['question_type'] == 'open').save_to_disk(os.path.join(args.output_dir, 'MMMU_TEST_OpenEnded'), max_shard_size="20GB")

    # https://huggingface.co/datasets/cais/mmlu
    # https://github.com/FranxYao/chain-of-thought-hub/blob/main/MMLU/run_mmlu_open_source.py
    subsets_MMLU =  ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
    datasetdict_list = []
    for subset in subsets_MMLU:
        datasetdict_list.append(load_dataset("cais/mmlu", subset))
    datasetdict_MMLU = DatasetDict()
    for split in ["dev", "test"]:
        datasetdict_MMLU[split] = concatenate_datasets([x[split] for x in datasetdict_list])
    def transform_answer(example):
        example['answer_transformed'] = string.ascii_uppercase[example['answer']]
        return example
    datasetdict_MMLU = datasetdict_MMLU.map(transform_answer)
    print(datasetdict_MMLU)
    datasetdict_MMLU.save_to_disk(os.path.join(args.output_dir, 'MMLU'), max_shard_size="20GB")
    
    # https://huggingface.co/datasets/piqa
    raw_datasets = load_dataset("piqa")
    raw_datasets.pop('train')
    raw_datasets.pop('test')
    raw_datasets = raw_datasets.rename_columns({'goal': 'question', 'sol1': 'A', 'sol2':'B', 'label': 'answer'})
    raw_datasets = raw_datasets.map(transform_answer)
    print(raw_datasets)
    raw_datasets.save_to_disk(os.path.join(args.output_dir, 'PIQA_VAL'), max_shard_size="20GB")

    # # need to git clone to download the images for SEED-Bench-2
    # raw_datasets = load_dataset("AILab-CVC/SEED-Bench-2", trust_remote_code=True)
    # print(raw_datasets)
    # raw_datasets.save_to_disk(os.path.join(args.output_dir, 'SEED-Bench-2'), max_shard_size="20GB")
    
def main():
    args = parse_args()

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
    else:
        print("There is no `args.=output_dir` specified! Model checkpoints will not be saved.")
        exit()
    
    if args.download_datasets:
        download_datasets(args)
        exit()

if __name__ == "__main__":
    main()
