#!/usr/bin/env python3
import os
import cv2
import argparse
import numpy as np
from PIL import Image, ImageDraw
from ultralytics import YOLO


iou_t = 0.1

def parse_args():
    parser = argparse.ArgumentParser(
        description='保存视频所有帧到images目录，并在前3秒内查找最佳帧（满足人数最多且所有bbox IOU均低于阈值），'
                    '或直接指定处理的帧号。'
    )
    parser.add_argument('--input_path', type=str, required=True,
                        help='输入mp4视频文件路径')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='保存输出文件的文件夹')
    parser.add_argument('--conf_threshold', type=float, default=0.6,
                        help='检测bbox置信度阈值')
    parser.add_argument('--frame_index', type=int, default=None,
                        help='指定要处理的帧号。如果指定，则直接提取该帧的人体边界框，但仍会保存所有帧。')
    return parser.parse_args()

def compute_iou(box1, box2):
    x1, y1, x2, y2 = box1
    a1, b1, a2, b2 = box2

    inter_x1 = max(x1, a1)
    inter_y1 = max(y1, b1)
    inter_x2 = min(x2, a2)
    inter_y2 = min(y2, b2)
    if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
        return 0.0
    inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
    area1 = (x2 - x1) * (y2 - y1)
    area2 = (a2 - a1) * (b2 - b1)
    union_area = area1 + area2 - inter_area
    return inter_area / union_area if union_area else 0.0

def all_pairs_low_iou(boxes):
    """检查所有边界框对是否都有较低的IOU（低于阈值）"""
    for i in range(len(boxes)):
        for j in range(i+1, len(boxes)):
            if compute_iou(boxes[i], boxes[j]) >= iou_t:
                return False
    return True

def save_all_frames_and_find_best_frame(video_path, output_dir, model, conf_threshold, specified_frame=None):
    cap = cv2.VideoCapture(video_path)
    

    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps <= 0:
        fps = 30
    
    images_dir = os.path.join(output_dir, "images")
    os.makedirs(images_dir, exist_ok=True)
    

    frame_index = 0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"开始保存所有帧，总帧数：{total_frames}")
    

    three_seconds_frames = int(fps * 3)
    candidate_frames = []
    any_detection_frame = None
    any_detection_boxes = None
    max_people_count = 0
    specified_frame_boxes = None
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        img_path = os.path.join(images_dir, f"frame_{frame_index}.jpg")
        cv2.imwrite(img_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
        

        if specified_frame is not None and frame_index == specified_frame:
            results = model(frame, verbose=False)[0]
            specified_frame_boxes = [
                box.xyxy[0].cpu().numpy()
                for box in results.boxes if box.cls == 0 and box.conf[0] >= conf_threshold
            ]
        

        if specified_frame is None and frame_index < three_seconds_frames:
            results = model(frame, verbose=False)[0]
            valid_boxes = [
                box.xyxy[0].cpu().numpy()
                for box in results.boxes if box.cls == 0 and box.conf[0] >= conf_threshold
            ]
            
            if len(valid_boxes) > max_people_count:
                max_people_count = len(valid_boxes)
            
            if len(valid_boxes) > 0 and any_detection_frame is None:
                any_detection_frame = frame_index
                any_detection_boxes = valid_boxes
            
            if len(valid_boxes) >= 2 and all_pairs_low_iou(valid_boxes):
                candidate_frames.append({
                    'frame_index': frame_index,
                    'boxes': valid_boxes,
                    'num_people': len(valid_boxes)
                })
        
        frame_index += 1
        if frame_index % 100 == 0:
            print(f"已处理 {frame_index}/{total_frames} 帧")
    
    cap.release()
    print(f"已保存所有帧到：{images_dir}")
    

    if specified_frame is not None:
        if specified_frame_boxes is not None:
            print(f"已处理指定的帧 {specified_frame}，检测到 {len(specified_frame_boxes)} 个人")
            return specified_frame, specified_frame_boxes
        else:
            print(f"未能处理指定的帧 {specified_frame}，可能超出视频范围")
            return None, None
    

    best_frame = None
    candidate_frames.sort(key=lambda x: x['num_people'], reverse=True)
    
    if candidate_frames:
        max_people_in_candidates = candidate_frames[0]['num_people']
        best_candidates = [f for f in candidate_frames if f['num_people'] == max_people_in_candidates]
        best_frame = best_candidates[0]
        print(f"找到最佳帧 {best_frame['frame_index']}，包含 {best_frame['num_people']} 人，且所有人的IOU都低于阈值")
        return best_frame['frame_index'], best_frame['boxes']
    
    if any_detection_frame is not None:
        print(f"未找到所有bbox IOU都低的帧，使用第一个有检测结果的帧: {any_detection_frame}，包含 {len(any_detection_boxes)} 人")
        return any_detection_frame, any_detection_boxes
    
    return None, None


def save_visualization_and_boxes(frame_index, boxes, output_dir):
    img_path = os.path.join(output_dir, "images", f"frame_{frame_index}.jpg")
    frame = cv2.imread(img_path)
    
    if frame is None:
        print(f"无法读取已保存的图像: {img_path}")
        return
    
    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img)
    

    boxes_with_centers = []
    for box in boxes:
        x1, y1, x2, y2 = box.tolist()
        center_x = (x1 + x2) / 2
        boxes_with_centers.append((box, center_x))
    

    boxes_with_centers.sort(key=lambda x: x[1])
    sorted_boxes = [box for box, _ in boxes_with_centers]
    

    for box in boxes:
        draw.rectangle(box.tolist(), outline="red", width=3)
    
    vis_path = os.path.join(output_dir, f"frame_{frame_index}_detected.jpg")
    img.save(vis_path, quality=80)
    print(f"已保存检测可视化结果到：{vis_path}")
    

    txt_path = os.path.join(output_dir, f"frame_{frame_index}_boxes.txt")
    with open(txt_path, 'w') as f:
        for box in sorted_boxes:
            x1, y1, x2, y2 = box.tolist()
            f.write(f"{x1},{y1},{x2},{y2}\n")
    print(f"已保存bbox坐标到：{txt_path}（已按照中心点从左到右排序）")


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    model = YOLO('yolov8x.pt')
    
    frame_index, boxes = save_all_frames_and_find_best_frame(
        args.input_path, args.output_dir, model, args.conf_threshold, args.frame_index
    )
    
    if frame_index is not None:
        save_visualization_and_boxes(frame_index, boxes, args.output_dir)
        print(f"成功为第 {frame_index} 帧保存检测结果")
    else:
        print("未在视频中检测到任何人。")
        dummy_frame = 0
        dummy_boxes = np.array([[100, 100, 300, 400]])
        save_visualization_and_boxes(dummy_frame, dummy_boxes, args.output_dir)
        print("为第 0 帧创建了虚拟检测结果，以防下游处理出错。")

if __name__ == '__main__':
    main()