import os
from multiprocessing import Pool

from tqdm import tqdm

ROOT_FROM = ""  # the path of laion-ocr-zip
ROOT_TO = ""  # the path for saving dataset
MULTIPROCESSING_NUM = 64
DOWNLOAD_IMAGES = False  # whether to download images from urls


def remove_files_except_caption(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == "charseg.npy":
                os.remove(os.path.join(root, file))


def unzip_file(idx):
    if not os.path.exists(f"{ROOT_FROM}/{idx}.zip") or os.path.exists(
        f"{ROOT_TO}/{idx}"
    ):
        return
    cmd = f"unzip -q {ROOT_FROM}/{idx}.zip -d {ROOT_TO}"
    os.system(cmd)

    # Define the subdirectory path
    subdirectory = os.path.join(ROOT_TO, idx)
    # Cleanup files in the subdirectory
    remove_files_except_caption(subdirectory)


def multiprocess_unzip_file(idxs):
    os.makedirs(ROOT_TO, exist_ok=True)

    with Pool(processes=MULTIPROCESSING_NUM) as p:
        with tqdm(total=len(idxs), desc="total") as pbar:
            for i, _ in enumerate(p.imap_unordered(unzip_file, idxs)):
                pbar.update()
    print("multiprocess_unzip_file done!")


if __name__ == "__main__":
    files = os.listdir(ROOT_FROM)
    idxs = [str(idx[:-4]).zfill(5) for idx in files]
    multiprocess_unzip_file(idxs)
    print("Finished!")
