import os
import cv2
import numpy as np
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
import argparse

def extract_image_id_from_filename(filename):
   """从文件名提取image_id"""
   # 假设文件名格式为 000000391895.jpg
   basename = os.path.splitext(filename)[0]
   # 去掉前导零，转换为整数
   return int(basename.lstrip('0')) if basename.lstrip('0') else 0

def draw_bbox(image, bbox, color=(0, 255, 0), thickness=2):
   """在图像上绘制边界框"""
   x, y, w, h = bbox
   x, y, w, h = int(x), int(y), int(w), int(h)
   cv2.rectangle(image, (x, y), (x + w, y + h), color, thickness)
   return image

def draw_mask(image, segmentation, color=(0, 255, 0), alpha=0.5):
    """在图像上绘制掩码"""
    try:
        if isinstance(segmentation, list):
            # 检查是否是polygon格式还是RLE格式
            if len(segmentation) > 0 and isinstance(segmentation[0], list):
                # polygon格式 - 每个元素是坐标列表
                mask = np.zeros(image.shape[:2], dtype=np.uint8)
                for seg in segmentation:
                    if len(seg) >= 6:  # 至少需要3个点(6个坐标)
                        poly = np.array(seg).reshape(-1, 2).astype(np.int32)
                        cv2.fillPoly(mask, [poly], 1)
            else:
                # RLE格式但是以list形式存储
                # 需要构造RLE字典
                if len(segmentation) == 1 and isinstance(segmentation[0], dict):
                    # 已经是RLE字典格式
                    mask = coco_mask.decode(segmentation[0])
                else:
                    # 跳过无法处理的格式
                    return image
        # RLE dict格式处理应该是：
        elif isinstance(segmentation, dict):
            if isinstance(segmentation['counts'], list):
                # 未压缩RLE，需要转换
                height, width = segmentation["size"]
                rle = coco_mask.frPyObjects([segmentation], height, width)[0]  # 取第一个
                mask = coco_mask.decode(rle)
            else:
                # 已压缩RLE，直接decode
                mask = coco_mask.decode(segmentation)
        else:
            # 其他格式，跳过
            return image
        
        # 创建彩色掩码
        colored_mask = np.zeros_like(image)
        colored_mask[mask == 1] = color
        
        # 与原图融合
        image = cv2.addWeighted(image, 1, colored_mask, alpha, 0)
        return image
        
    except Exception as e:
        print(f"绘制掩码时出错: {e}")
        import traceback
        traceback.print_exc()
        return image

def visualize_coco_annotations(image_dir, annotation_file, output_dir):
    """可视化COCO格式的标注"""
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 加载COCO数据
    coco = COCO(annotation_file)
    
    # 获取所有图像文件
    image_files = [f for f in os.listdir(image_dir) 
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    # 预定义颜色列表
    colors = [
        (255, 0, 0),    # 红色
        (255, 165, 0),  # 橙色
        (255, 255, 0),  # 黄色
        (0, 255, 0),    # 绿色
        (0, 0, 255),    # 蓝色
        (128, 0, 128),  # 紫色
        (255, 192, 203), # 粉色
    ]
    
    print(f"开始处理 {len(image_files)} 张图像...")
    

    for i, image_file in enumerate(image_files):
        try:
            # 提取image_id
            image_id = extract_image_id_from_filename(image_file)
            
            # 检查该图像是否在COCO数据中
            img_info = coco.loadImgs(image_id)
            if not img_info:
                print(f"警告: 图像 {image_file} (ID: {image_id}) 不在标注文件中")
                continue
            
            # 加载图像
            image_path = os.path.join(image_dir, image_file)
            image = cv2.imread(image_path)
            if image is None:
                print(f"错误: 无法加载图像 {image_file}")
                continue
            
            # 获取该图像的所有标注
            ann_ids = coco.getAnnIds(imgIds=image_id)
            annotations = coco.loadAnns(ann_ids)
            
            # 在图像上绘制标注
            for j, ann in enumerate(annotations):
                color = colors[j % len(colors)]
                
                # 绘制边界框
                if 'bbox' in ann:
                    image = draw_bbox(image, ann['bbox'], color)
                
                # 绘制掩码 - 添加更多检查
                if 'segmentation' in ann and ann['segmentation']:
                    # 检查segmentation是否为空
                    seg = ann['segmentation']
                    if seg:  # 确保不是空列表
                        image = draw_mask(image, seg, color)
            # 在图像上绘制标注
            # for j, ann in enumerate(annotations):
            #     color = colors[j % len(colors)]
                
            #     # 绘制边界框
            #     if 'bbox' in ann:
            #         image = draw_bbox(image, ann['bbox'], color)
                
            #     # 绘制掩码
            #     if 'segmentation' in ann:
            #         image = draw_mask(image, ann['segmentation'], color)
            
            # 保存可视化结果
            output_filename = f"{image_id}_gt.jpg"
            output_path = os.path.join(output_dir, output_filename)
            cv2.imwrite(output_path, image)
            
            if (i + 1) % 100 == 0:
                print(f"已处理 {i + 1}/{len(image_files)} 张图像")
                
        except Exception as e:
            print(f"处理图像 {image_file} 时出错: {e}")
            continue
    
    print(f"处理完成！结果保存在 {output_dir}")

def main():
   parser = argparse.ArgumentParser(description='可视化COCO格式标注')
   parser.add_argument('--image_dir', type=str, default='/data/xxx/datasets/coco/val2017', 
                       help='图像目录路径')
   parser.add_argument('--annotation_file', type=str, default='/data/xxx/datasets/coco/annotations/coco_cls_agnostic_instances_val2017.json', 
                       help='COCO格式标注文件路径')
   parser.add_argument('--output_dir', type=str, default='visulization_COCOval2017-Ground_truth', 
                       help='输出目录路径')
   
   args = parser.parse_args()
   print(args)

   visualize_coco_annotations(args.image_dir, args.annotation_file, args.output_dir)

if __name__ == "__main__":
   main()