import json
import random
from collections import defaultdict
from pycocotools.coco import COCO

def create_fewshot_coco_dataset(ann_file, output_ann_file, fraction_per_class=0.2):
    """
    Creates a few-shot COCO dataset with a specified fraction of images per class.

    Parameters:
    - ann_file (str): Path to the original COCO annotations file (instances_val2017.json).
    - output_ann_file (str): Path to save the new few-shot annotations file.
    - fraction_per_class (float): Fraction of images per class to include in the dataset (e.g., 0.2 for 20%).
    """

    # Load the COCO dataset
    print("Loading COCO annotations...")
    coco = COCO(ann_file)

    # Get all category IDs
    cat_ids = coco.getCatIds()
    print(f"Total categories: {len(cat_ids)}")

    # Map to store selected image IDs and annotations
    selected_img_ids = set()
    selected_ann_ids = set()

    # For each category, randomly select a fraction of images
    print("Selecting images per category...")
    for idx, cat_id in enumerate(cat_ids):
        img_ids = coco.getImgIds(catIds=[cat_id])
        num_images = len(img_ids)
        num_select = max(1, int(num_images * fraction_per_class))
        selected_imgs = random.sample(img_ids, num_select)
        selected_img_ids.update(selected_imgs)
        print(f"Category {idx+1}/{len(cat_ids)}: Selected {num_select} out of {num_images} images.")

    # Load image information for the selected images
    print("Loading selected images...")
    selected_img_ids = list(selected_img_ids)
    imgs = coco.loadImgs(selected_img_ids)

    # Get all annotation IDs for the selected images
    print("Collecting annotations for selected images...")
    ann_ids = coco.getAnnIds(imgIds=selected_img_ids)
    anns = coco.loadAnns(ann_ids)

    # Create the new dataset dictionary
    fewshot_dataset = {
        'info': coco.dataset.get('info', {}),
        'licenses': coco.dataset.get('licenses', []),
        'images': imgs,
        'annotations': anns,
        'categories': coco.dataset.get('categories', [])
    }

    # Save the new annotations to a JSON file
    print(f"Saving few-shot dataset to {output_ann_file}...")
    with open(output_ann_file, 'w') as f:
        json.dump(fewshot_dataset, f)

    print(f"Few-shot COCO dataset created with {len(imgs)} images and {len(anns)} annotations.")
    print(f"Dataset saved to {output_ann_file}")

# Example usage
if __name__ == "__main__":
    ann_file = '/root/annotations/instances_train2017.json'  # Replace with your path
    output_ann_file = '/root/autodl-tmp/new/instances_train2017_fewshot.json'  # Replace with your path
    fraction_per_class = 0.3  # 20% per class

    create_fewshot_coco_dataset(ann_file, output_ann_file, fraction_per_class)
