#!/usr/bin/env python3
"""
掩码压缩工具
功能：根据压缩比率n，对每个类别只保留1/n的不全为黑的掩码，其余替换为全黑掩码
"""

import os
import json
import argparse
from PIL import Image
import numpy as np
from tqdm import tqdm
import glob

def is_black_mask(image_path, threshold=0.0):
    """
    判断掩码是否为完全全黑
    
    Args:
        image_path: 图片路径
        threshold: 非黑像素比例阈值，默认0.0（完全全黑）
    
    Returns:
        bool: True表示完全全黑，False表示不全黑
    """
    try:
        with Image.open(image_path) as img:
            # 转换为灰度图
            if img.mode != 'L':
                img = img.convert('L')
            
            # 转换为numpy数组
            img_array = np.array(img)
            
            # 检查是否有任何非零像素（完全全黑 = 所有像素都是0）
            has_non_zero = np.any(img_array > 0)
            
            return not has_non_zero
            
    except Exception as e:
        print(f"处理图片 {image_path} 时出错: {e}")
        return True  # 出错时当作全黑处理

def create_black_mask(width, height):
    """
    创建全黑掩码
    
    Args:
        width: 宽度
        height: 高度
    
    Returns:
        PIL.Image: 全黑掩码图片
    """
    return Image.new('L', (width, height), 0)

def compress_masks(input_folder, compression_ratio, output_folder=None):
    """
    压缩掩码函数
    
    Args:
        input_folder: 输入文件夹路径
        compression_ratio: 压缩比率n（保留1/n的掩码）
        output_folder: 输出文件夹路径，如果为None则覆盖原文件
    
    Returns:
        dict: 处理结果统计
    """
    # 读取info.json
    info_path = os.path.join(input_folder, 'info.json')
    if not os.path.exists(info_path):
        raise FileNotFoundError(f"找不到info.json文件: {info_path}")
    
    with open(info_path, 'r', encoding='utf-8') as f:
        info = json.load(f)
    
    total_categories = info['total_categories']
    image_resolution = info['image_resolution']
    width, height = image_resolution['width'], image_resolution['height']
    
    print(f"总类别数: {total_categories}")
    print(f"图片分辨率: {width}x{height}")
    print(f"压缩比率: 1/{compression_ratio}")
    
    # 如果没有指定输出文件夹，则覆盖原文件
    if output_folder is None:
        output_folder = input_folder
        print("将覆盖原文件")
    else:
        os.makedirs(output_folder, exist_ok=True)
        print(f"输出到文件夹: {output_folder}")
    
    # 复制info.json到输出文件夹
    output_info_path = os.path.join(output_folder, 'info.json')
    with open(output_info_path, 'w', encoding='utf-8') as f:
        json.dump(info, f, indent=2, ensure_ascii=False)
    
    # 统计结果
    stats = {
        'total_categories': total_categories,
        'compression_ratio': compression_ratio,
        'categories_processed': 0,
        'total_images_processed': 0,
        'total_images_kept': 0,
        'total_images_replaced': 0,
        'category_details': {}
    }
    
    # 处理每个类别
    for category_id in tqdm(range(total_categories), desc="处理类别"):
        # 查找该类别的所有图片
        pattern = os.path.join(input_folder, f"{category_id}_*.png")
        category_images = glob.glob(pattern)
        
        if not category_images:
            print(f"警告：类别 {category_id} 没有找到图片")
            continue
        
        # 按文件名排序
        category_images.sort()
        
        # 收集不全为黑的掩码
        non_black_masks = []
        for img_path in category_images:
            if not is_black_mask(img_path):
                non_black_masks.append(img_path)
        
        total_masks = len(category_images)
        non_black_count = len(non_black_masks)
        
        print(f"\n类别 {category_id}:")
        print(f"  总图片数: {total_masks}")
        print(f"  不全为黑的掩码数: {non_black_count}")
        
        # 计算要保留的掩码数量
        masks_to_keep = max(1, non_black_count // compression_ratio)
        print(f"  保留掩码数: {masks_to_keep}")
        
        # 选择要保留的掩码（均匀分布）
        if non_black_count > 0:
            step = max(1, non_black_count // masks_to_keep)
            indices_to_keep = list(range(0, non_black_count, step))[:masks_to_keep]
            masks_to_keep_paths = [non_black_masks[i] for i in indices_to_keep]
        else:
            masks_to_keep_paths = []
        
        # 处理每张图片
        images_kept = 0
        images_replaced = 0
        
        for img_path in category_images:
            filename = os.path.basename(img_path)
            output_path = os.path.join(output_folder, filename)
            
            if img_path in masks_to_keep_paths:
                # 保留原图
                if output_path != img_path:  # 如果不是覆盖原文件
                    Image.open(img_path).save(output_path)
                images_kept += 1
            else:
                # 替换为全黑掩码
                black_mask = create_black_mask(width, height)
                black_mask.save(output_path)
                images_replaced += 1
        
        # 更新统计
        stats['category_details'][category_id] = {
            'total_images': total_masks,
            'non_black_masks': non_black_count,
            'masks_kept': len(masks_to_keep_paths),
            'masks_replaced': total_masks - len(masks_to_keep_paths)
        }
        
        stats['total_images_processed'] += total_masks
        stats['total_images_kept'] += len(masks_to_keep_paths)
        stats['total_images_replaced'] += (total_masks - len(masks_to_keep_paths))
        stats['categories_processed'] += 1
        
        print(f"  保留: {len(masks_to_keep_paths)}, 替换: {total_masks - len(masks_to_keep_paths)}")
    
    return stats

def main():
    parser = argparse.ArgumentParser(description="掩码压缩工具")
    parser.add_argument("input_folder", help="输入文件夹路径")
    parser.add_argument("compression_ratio", type=int, help="压缩比率n（保留1/n的掩码）")
    parser.add_argument("-o", "--output", help="输出文件夹路径（可选，默认覆盖原文件）")
    parser.add_argument("--threshold", type=float, default=0.0, 
                       help="全黑掩码判断阈值（默认0.0，即完全全黑）")
    
    args = parser.parse_args()
    
    # 检查输入文件夹
    if not os.path.exists(args.input_folder):
        print(f"错误：输入文件夹不存在: {args.input_folder}")
        return
    
    # 检查info.json
    info_path = os.path.join(args.input_folder, 'info.json')
    if not os.path.exists(info_path):
        print(f"错误：找不到info.json文件: {info_path}")
        return
    
    try:
        # 执行压缩
        stats = compress_masks(args.input_folder, args.compression_ratio, args.output)
        
        # 输出统计结果
        print("\n" + "="*50)
        print("处理完成！统计结果：")
        print(f"总类别数: {stats['total_categories']}")
        print(f"已处理类别数: {stats['categories_processed']}")
        print(f"总图片数: {stats['total_images_processed']}")
        print(f"保留图片数: {stats['total_images_kept']}")
        print(f"替换图片数: {stats['total_images_replaced']}")
        print(f"压缩比率: 1/{stats['compression_ratio']}")
        
        if args.output:
            print(f"结果已保存到: {args.output}")
        else:
            print("原文件已被覆盖")
            
    except Exception as e:
        print(f"处理过程中出错: {e}")
        return

if __name__ == "__main__":
    main()
