import os
import cv2
import matplotlib.pyplot as plt
from pycocotools.coco import COCO

def batch_visualize_coco(ann_file, img_dir, output_dir, img_range=None):
   """
   Batch generate visualization files for COCO dataset
   
   Args:
       ann_file: path to COCO annotation file
       img_dir: directory containing images
       output_dir: directory to save visualization results
       img_range: tuple (start, end) for image ID range, None for all images
   """
   # Load COCO annotation file
   coco = COCO(ann_file)
   
   # Create output directory if it doesn't exist
   os.makedirs(output_dir, exist_ok=True)
   
   # Get image IDs to process
   if img_range is None:
       img_ids = coco.getImgIds()
       print(f"Processing all {len(img_ids)} images")
   else:
       start_id, end_id = img_range
       img_ids = list(range(start_id, end_id + 1))
       print(f"Processing images from ID {start_id} to {end_id}")
   
   processed_count = 0
   skipped_count = 0
   
   for img_id in img_ids:
       try:
           # Load image info
           img_info = coco.loadImgs(img_id)
           if not img_info:
               print(f"Warning: Image ID {img_id} not found in annotations")
               skipped_count += 1
               continue
               
           img_info = img_info[0]
           img_path = os.path.join(img_dir, img_info['file_name'])
           
           # Check if image file exists
           if not os.path.exists(img_path):
               print(f"Warning: Image file {img_path} not found")
               skipped_count += 1
               continue
           
           # Load image
           img = cv2.imread(img_path)
           img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
           
           # Get annotations for this image
           ann_ids = coco.getAnnIds(imgIds=img_id)
           anns = coco.loadAnns(ann_ids)
           
           # Create visualization
           plt.figure(figsize=(12, 8))
           plt.axis('off')
           plt.tight_layout(pad=0)
           plt.imshow(img)
           
           # Show annotations with semi-transparent overlay
           if anns:
               coco.showAnns(anns)
           
           # Generate output filename
           base_name = os.path.splitext(img_info['file_name'])[0]
           output_filename = f"{base_name}_gt.jpg"
           output_path = os.path.join(output_dir, output_filename)
           
           # Save visualization
           plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150)
           plt.close()  # Close figure to free memory
           
           processed_count += 1
           if processed_count % 100 == 0:
               print(f"Processed {processed_count} images...")
               
       except Exception as e:
           print(f"Error processing image ID {img_id}: {str(e)}")
           skipped_count += 1
           continue
   
   print(f"Batch visualization completed!")
   print(f"Successfully processed: {processed_count} images")
   print(f"Skipped: {skipped_count} images")

# Usage example
if __name__ == "__main__":
   ann_file = '/data/xxx/datasets/coco/annotations/coco_cls_agnostic_instances_val2017.json'
   img_dir = '/data/xxx/datasets/coco/val2017'
   output_dir = '/data/xxx/datasets/coco/visualizations'
   
   # Process all images
   batch_visualize_coco(ann_file, img_dir, output_dir)
   
   # Process specific range (e.g., image IDs 2-102)
#    batch_visualize_coco(ann_file, img_dir, output_dir, img_range=(2, 102))