import json
import random

# Set random seed for reproducibility
random.seed(42)

# Load the captions_train2017.json file
with open('annotations/captions_train2017.json', 'r') as f:
    data = json.load(f)

# Extract all image IDs
id_set = set()

for item in data['annotations']:
    image_id = item['image_id']
    id_set.add(image_id)

id_set = list(id_set)  # Convert set to list for easier manipulation

print(f"Total unique image IDs: {len(id_set)}")
# Shuffle the image IDs
random.shuffle(id_set)

# Calculate split index for 98% train and 2% validation
split_index = int(0.98 * len(id_set))

# Split the image IDs
train_ids = id_set[:split_index]
val_ids = id_set[split_index:]

print(f"Number of training IDs: {len(train_ids)}")
print(f"Number of validation IDs: {len(val_ids)}")

# Write train IDs to train_id_coco.txt
with open('train_id_coco.txt', 'w') as f:
    for image_id in train_ids:
        f.write(f"{image_id}\n")

# Write validation IDs to validation_id_coco.txt
with open('validation_id_coco.txt', 'w') as f:
    for image_id in val_ids:
        f.write(f"{image_id}\n")

print("Split completed. Train and validation IDs saved to 'train_id_coco.txt' and 'validation_id_coco.txt'.")
