import multiprocessing
import os
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset, concatenate_datasets
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, default='Merged_new', help='dataset name')
    
    args = parser.parse_args()
    
    return args

def main():
    args = parse_args()
    root_path = 'YOUR_ROOT_PATH/data/MLLM/IC'
    image_token_path = os.path.join(root_path, args.dataset_name, 'image_token')
    caption_path = os.path.join(root_path, args.dataset_name, 'caption')
    if args.dataset_name == 'Merged_new':
        folder_count = 16
    else:
        folder_count = 2
    folder_list = [os.path.join(image_token_path, str(i)) for i in range(folder_count)]
    for folder_index, folder in enumerate(folder_list):
        cur_image_token_dataset = load_from_disk(folder)['train']
        if folder_index == 0:
            image_token_dataset = cur_image_token_dataset
        else:
            image_token_dataset = concatenate_datasets([image_token_dataset, cur_image_token_dataset])

    print(image_token_dataset)

    def extract_caption(examples, idxes):
        new_examples = {
            "input_index": [],
            "caption_origin": [],
            "caption_coco": [],
        }
        for idx, example_idx in zip(idxes, range(len(examples['caption_origin']))):
            if args.dataset_name != 'Merged_new' or examples['caption_capsfusion'][example_idx] == '':
                # print(f"idx: {idx}, caption_capsfusion: {examples['caption_capsfusion'][example_idx]}")
                new_examples['input_index'].append(idx)
                new_examples['caption_origin'].append(examples['caption_origin'][example_idx])
                new_examples['caption_coco'].append(examples['caption_coco'][example_idx])
        return new_examples

    # image_token_dataset.save_to_disk(os.path.join(image_token_path, 'merged'), max_shard_size="20GB")
    if args.dataset_name == 'Merged_new':
        remove_columns = ['url', 'image_tokens', 'caption_origin', 'caption_coco', 'caption_capsfusion']
    else:
        remove_columns = ['url', 'image_tokens', 'caption_origin', 'caption_coco']

    caption_to_be_generated = image_token_dataset.map(
        extract_caption,
        with_indices=True,
        batched=True,
        batch_size=1000,
        num_proc=multiprocessing.cpu_count(),
        remove_columns=remove_columns
    )

    print(caption_to_be_generated)
    caption_to_be_generated.save_to_disk(caption_path, max_shard_size="20GB")


if __name__ == "__main__":
    main()