import json
import cv2
import numpy as np
import torch
import os
from detectron2.utils.visualizer import Visualizer
from detectron2.structures import Boxes, Instances
from tqdm import tqdm
from collections import defaultdict

def batch_visualize(json_file_path, images_path, output_dir="output", confidence_threshold=0.5, start_index=0, max_count=100):
    """
    Batch visualize COCO format detections
    
    Args:
        json_file_path: Path to JSON file containing detection results (COCO dict format or list format)
        images_path: Path to directory containing images or list of image paths
        output_dir: Directory to save visualization results
        confidence_threshold: Minimum confidence score to display
    """
    
    # Load detection results
    print("Loading detection results...")
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    
    # Handle both COCO dict format and list format
    if isinstance(data, dict):
        # Standard COCO format
        detections_list = data.get('annotations', [])
        images_info = data.get('images', [])
        
        # Create mapping from image_id to filename
        image_id_to_filename = {}
        for img_info in images_info:
            # image_id_to_filename[img_info['id']] = img_info['file_name']
            image_id_to_filename[img_info['id']] = img_info['file_name'].split('/')[-1]
        
        # Apply slicing to detections_list
        end_index = start_index + max_count
        detections_list = detections_list[start_index:end_index]
    elif isinstance(data, list):
        # List format (original)
        end_index = start_index + max_count
        detections_list = data[start_index:end_index]
        image_id_to_filename = {}
    else:
        raise ValueError("JSON data must be either a dict (COCO format) or a list")
    
    # Group detections by image_id
    detections_by_image = defaultdict(list)
    for det in detections_list:
        if det.get("score", 1.0) >= confidence_threshold:
            detections_by_image[det['image_id']].append(det)
    
    # Get image paths
    if isinstance(images_path, str) and os.path.isdir(images_path):
        # Directory path
        image_files = {}
        
        if image_id_to_filename:
            # Use COCO format mapping
            for image_id, filename in image_id_to_filename.items():
                img_path = os.path.join(images_path, filename)
                if os.path.exists(img_path):
                    image_files[image_id] = img_path
        else:
            # Original logic for list format
            for filename in os.listdir(images_path):
                if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    # Extract image_id from filename (assuming filename contains image_id)
                    # You may need to modify this based on your naming convention
                    image_id = int(os.path.splitext(filename)[0])
                    image_files[image_id] = os.path.join(images_path, filename)
    else:
        # List of image paths
        image_files = {}
        for img_path in images_path:
            filename = os.path.basename(img_path)
            if image_id_to_filename:
                # Find image_id by matching filename
                for img_id, fname in image_id_to_filename.items():
                    if fname == filename:
                        image_files[img_id] = img_path
                        break
            else:
                # Original logic
                image_id = int(os.path.splitext(filename)[0])
                image_files[image_id] = img_path
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Process each image
    print(f"Processing {len(detections_by_image)} images...")
    processed_count = 0

    for image_id, detections in tqdm(detections_by_image.items()):
        # Find corresponding image file
        if image_id not in image_files:
            print(f"Warning: Image file not found for image_id {image_id}")
            continue
        
        img_path = image_files[image_id]
        
        # Load image
        try:
            img = cv2.imread(img_path)
            if img is None:
                print(f"Warning: Cannot load image {img_path}")
                continue
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            continue
        
        # Convert detections to Instances format
        instances = Instances((img.shape[0], img.shape[1]))  # (height, width)
        
        boxes = []
        classes = []
        scores = []
        masks = []

        for idx, det in enumerate(detections):
            # Convert bbox from [x, y, w, h] to [x1, y1, x2, y2]
            x, y, w, h = det['bbox']
            boxes.append([x, y, x + w, y + h])
            classes.append(idx+1)
            # classes.append(det['category_id'])
            # scores.append(det['score'])

            # Process segmentation mask
            if 'segmentation' in det and det['segmentation']:
                # Handle different segmentation formats
                seg = det['segmentation']
                if isinstance(seg, dict) and 'counts' in seg:
                    # RLE format
                    from pycocotools import mask as maskUtils
                    mask = maskUtils.decode(seg)
                elif isinstance(seg, list) and len(seg) > 0:
                    # Polygon format - convert to mask
                    from pycocotools import mask as maskUtils
                    mask = maskUtils.frPyObjects(seg, img.shape[0], img.shape[1])
                    mask = maskUtils.decode(mask)
                else:
                    mask = None
                masks.append(mask)
            else:
                masks.append(None)
        
        if len(boxes) == 0:
            print(f"No valid detections for image_id {image_id}")
            continue
        
        # Convert to tensors and assign to instances
        instances.pred_boxes = Boxes(torch.tensor(boxes, dtype=torch.float32))
        instances.pred_classes = torch.tensor(classes, dtype=torch.int64)
        # instances.scores = torch.tensor(scores, dtype=torch.float32)
        # Add masks if available
        valid_masks = [m for m in masks if m is not None]
        if valid_masks:
            instances.pred_masks = torch.tensor(np.stack(valid_masks), dtype=torch.bool)
        # Create visualizer with larger font
        v = Visualizer(img_rgb, font_size_scale=1.5)
        # v._default_font_size = 20  # Increase font size for confidence scores
        
        # Visualize (only confidence scores, no class names)
        try:
            out = v.draw_instance_predictions(instances)
            
            # Save result
            output_filename = f"{image_id}_visualized.jpg"
            output_path = os.path.join(output_dir, output_filename)
            cv2.imwrite(output_path, cv2.cvtColor(out.get_image(), cv2.COLOR_RGB2BGR))
            # output_filename = f"{image_id}_visualized.pdf"
            # output_path = os.path.join(output_dir, output_filename)

            # # 获取可视化图像并转换为PIL Image
            # vis_image = out.get_image()
            # pil_image = Image.fromarray(vis_image)

            # # 保存为PDF，保持原始分辨率
            # pil_image.save(output_path, format='PDF', resolution=100.0)
            processed_count += 1
            
        except Exception as e:
            print(f"Error processing image_id {image_id}: {e}")
            continue
    
    print(f"Successfully processed {processed_count} images")
    print(f"Results saved to: {output_dir}")


if __name__ == "__main__":
    ann_file = "pseudo_labels/coco_val2017_cutonce_factor0.1.json"
    ann_file = "/data/xxx/segmentation/CutLER/maskcut/maskcut_annotation/coco_train_fixsize480_tau0.15_N3.json"
    batch_visualize(
        json_file_path=ann_file,
        images_path="/data/xxx/datasets/coco/val2017",
        # output_dir="visualization_COCOval2017_CutOnce",
        output_dir="visualization_COCOval2017_MaskCut",
        confidence_threshold=0.5,
        max_count=6000,
    )

    # ann_file = "pseudo_labels/imagenet_val_cutonce_filter0.95-new.json"
    # batch_visualize(
    #     json_file_path=ann_file,
    #     images_path="/data/xxx/datasets/imagenet/val",
    #     output_dir="visualization_imagenet_val-CutOnce",
    #     confidence_threshold=0.5
    # )
