# Convert VLMEvalKit tsv files to HF datasets and save images to disk for image tokenization
import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from utils import dataset_URLs, download_file, decode_base64_to_image_file
import argparse

def tsv_to_jsonl(tsv_path, jsonl_path, dataset_name):
    df = pd.read_csv(tsv_path, sep='\t')
    print(df.columns)
    df = df.sort_values(by=['index'])
    df.to_json(jsonl_path, orient='records', lines=True)
    if dataset_name == 'MME':
        image_path = df['image_path'].tolist()
        image_path = list(dict.fromkeys(image_path).keys())
        return {image_path: idx for idx, image_path in enumerate(image_path)}

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help='path to save the output')
    parser.add_argument('--dataset_name', type=str, default='HallusionBench', help='dataset name')
    parser.add_argument('--dataset_name_list', type=str, default='MME,SEEDBench_IMG,MMVet,ScienceQA_VAL,ScienceQA_TEST,HallusionBench', help='dataset name list') # 'MME,SEEDBench_IMG,MMVet,ScienceQA_VAL,ScienceQA_TEST,HallusionBench'
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    print('Number of available gpus:', torch.cuda.device_count())
    
    try:
        print('GPU model name:', torch.cuda.get_device_name(0))
        print('GPU memory size:', torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024, 'GB')
    except:
        print('No GPU available.')
    
    args = parser.parse_args()
    
    return args

def main():
    args = parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
    else:
        print("There is no `args.=output_dir` specified! Model checkpoints will not be saved.")
        exit()

    if args.dataset_name_list:
        dataset_name_list = args.dataset_name_list.split(',')
    else:
        dataset_name_list = [args.dataset_name]
    
    for dataset_name in dataset_name_list:
        origin_path = os.path.join(args.output_dir, dataset_name, 'origin')
        os.makedirs(origin_path, exist_ok=True)
        dataset_url = dataset_URLs[dataset_name]
        dataset_file_name = dataset_url.split('/')[-1]
        dataset_file_path = os.path.join(origin_path, dataset_file_name)
        if not os.path.exists(dataset_file_path):
            download_file(dataset_url, dataset_file_path)

        if dataset_file_path.endswith('.tsv'):
            if dataset_name == 'MME':
                image_path_2_idx = tsv_to_jsonl(dataset_file_path, dataset_file_path.replace('.tsv', '.jsonl'), dataset_name)
            else:
                tsv_to_jsonl(dataset_file_path, dataset_file_path.replace('.tsv', '.jsonl'), dataset_name)
            dataset_file_path = dataset_file_path.replace('.tsv', '.jsonl')

        raw_datasets = load_dataset(
            "json",
            data_files=dataset_file_path,
        )

        hfdatasets_path = os.path.join(args.output_dir, dataset_name, 'datasets')
        images_path = os.path.join(args.output_dir, dataset_name, 'images')
        os.makedirs(hfdatasets_path, exist_ok=True)
        os.makedirs(images_path, exist_ok=True)
        
        def process_images(examples, idxes):
            if dataset_name == 'MME':
                examples['image_index'] = [image_path_2_idx[image_path] for image_path in examples['image_path']]
            else:
                examples['image_index'] = idxes
            examples['local_image_path'] = [os.path.join(images_path, f'{str(image_index).zfill(7)}.jpg') for image_index in examples['image_index']]
            for example_index, index in enumerate(examples['index']):
                if examples['image'][example_index] is None: # for HallusionBench
                    examples['image_index'][example_index] = None
                    examples['local_image_path'][example_index] = None
                elif len(examples['image'][example_index]) > 5: # for MME
                    decode_base64_to_image_file(examples['image'][example_index], examples['local_image_path'][example_index])
            return examples

        convert_datasets = raw_datasets.map(
            process_images,
            batched=True,
            with_indices=True,
            batch_size=args.process_batch_size,
            num_proc=args.process_num_workers, # may need to comment if stuck
            remove_columns=['image'],
            desc="Save image bytes to files",
        )

        print(convert_datasets)
        print(convert_datasets['train'].column_names)

        # the cache of tensors are quite large, e.g., 3M images => 2TB. 
        # If your images are saved as bytes, you may have to convert them to tensors firstly.
        # If your images are saved as files, you may better directly open the images before tokenization. 
        convert_datasets.save_to_disk(
            hfdatasets_path, max_shard_size="20GB"
        )

if __name__ == "__main__":
    main()
