import argparse
import io
from io import BytesIO
import requests
import os
import subprocess
import json
import multiprocessing
from multiprocessing.pool import ThreadPool, Pool
from PIL import Image, ImageFile
from tqdm import tqdm
from tqdm.contrib import tzip
import numpy as np
import pandas as pd
import time
import sys
from img2dataset import download
import shutil
import math
import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_from_disk, Dataset, DatasetDict, load_dataset

def parse_args():
    parser = argparse.ArgumentParser(description="BLIP and CapsFusion100M datasets download")
    
    parser.add_argument("--data_dir", '-d', type=str, default="YOUR_ROOT_PATH/data/MLLM/IC/Merged_new/urls", help="data dir path")
    parser.add_argument("--save_dir", '-id', type=str, default="YOUR_ROOT_PATH/data/MLLM/IC/Merged_new/images", help="image save dir")
    parser.add_argument("--num_shards", '-ns', type=int, default=10, help="total shards due to the memory limit")
    parser.add_argument("--number_sample_per_shard", '-nsp', type=int, default=10000, help="number of samples per shard")
    parser.add_argument("--download_chunkindex", '-di', type=int, default=-1, help="download chunkindex")
    parser.add_argument("--test_chunkindex", '-ti', type=int, default=0, help="test chunkindex")
    parser.add_argument("--test_row_index", '-tr', type=int, default=100, help="test row index")

    args = parser.parse_args()
    
    return args

def img2dataset(url_dataset, args):

    download(
        url_list=url_dataset,
        image_size=224,
        output_folder=args.save_dir,
        processes_count=1,
        thread_count=64,
        resize_mode="keep_ratio",
        encode_quality=100,
        encode_format="jpg",
        output_format="webdataset",
        input_format="parquet",
        url_col="url",
        # caption_col="caption_capsfusion",
        save_additional_columns=["caption_origin", "caption_coco", "caption_capsfusion"],
        number_sample_per_shard=args.number_sample_per_shard,
        extract_exif=False,
        enable_wandb=True,
        wandb_project="img2dataset",
        distributor="multiprocessing",
        timeout=10,
        retries=0,
        min_image_size=224,
        max_aspect_ratio=3,
        incremental_mode="incremental",
        max_shard_retry=1,
    )

def main(args):
    
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    cur_url_path = os.path.join(args.data_dir, f'urls_{args.download_chunkindex}.parquet')
    args.save_dir = os.path.join(args.save_dir, str(args.download_chunkindex))
    if args.download_chunkindex == -1 or not os.path.exists(cur_url_path):
        urls_dataset = load_dataset("parquet", data_files=os.path.join(args.data_dir, 'filter_20_50.parquet'))['train']
        if args.download_chunkindex == -1:
            cur_urls_dataset = urls_dataset.select(range(10000))
            args.download_chunkindex = "test"
            cur_url_path = os.path.join(args.data_dir, f'urls_{args.download_chunkindex}.parquet')
            args.save_dir = os.path.join(args.save_dir, str(args.download_chunkindex))
        else:
            cur_urls_dataset = urls_dataset.shard(num_shards=args.num_shards, index=args.download_chunkindex, contiguous=True)
        cur_urls_dataset.to_parquet(cur_url_path)
        del urls_dataset, cur_urls_dataset

    img2dataset(cur_url_path, args)

if __name__ == '__main__':
    args = parse_args()
    main(args)