from datasets import load_dataset, load_from_disk, Dataset, DatasetDict
import multiprocessing
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../WIT/pre_process'))
from utils import read_with_orjsonl, write_with_orjsonl, write_with_orjsonl_extend

root_path = "YOUR_ROOT_PATH/data/MLLM/IC/JourneyDB/data"
folder_count = 200
annotation_file = {
    "train": os.path.join(root_path, "train/train_anno_realease_repath.jsonl"),
    "valid": os.path.join(root_path, "valid/valid_anno_repath.jsonl"),
}

def delete_zero_size_files(directory):
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        if os.path.isfile(file_path) and os.path.getsize(file_path) == 0:
                os.remove(file_path)
                print(f"Deleted {file_path}")

for split in ['train', 'valid']:
    folder_list = [str(i).zfill(3) for i in range(folder_count)]
    # unzip image tgz
    for folder in folder_list:
        os.system(f"tar -xzf {os.path.join(root_path, split, 'imgs', folder + '.tgz')} -C {os.path.join(root_path, split, 'imgs')}")
        delete_zero_size_files(os.path.join(root_path, split, 'imgs', folder))
    
    # rm -rf YOUR_ROOT_PATH/data/MLLM/IC/JourneyDB/data/train/imgs/000/7f0b4fae-2a33-440e-bfc0-b15d7e0d53fe.jpg
    # rm -rf YOUR_ROOT_PATH/data/MLLM/IC/JourneyDB/data/train/imgs/000/6d721d04-7953-43b8-a089-61c4467e6f05.jpg
    annotation_file = read_with_orjsonl(annotation_file[split])
    drop_index = []
    for i in range(len(annotation_file)): # drop 'Task3'
        annotation_file[i].pop('Task3')
        annotation_file[i].pop('prompt')
        annotation_file[i].pop('ori_prompt')
        try:
            annotation_file[i]['caption'] = annotation_file[i]['Task2']['Caption']
            annotation_file[i]['style'] = ", ".join(annotation_file[i]['Task1']['Style'])
            annotation_file[i]['folder'], annotation_file[i]['file_name'] = annotation_file[i]['img_path'].split('.', maxsplit=1)[1][1:].split('/', maxsplit=1)
            annotation_file[i].pop('img_path')
            annotation_file[i].pop('Task1')
            annotation_file[i].pop('Task2')
        except Exception as e:
            print(e)
            drop_index.append(i)
    print(drop_index) # need to delete these images manually
    print(len(drop_index)) # 17
    images_path_deleted = []
    for i in drop_index[::-1]:
        delete_annotation = annotation_file.pop(i)
        folder, file_name = delete_annotation['img_path'].split('.', maxsplit=1)[1][1:].split('/', maxsplit=1)
        images_path_deleted.append(os.path.join(root_path, split, 'imgs', folder, file_name))
    print(images_path_deleted)
    for path in images_path_deleted:
        os.remove(path)
    
    annotation_dataset = Dataset.from_list(annotation_file)
    del annotation_file
    print(annotation_dataset)
    
    annotation_dataset = annotation_dataset.sort("folder")
    for folder in folder_list:
        annotation_dataset.filter(lambda x: x['folder'] == folder).to_csv(os.path.join(root_path, split, 'imgs', folder, 'metadata.csv'))