import json
import cv2
import numpy as np
import torch
import os
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.structures import Boxes, Instances
from tqdm import tqdm
from collections import defaultdict
from detectron2.data import MetadataCatalog
from pycocotools import mask as coco_mask

MetadataCatalog.get("my_dataset").thing_colors = [
    [255, 0, 0],    # 红色
    [255, 165, 0],  # 橙色
    [255, 255, 0],  # 黄色
    [0, 255, 0],    # 绿色
    [0, 0, 255],    # 蓝色
    [128, 0, 128],  # 紫色
    [255, 192, 203], # 粉色
]

def draw_bbox(image, bbox, color=(0, 255, 0), thickness=2):
   """在图像上绘制边界框"""
   x, y, w, h = bbox
   x, y, w, h = int(x), int(y), int(w), int(h)
   cv2.rectangle(image, (x, y), (x + w, y + h), color, thickness)
   return image

def draw_mask(image, segmentation, color=(0, 255, 0), alpha=0.5):
    """在图像上绘制掩码"""
    try:
        if isinstance(segmentation, list):
            # 检查是否是polygon格式还是RLE格式
            if len(segmentation) > 0 and isinstance(segmentation[0], list):
                # polygon格式 - 每个元素是坐标列表
                mask = np.zeros(image.shape[:2], dtype=np.uint8)
                for seg in segmentation:
                    if len(seg) >= 6:  # 至少需要3个点(6个坐标)
                        poly = np.array(seg).reshape(-1, 2).astype(np.int32)
                        cv2.fillPoly(mask, [poly], 1)
            else:
                # RLE格式但是以list形式存储
                # 需要构造RLE字典
                if len(segmentation) == 1 and isinstance(segmentation[0], dict):
                    # 已经是RLE字典格式
                    mask = coco_mask.decode(segmentation[0])
                else:
                    # 跳过无法处理的格式
                    return image
        # RLE dict格式处理应该是：
        elif isinstance(segmentation, dict):
            if isinstance(segmentation['counts'], list):
                # 未压缩RLE，需要转换
                height, width = segmentation["size"]
                rle = coco_mask.frPyObjects([segmentation], height, width)[0]  # 取第一个
                mask = coco_mask.decode(rle)
            else:
                # 已压缩RLE，直接decode
                mask = coco_mask.decode(segmentation)
        else:
            # 其他格式，跳过
            return image
        
        # 创建彩色掩码
        colored_mask = np.zeros_like(image)
        colored_mask[mask == 1] = color
        
        # 与原图融合
        image = cv2.addWeighted(image, 1, colored_mask, alpha, 0)
        return image
        
    except Exception as e:
        print(f"绘制掩码时出错: {e}")
        import traceback
        traceback.print_exc()
        return image

def batch_visualize(json_file_path, images_path, output_dir="output", confidence_threshold=0.5, start_index=0, max_count=None):
    """
    Batch visualize COCO format detections
    
    Args:
        json_file_path: Path to JSON file containing detection results (COCO dict format or list format)
        images_path: Path to directory containing images or list of image paths
        output_dir: Directory to save visualization results
        confidence_threshold: Minimum confidence score to display
    """
    
    # Load detection results
    print("Loading detection results...")
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    
    # Handle both COCO dict format and list format
    if isinstance(data, dict):
        # Standard COCO format
        detections_list = data.get('annotations', [])
        images_info = data.get('images', [])
        
        # Create mapping from image_id to filename
        image_id_to_filename = {}
        for img_info in images_info:
            # image_id_to_filename[img_info['id']] = img_info['file_name']
            image_id_to_filename[img_info['id']] = img_info['file_name'].split('/')[-1]
        
        # Apply slicing to detections_list
        if max_count != None:
            end_index = start_index + max_count
            detections_list = detections_list[start_index:end_index]
    elif isinstance(data, list):
        # List format (original)
        if max_count != None:
            end_index = start_index + max_count
            detections_list = data[start_index:end_index]
        else:
            detections_list = data
        image_id_to_filename = {}
    else:
        raise ValueError("JSON data must be either a dict (COCO format) or a list")
    
    # Group detections by image_id
    detections_by_image = defaultdict(list)
    for det in detections_list:
        if det.get("score", 1.0) >= confidence_threshold:
            detections_by_image[det['image_id']].append(det)
    
    # Get image paths
    if isinstance(images_path, str) and os.path.isdir(images_path):
        # Directory path
        image_files = {}
        
        if image_id_to_filename:
            # Use COCO format mapping
            for image_id, filename in image_id_to_filename.items():
                img_path = os.path.join(images_path, filename)
                if os.path.exists(img_path):
                    image_files[image_id] = img_path
        else:
            # Original logic for list format
            for filename in sorted(os.listdir(images_path)):
                if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    # Extract image_id from filename (assuming filename contains image_id)
                    # You may need to modify this based on your naming convention
                    # image_id = int(os.path.splitext(filename)[0].split('_')[-1])
                    image_id = os.path.splitext(filename)[0].split('_')[-1]
                    image_files[image_id] = os.path.join(images_path, filename)
    else:
        # List of image paths
        image_files = {}
        for img_path in images_path:
            filename = os.path.basename(img_path)
            if image_id_to_filename:
                # Find image_id by matching filename
                for img_id, fname in image_id_to_filename.items():
                    if fname == filename:
                        image_files[img_id] = img_path
                        break
            else:
                # Original logic
                image_id = int(os.path.splitext(filename)[0])
                image_files[image_id] = img_path
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Process each image
    print(f"Processing {len(detections_by_image)} images...")
    processed_count = 0

    for image_id, detections in tqdm(detections_by_image.items()):
        # Find corresponding image file
        if image_id not in image_files:
            print(f"Warning: Image file not found for image_id {image_id}")
            continue
        
        img_path = image_files[image_id]
        
        # Load image
        try:
            image = cv2.imread(img_path)
            if image is None:
                print(f"Warning: Cannot load image {img_path}")
                continue
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            continue

        colors = [
            (255, 0, 0),    # 红色
            (255, 165, 0),  # 橙色
            (255, 255, 0),  # 黄色
            (0, 255, 0),    # 绿色
            (0, 0, 255),    # 蓝色
            (128, 0, 128),  # 紫色
            (255, 192, 203), # 粉色
        ]
        bgr_colors = [(b, g, r) for (r, g, b) in colors]
        for idx, det in enumerate(detections):
            seg = det['segmentation']
            color = bgr_colors[idx % len(bgr_colors)]

            image = draw_mask(image, seg, color)
            image = draw_bbox(image, det['bbox'], color)

        output_filename = f"{image_id}_bright.jpg"
        output_path = os.path.join(output_dir, output_filename)
        cv2.imwrite(output_path, image)
        processed_count += 1

    print(f"Successfully processed {processed_count} images")
    print(f"Results saved to: {output_dir}")

