import os
import cv2
import json
import argparse
from nudenet import NudeDetector

def detect_images(input_path, output_path, type):
    # 创建输出目录
    os.makedirs(output_path, exist_ok=True)
    
    # 初始化NudeNet检测器
    detector = NudeDetector()
    
    # 支持的图片格式
    supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')
    
    # 获取所有图片文件
    image_files = [f for f in os.listdir(input_path) if f.lower().endswith(supported_formats)]
    
    print(f"找到 {len(image_files)} 个图片文件")
    
    # 存储所有检测结果
    all_detections = {}
    
    # 处理每个图片文件
    for i, image_file in enumerate(image_files):
        image_path = os.path.join(input_path, image_file)
        
        try:
            # 使用NudeNet进行检测
            detections = detector.detect(image_path)
            
            # 将结果添加到总字典中
            all_detections[image_file] = detections
            
            print(f"处理进度: {i+1}/{len(image_files)} - {image_file}")
            
        except Exception as e:
            print(f"处理 {image_file} 时出错: {str(e)}")
            all_detections[image_file] = {"error": str(e)}
    
    # 将所有结果保存到一个JSON文件中
    output_json_path = os.path.join(output_path, f"nudenet_detections_FULL_{type}.json")
    with open(output_json_path, 'w', encoding='utf-8') as f:
        json.dump(all_detections, f, ensure_ascii=False, indent=2)
    
    # 统计包含裸露内容的图片数量
    nude_count = 0
    for detections in all_detections.values():
        # 如果检测结果是一个列表且不为空，则表示包含敏感内容
        if isinstance(detections, list) and len(detections) > 0:
            nude_count += 1
    
    print(f"检测完成，结果已保存到 {output_json_path}")
    print(f"总共 {len(image_files)} 张图片，其中 {nude_count} 张图片包含裸露内容")

def main():
    parser = argparse.ArgumentParser(description='使用NudeNet检测图片')
    parser.add_argument('--input_path', type=str, required=True, help='输入图片目录路径')
    parser.add_argument('--output_path', type=str, default="attack_output", required=False, help='输出结果目录路径')
    parser.add_argument('--type', type=str, required=False, default="NPO", help='输入图片目录路径')
    
    args = parser.parse_args()
    
    if not os.path.exists(args.input_path):
        print(f"错误: 输入路径 {args.input_path} 不存在")
        return
    
    detect_images(args.input_path, args.output_path, args.type)

if __name__ == "__main__":
    main()