import json
import cv2
import os
import numpy as np
from pycocotools import mask as maskUtils

def draw_bbox(image, bbox, color, thickness=2):
    """Draw bounding box on image"""
    x, y, w, h = bbox
    x, y, w, h = int(x), int(y), int(w), int(h)
    cv2.rectangle(image, (x, y), (x + w, y + h), color, thickness)
    return image

def draw_mask(image, segmentation, color=(0, 255, 0), alpha=0.5):
    """Draw mask on image"""
    try:
        if isinstance(segmentation, list):
            # Check if it's polygon format or RLE format
            if len(segmentation) > 0 and isinstance(segmentation[0], list):
                # polygon format - each element is a coordinate list
                mask = np.zeros(image.shape[:2], dtype=np.uint8)
                for seg in segmentation:
                    if len(seg) >= 6:  # At least 3 points needed (6 coordinates)
                        poly = np.array(seg).reshape(-1, 2).astype(np.int32)
                        cv2.fillPoly(mask, [poly], 1)
            else:
                # RLE format but stored as list
                # Need to construct RLE dictionary
                if len(segmentation) == 1 and isinstance(segmentation[0], dict):
                    # Already in RLE dictionary format
                    mask = maskUtils.decode(segmentation[0])
                else:
                    # Skip unprocessable format
                    return image
        # RLE dict format processing should be:
        elif isinstance(segmentation, dict):
            if isinstance(segmentation['counts'], list):
                # Uncompressed RLE, needs conversion
                height, width = segmentation["size"]
                rle = maskUtils.frPyObjects([segmentation], height, width)[0]  # Take the first one
                mask = maskUtils.decode(rle)
            else:
                # Already compressed RLE, decode directly
                mask = maskUtils.decode(segmentation)
        else:
            # Other formats, skip
            return image
        
        # Create colored mask
        colored_mask = np.zeros_like(image)
        colored_mask[mask == 1] = color
        
        # Blend with original image
        image = cv2.addWeighted(image, 1, colored_mask, alpha, 0)
        return image
        
    except Exception as e:
        print(f"Error drawing mask: {e}")
        import traceback
        traceback.print_exc()
        return image

def visualize_coco_annotations(json_path, image_root_path, output_dir, max_images=200, confidence_threshold=0.5):
    """
    Visualize COCO format annotations
    
    Args:
        json_path: Path to COCO format JSON file
        image_root_path: Root directory containing images
        output_dir: Directory to save visualization results
        max_images: Maximum number of images to process (default: 200)
    """
    
    # Load COCO format JSON
    with open(json_path, 'r') as f:
        coco_data = json.load(f)
    
    images = coco_data['images']
    annotations = coco_data['annotations']
    
    # Build mapping from image_id to annotations
    image_to_anns = {}
    for ann in annotations:
        image_id = ann['image_id']
        if image_id not in image_to_anns:
            image_to_anns[image_id] = []
        image_to_anns[image_id].append(ann)
    
    # Create output directory if not exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Color palette (BGR format for OpenCV)
    colors = [
        (255, 0, 0),    # 红色
        (255, 165, 0),  # 橙色
        (255, 255, 0),  # 黄色
        (0, 255, 0),    # 绿色
        (0, 0, 255),    # 蓝色
        (128, 0, 128),  # 紫色
        (255, 192, 203), # 粉色
    ]
    bgr_colors = [(b, g, r) for (r, g, b) in colors]
    
    # Process first max_images images
    processed_count = 0
    for img_info in images:
        if processed_count >= max_images:
            break
            
        image_id = img_info['id']
        filename = img_info['file_name']
        
        # Construct full image path
        image_path = os.path.join(image_root_path, filename)
        
        # Check if image file exists
        if not os.path.exists(image_path):
            print(f"Warning: Image {image_path} not found, skipping...")
            continue
        
        # Load image
        image = cv2.imread(image_path)
        if image is None:
            print(f"Warning: Failed to load image {image_path}, skipping...")
            continue
        
        # Get annotations for this image
        detections = image_to_anns.get(image_id, [])
        
        # Draw annotations
        for idx, det in enumerate(detections):
            if 'score' in det and det['score'] < confidence_threshold:
                continue

            color = bgr_colors[idx % len(bgr_colors)]
            # Draw bounding box
            if 'bbox' in det:
                image = draw_bbox(image, det['bbox'], color)

            if 'segmentation' in det and det['segmentation']:
                # Draw mask
                image = draw_mask(image, det['segmentation'], color)
                
        # Save visualization result
        output_filename = f"{image_id}_bright.jpg"
        output_path = os.path.join(output_dir, output_filename)
        cv2.imwrite(output_path, image)
        
        processed_count += 1
        if processed_count % 50 == 0:
            print(f"Processed {processed_count} images...")
    
    print(f"Visualization completed! Processed {processed_count} images.")
    print(f"Results saved to: {output_dir}")

