import os
from PIL import Image
from io import BytesIO
from utils import dataset_URLs, download_file, decode_base64_to_image_file
from datasets import load_dataset, load_from_disk, DatasetDict
import multiprocessing
import shutil


image_dict = {
    "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
    "file_name": "images.zip",
}

annotation_dict = {
    "url": "https://huggingface.co/datasets/Otter-AI/GQA/resolve/main/testdev_balanced_instructions/testdev-00000-of-00001.parquet",
    "file_name": "gqa_testdev_balanced_annotations.parquet",
}

output_dir = "YOUR_ROOT_PATH/data/MLLM/Evaluation"
dataset_name = "GQA_TESTDEV_BALANCED"

origin_path = os.path.join(output_dir, dataset_name, 'origin')
origin_image_path = os.path.join(origin_path, 'images')
new_image_path = os.path.join(output_dir, dataset_name, 'images')
os.makedirs(origin_path, exist_ok=True)
os.makedirs(new_image_path, exist_ok=True)
for download_dict in [image_dict, annotation_dict]:
    dataset_url = download_dict["url"]
    dataset_file_name = download_dict["file_name"]
    dataset_file_path = os.path.join(origin_path, dataset_file_name)
    if not os.path.exists(dataset_file_path):
        download_file(dataset_url, dataset_file_path)

# unzip images.zip
if not os.path.exists(origin_image_path):
    os.makedirs(origin_image_path)
    os.system(f"unzip {os.path.join(origin_path, image_dict['file_name'])} -d {origin_image_path}")

annotation_dataset = load_dataset(
    "parquet",
    data_files=os.path.join(origin_path, annotation_dict["file_name"]),
)['train']

image_id2image_index = dict.fromkeys(set(annotation_dataset['imageId']))
image_id2image_index = {image_id: i for i, image_id in enumerate(image_id2image_index)}


def update_image(example):
    example['image_index'] = image_id2image_index[example['imageId']]
    local_image_path = os.path.join(origin_image_path, f"{example['imageId']}.jpg")
    example['local_image_path'] = os.path.join(new_image_path, f"{str(example['image_index']).zfill(7)}.jpg")
    # copy image from origin_path to local_image_path
    # shutil.copy(local_image_path, example['local_image_path'])
    if not os.path.exists(example['local_image_path']):
        os.system(f"cp {local_image_path} {example['local_image_path']}")
    return example

annotation_dataset = annotation_dataset.map(update_image)

hfdatasets_path = os.path.join(output_dir, dataset_name, 'datasets')
hfdatasets = DatasetDict()
hfdatasets['train'] = annotation_dataset
print(hfdatasets)
hfdatasets.save_to_disk(hfdatasets_path)