DATASET_ROOT=f"/data/xxx/datasets"
EVAL_ROOT="coler_official_eval/cls_agnostic"
result_eval_dict={
    "coco20k": [f"{DATASET_ROOT}/coco/train2014", f"{EVAL_ROOT}_coco20k/inference/coco_instances_results.json"],
    "voc": [f"{DATASET_ROOT}/voc/VOC2007/JPEGImages", f"{EVAL_ROOT}_voc/inference/coco_instances_results.json"],
    "coco": [f"{DATASET_ROOT}/coco/val2017", f"{EVAL_ROOT}_coco/inference/coco_instances_results.json"],
    "imagenet": [f"{DATASET_ROOT}/imagenet/val", f"{EVAL_ROOT}_imagenet/inference/coco_instances_results.json"],
    "lvis": [f"{DATASET_ROOT}/coco/train2014", f"{EVAL_ROOT}_lvis/inference/coco_instances_results.json"],
    "clipart": [f"{DATASET_ROOT}/clipart/JPEGImages", f"{EVAL_ROOT}_clipart/inference/coco_instances_results.json"],
    "watercolor": [f"{DATASET_ROOT}/watercolor/JPEGImages", f"{EVAL_ROOT}_watercolor/inference/coco_instances_results.json"],
    "comic": [f"{DATASET_ROOT}/comic/JPEGImages", f"{EVAL_ROOT}_comic/inference/coco_instances_results.json"],
    "kitti": [f"{DATASET_ROOT}/kitti/JPEGImages", f"{EVAL_ROOT}_kitti/inference/coco_instances_results.json"],
    "openimages": [f"{DATASET_ROOT}/openimages-v7/validation", f"{EVAL_ROOT}_openimages/inference/coco_instances_results.json"],
    "objects365": [f"{DATASET_ROOT}/objects365/val/images/v2/patch16", f"{EVAL_ROOT}_objects365/inference/coco_instances_results.json"],
}

def visualize_all():
    dataset_names = [
        # "coco", "coco20k", 
        "lvis", "voc", 
        # "clipart", "watercolor", "comic", 
        "kitti", "openimages", 
        # "objects365",
        ]
    # dataset_names = ["objects365"]
    # dataset_names = ["openimages"]
    for dataset_name in dataset_names:
        dataset_path = result_eval_dict[dataset_name][0]
        ann_file = result_eval_dict[dataset_name][1]
        batch_visualize(
            json_file_path=ann_file,
            images_path=dataset_path,
            output_dir=f"visualization_{dataset_name}_COLER",
            confidence_threshold=0.5,
            max_count=None,
        )

if __name__ == "__main__":
    visualize_all()
    exit(0)
    # ann_file = "cutonce_eval/coco_val2017/inference_best/coco_instances_results.json"
    # batch_visualize(
    #     json_file_path=ann_file,
    #     images_path="/data/xxx/datasets/coco/val2017",
    #     output_dir="visualization_COCOval2017_COLER",
    #     confidence_threshold=0.5,
    #     # max_count=2000,
    # )

    # ann_file = "/data/xxx/segmentation/CutLER/maskcut/maskcut_annotation/coco_train_fixsize480_tau0.15_N3.json"
    ann_file = "/data/xxx/segmentation/CuVLER/datasets/coco/annotations/coco_votecut_kmax_3_tuam_0.2-official.json"
    # ann_file = "pseudo_labels/coco_val2017_cutonce_factor0.1.json"
    # ann_file = "/data/xxx/datasets/coco/annotations/coco_cls_agnostic_instances_val2017.json"
    # ann_file = "pseudo_labels/coco_val2017_cutonce_improve.json"
    # ann_file = "/data/xxx/segmentation/CutLER/coler_eval/cls_agnostic_imagenet/inference_base/coler_self_training_r1.json"
    ann_file = "/data/xxx/segmentation/CutLER/output/cls_agnostic_coco/inference/coco_instances_results.json"
    batch_visualize(
        json_file_path=ann_file,
        images_path="/data/xxx/datasets/coco/val2017",
        # images_path="/data/xxx/datasets/imagenet/val",
        # output_dir="visualization_COCOval2017_VoteCut",
        # output_dir="visualization_ImageNet_coler_self_train_r1",
        output_dir="visualization_COCOval2017_CutLER",
        confidence_threshold=0.5,
        # max_count=2000,
    )
