import os
import json
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

base_path = ''

def delete_image(unused_image_chunk):
    for image_path in unused_image_chunk:
        try:
            os.remove(os.path.join(base_path, image_path))
        except:
            print(f"Not existing image path {image_path}")

    return True

# Delete unused images to save memory, since only about 514k images out of 2M images are actually used in Encyclopedic-VQA
# We further remove redundant images to maintain only a single image per text query.
if __name__ == '__main__':

    dataset_id_to_path = json.load(open(os.path.join(base_path, 'dataset_id_to_path.json'), 'r'))
    all_data_path_set = set(dataset_id_to_path.values())

    actual_data_path_list = []

    # Load encyclopedic-vqa query files
    encyclo_train = pd.read_csv(open(os.path.join(base_path, 'train_clean.csv')))
    encyclo_val = pd.read_csv(open(os.path.join(base_path, 'val_clean.csv')))
    encyclo_test = pd.read_csv(open(os.path.join(base_path, 'test_clean.csv')))
    encyclo = pd.concat([encyclo_train, encyclo_val, encyclo_test])

    # Attain the actually used image data path
    for i, data_row in tqdm(encyclo.iterrows(), total=len(encyclo)):

        image_ids = data_row['dataset_image_ids'].split('|')
        for idx in image_ids:
            actual_data_path_list.append(dataset_id_to_path[idx])

    actual_data_path_set = set(actual_data_path_list)

    # Collect unused image data path
    unused_data_path_set = all_data_path_set - actual_data_path_set
    print(f"Total Size {len(all_data_path_set)}")
    print(f"Actual size {len(actual_data_path_set)}")

    unused_data_path_list = list(unused_data_path_set)

    # Remove the files
    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(unused_data_path_list) // num_workers
    chunks = [unused_data_path_list[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
              else unused_data_path_list[i*chunk_size:] for i in range(num_workers)]

    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(delete_image, chunk) for chunk in chunks]

        for future in tqdm(as_completed(futures), total=len(futures), desc="Deleting unused images"):
            assert future.result()

    print("Done!")