DATASET_ROOT=f"/data/xxx/datasets"
EVAL_ROOT="coler_official_eval/cls_agnostic"
result_eval_dict={
    "coco20k": [f"{DATASET_ROOT}/coco/train2014", f"{DATASET_ROOT}/coco/annotations/coco20k_trainval_gt.json", f"{EVAL_ROOT}_coco20k/inference/coco_instances_results_dict.json"],
    "voc": [f"{DATASET_ROOT}/voc", f"{DATASET_ROOT}/voc/annotations/trainvaltest_2007_cls_agnostic.json", f"{EVAL_ROOT}_voc/inference/coco_instances_results_dict.json"],
    "coco": [f"{DATASET_ROOT}/coco/val2017", f"{DATASET_ROOT}/coco/annotations/coco_cls_agnostic_instances_val2017.json", f"{EVAL_ROOT}_coco/inference/coco_instances_results_dict.json"],
    "imagenet": [f"{DATASET_ROOT}/imagenet/val", f"{DATASET_ROOT}/imagenet/annotations/imagenet_val_cls_agnostic_gt.json", f"{EVAL_ROOT}_imagenet/inference/coco_instances_results_dict.json"],
    "lvis": [f"{DATASET_ROOT}/coco", f"{DATASET_ROOT}/coco/annotations/lvis1.0_cocofied_val_cls_agnostic.json", f"{EVAL_ROOT}_lvis/inference/coco_instances_results_dict.json"],
    "clipart": [f"{DATASET_ROOT}/clipart", f"{DATASET_ROOT}/clipart/annotations/traintest_clipart_cls_agnostic.json", f"{EVAL_ROOT}_clipart/inference/coco_instances_results_dict.json"],
    "watercolor": [f"{DATASET_ROOT}/watercolor", f"{DATASET_ROOT}/watercolor/annotations/traintest_watercolor_cls_agnostic.json", f"{EVAL_ROOT}_watercolor/inference/coco_instances_results_dict.json"],
    "comic": [f"{DATASET_ROOT}/comic", f"{DATASET_ROOT}/comic/annotations/traintest_comic_cls_agnostic.json", f"{EVAL_ROOT}_comic/inference/coco_instances_results_dict.json"],
    "kitti": [f"{DATASET_ROOT}/kitti", f"{DATASET_ROOT}/kitti/annotations/trainval_cls_agnostic.json", f"{EVAL_ROOT}_kitti/inference/coco_instances_results_dict.json"],
    "openimages": [f"{DATASET_ROOT}/openimages-v7/validation", f"{DATASET_ROOT}/openimages-v7/annotations/openimages_val_cls_agnostic.json", f"{EVAL_ROOT}_openimages/inference/coco_instances_results_dict.json"],
    "objects365": [f"{DATASET_ROOT}/objects365/val", f"{DATASET_ROOT}/objects365/annotations/zhiyuan_objv2_val_cls_agnostic.json", f"{EVAL_ROOT}_objects365/inference/coco_instances_results_dict.json"],
}

def visualize_all():
    dataset_names = [
        # "coco", "coco20k", 
        # "lvis", "voc", 
        # "clipart", "watercolor", "comic", 
        "kitti", "openimages", 
        "objects365",
        ]
    # dataset_names = ["objects365"]
    # dataset_names = ["openimages"]
    # dataset_names = ["voc"]
    for dataset_name in dataset_names:
        dataset_path = result_eval_dict[dataset_name][0]
        ann_file = result_eval_dict[dataset_name][2]
        gt = result_eval_dict[dataset_name][1]
        visualize_coco_annotations(
            # json_path=ann_file, 
            json_path=gt, 
            image_root_path=dataset_path,
            # output_dir=f'visualization_{dataset_name}_COLER', 
            output_dir=f'visualization_{dataset_name}_gt', 
            max_images=200
        )

# Usage example
if __name__ == "__main__":
    # import argparse
    
    # parser = argparse.ArgumentParser(description='Visualize COCO format annotations')
    # parser.add_argument('--json_path', required=True, help='Path to COCO format JSON file')
    # parser.add_argument('--image_root', required=True, help='Root directory containing images')
    # parser.add_argument('--output_dir', required=True, help='Directory to save visualization results')
    # parser.add_argument('--max_images', type=int, default=200, help='Maximum number of images to process')
    
    # args = parser.parse_args()
    
    # visualize_coco_annotations(
    #     args.json_path, 
    #     args.image_root, 
    #     args.output_dir, 
    #     args.max_images
    # )
    visualize_all()
    exit(0)

    visualize_coco_annotations(
        json_path='coler_official_eval/cls_agnostic_objects365/inference/coco_instances_results_dict.json', 
        image_root_path='/data/xxx/datasets/objects365/val',
        output_dir='visualization_objects365_COLER', 
        max_images=200
    )