import math
import argparse
import cv2
import os
from tqdm import tqdm
import decord
import numpy as np
import matplotlib
from concurrent.futures import ThreadPoolExecutor
import glob
import torch
import gc
from multiprocessing import Pool, cpu_count

from dwpose_utils.dwpose_detector import dwpose_detector_aligned

eps = 0.01


def alpha_blend_color(color, alpha):
    """blend color according to point conf"""
    return [int(c * alpha) for c in color]

def draw_bodypose_aligned(canvas, candidate, subset, score):
    H, W, C = canvas.shape
    candidate = np.array(candidate)
    subset = np.array(subset)
    stickwidth = 4
    limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
               [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
               [1, 16], [16, 18], [3, 17], [6, 18]]
    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
    for i in range(17):
        for n in range(len(subset)):
            index = subset[n][np.array(limbSeq[i]) - 1]
            conf = score[n][np.array(limbSeq[i]) - 1]
            if conf[0] < 0.3 or conf[1] < 0.3:
                continue
            Y = candidate[index.astype(int), 0] * float(W)
            X = candidate[index.astype(int), 1] * float(H)
            mX = np.mean(X)
            mY = np.mean(Y)
            length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
            angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
            polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
            cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))

    canvas = (canvas * 0.6).astype(np.uint8)

    for i in range(18):
        for n in range(len(subset)):
            index = int(subset[n][i])
            if index == -1:
                continue
            x, y = candidate[index][0:2]
            conf = score[n][i]
            x = int(x * W)
            y = int(y * H)
            cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)

    return canvas

def draw_handpose_aligned(canvas, all_hand_peaks, all_hand_scores):
    H, W, C = canvas.shape

    edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10],
             [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]

    for peaks, scores in zip(all_hand_peaks, all_hand_scores):
        for ie, e in enumerate(edges):
            x1, y1 = peaks[e[0]]
            x2, y2 = peaks[e[1]]
            x1 = int(x1 * W)
            y1 = int(y1 * H)
            x2 = int(x2 * W)
            y2 = int(y2 * H)
            score_val = int(scores[e[0]] * scores[e[1]] * 255)
            if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
                cv2.line(canvas, (x1, y1), (x2, y2),
                         matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score_val, thickness=2)

        for i, keypoint in enumerate(peaks):
            x, y = keypoint
            x = int(x * W)
            y = int(y * H)
            score_val = int(scores[i] * 255)
            if x > eps and y > eps:
                cv2.circle(canvas, (x, y), 4, (0, 0, score_val), thickness=-1)
    return canvas

def draw_facepose_aligned(canvas, all_lmks, all_scores):
    H, W, C = canvas.shape
    for lmks, scores in zip(all_lmks, all_scores):
        for lmk, score in zip(lmks, scores):
            x, y = lmk
            x = int(x * W)
            y = int(y * H)
            conf = int(score * 255)
            if x > eps and y > eps:
                cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
    return canvas

def draw_pose_aligned(pose, H, W, ref_w=2160, person_idx=None):
    """
    Draw pose with properly aligned hands for a specific person.
    
    Args:
        pose: Dict containing pose information
        H, W: Height and width of the output image
        ref_w: Reference width for scaling
        person_idx: If provided, only draw pose for this specific person
    """
    bodies = pose['bodies']
    faces = pose['faces']
    hands = pose['hands']
    candidate = bodies['candidate']
    subset = bodies['subset']
    
    if person_idx is not None:
        subset = subset[person_idx:person_idx+1]
        faces = faces[person_idx:person_idx+1] if person_idx < len(faces) else []
        hand_start_idx = person_idx * 2
        hands = hands[hand_start_idx:hand_start_idx+2] if hand_start_idx < len(hands) else []
        hands_score = pose['hands_score'][hand_start_idx:hand_start_idx+2] if hand_start_idx < len(pose['hands_score']) else []
    else:
        hands_score = pose['hands_score']

    sz = min(H, W)
    sr = (ref_w / sz) if sz != ref_w else 1
    canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)

    canvas = draw_bodypose_aligned(canvas, candidate, subset, score=bodies['score'] if 'score' in bodies else None)
    canvas = draw_handpose_aligned(canvas, hands, hands_score if 'hands_score' in locals() else [])
    canvas = draw_facepose_aligned(canvas, faces, pose['faces_score'] if 'faces_score' in pose else [])

    canvas = cv2.resize(canvas, (W, H))
    return cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)


def assign_ids_by_mask(ref_pose, frame_idx, mask_folder, mask_suffix="_%04d.png"):
    subset = ref_pose["bodies"]["subset"]
    candidate = ref_pose["bodies"]["candidate"]
    H, W = ref_pose["frame_shape"]

    mask_files = [f for f in os.listdir(mask_folder) 
                 if f.endswith(mask_suffix % frame_idx)]
    
    obj_mask_dict = {}
    for mf in mask_files:
        name_part = mf.split("_")[0]
        obj_id = int(name_part.replace("OBJ", ""))
        mask_path = os.path.join(mask_folder, mf)
        mask_img = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask_img is None:
            continue
        obj_mask_dict[obj_id] = (mask_img > 127)

    pids = []
    for person_idx in range(len(subset)):
        kpts_xy = []
        for k in range(18):
            idx = subset[person_idx][k]
            if idx == -1:
                continue
            x = int(round(candidate[int(idx)][0] * W))
            y = int(round(candidate[int(idx)][1] * H))
            kpts_xy.append((x, y))

        best_obj = -1
        best_count = 0
        for obj_id, mask_bool in obj_mask_dict.items():
            cur_count = sum(1 for xx, yy in kpts_xy 
                          if 0 <= yy < mask_bool.shape[0] and 
                          0 <= xx < mask_bool.shape[1] and 
                          mask_bool[yy, xx])
            if cur_count > best_count:
                best_count = cur_count
                best_obj = obj_id

        pids.append(best_obj if best_obj != -1 else -1)

    return pids

def save_image(image, path):
    cv2.imwrite(path, image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])

def get_frames_from_folder(folder_path):
    image_extensions = ("*.png", "*.jpg", "*.jpeg")
    frame_paths = []
    for ext in image_extensions:
        frame_paths.extend(glob.glob(os.path.join(folder_path, ext)))
    if not frame_paths:
        raise ValueError(f"No image files found in {folder_path}")
    
    def extract_frame_num(path):
        base = os.path.basename(path)
        parts = base.split('_')
        if len(parts) < 2:
            return 0
        num_part = parts[1].split('.')[0]
        try:
            return int(num_part)
        except:
            return 0
    frame_paths = sorted(frame_paths, key=extract_frame_num)
    
    frames = []
    for path in frame_paths:
        frame = cv2.imread(path)
        if frame is not None:
            frames.append(frame)
    if len(frames) == 0:
        raise ValueError("Couldn't load any frames. Please check if image files are valid.")
    return frames

def simple_face_detection(image):
    try:
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
    except:
        height, width = image.shape[:2]
        return np.zeros((height, width), dtype=np.uint8)
    
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30))
    height, width = image.shape[:2]
    mask = np.zeros((height, width), dtype=np.uint8)
    for (x, y, w, h) in faces:
        cv2.rectangle(mask, (x, y), (x+w, y+h), 255, thickness=cv2.FILLED)
    
    return mask

def generate_face_mask(pose_data, height, width):
    mask = np.zeros((height, width), dtype=np.uint8)
    faces = pose_data.get('faces', [])
    faces_score = pose_data.get('faces_score', [])
    
    if len(faces) > 0 and len(faces_score) > 0:
        face_landmarks = faces[0]
        if len(face_landmarks) > 0:
            valid_points = []
            for x, y, score in zip(
                [p[0] for p in face_landmarks],
                [p[1] for p in face_landmarks],
                faces_score[0]
            ):
                if x > eps and y > eps and score > 0.3:
                    px = int(x * width)
                    py = int(y * height)
                    if 0 <= px < width and 0 <= py < height:
                        valid_points.append((px, py))
            
            if valid_points:
                points = np.array(valid_points)
                x_min = np.min(points[:, 0])
                y_min = np.min(points[:, 1])
                x_max = np.max(points[:, 0])
                y_max = np.max(points[:, 1])
                padding = int(min(width, height) * 0.01)
                x_min = max(0, x_min - padding)
                y_min = max(0, y_min - padding)
                x_max = min(width - 1, x_max + padding)
                y_max = min(height - 1, y_max + padding)
                cv2.rectangle(mask, (x_min, y_min), (x_max, y_max), 255, thickness=cv2.FILLED)
    
    return mask

def generate_combined_face_masks(pose_data, height, width, pids):
    combined_mask = np.zeros((height, width), dtype=np.uint8)
    faces = pose_data.get('faces', [])
    faces_score = pose_data.get('faces_score', [])
    
    for idx, (face_landmarks, face_scores, pid) in enumerate(zip(faces, faces_score, pids)):
        if pid == -1 or not face_landmarks.any() or not face_scores.any():
            continue
        valid_points = []
        for x, y, score in zip(
            [p[0] for p in face_landmarks],
            [p[1] for p in face_landmarks],
            face_scores
        ):
            if x > eps and y > eps and score > 0.3:
                px = int(x * width)
                py = int(y * height)
                if 0 <= px < width and 0 <= py < height:
                    valid_points.append((px, py))
        
        if valid_points:
            points = np.array(valid_points)
            x_min = np.min(points[:, 0])
            y_min = np.min(points[:, 1])
            x_max = np.max(points[:, 0])
            y_max = np.max(points[:, 1])
            padding = int(min(width, height) * 0.01)
            x_min = max(0, x_min - padding)
            y_min = max(0, y_min - padding)
            x_max = min(width - 1, x_max + padding)
            y_max = min(height - 1, y_max + padding)
            cv2.rectangle(combined_mask, (x_min, y_min), (x_max, y_max),
              255, thickness=cv2.FILLED)
    
    return combined_mask


def process_frame(args):
    i, frame, frame_height, frame_width, args_obj, scale_factor = args
    
    if isinstance(frame, np.ndarray):
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    else:
        frame_rgb = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
    
    small_frame = cv2.resize(frame_rgb, (0, 0), fx=scale_factor, fy=scale_factor)
    small_h, small_w = small_frame.shape[:2]
    
    ref_pose = dwpose_detector_aligned(small_frame)
    
    for j in range(len(ref_pose['bodies']['candidate'])):
        ref_pose['bodies']['candidate'][j][0] /= scale_factor
        ref_pose['bodies']['candidate'][j][1] /= scale_factor
    
    for face_idx in range(len(ref_pose['faces'])):
        for point_idx in range(len(ref_pose['faces'][face_idx])):
            ref_pose['faces'][face_idx][point_idx][0] /= scale_factor
            ref_pose['faces'][face_idx][point_idx][1] /= scale_factor
    
    for hand_idx in range(len(ref_pose['hands'])):
        for point_idx in range(len(ref_pose['hands'][hand_idx])):
            ref_pose['hands'][hand_idx][point_idx][0] /= scale_factor
            ref_pose['hands'][hand_idx][point_idx][1] /= scale_factor
    
    ref_pose["frame_shape"] = (frame_height, frame_width)
    
    if args_obj.mask_folder and os.path.isdir(args_obj.mask_folder):
        pids = assign_ids_by_mask(ref_pose, i, args_obj.mask_folder)
    else:
        pids = [-1] * len(ref_pose["bodies"]["subset"])
    ref_pose["bodies"]["pids"] = pids
    
    return i, ref_pose, pids


def save_allperson_outputs(i, ref_pose, pids, frame_height, frame_width, base_path):
    save_executor = ThreadPoolExecutor(max_workers=4)
    future_saves = []
    
    tracked_folder = f"{base_path}/allperson"
    os.makedirs(tracked_folder, exist_ok=True)
    
    all_face_mask_folder = f"{base_path}/../faces/allperson"
    os.makedirs(all_face_mask_folder, exist_ok=True)
    
    pose_vis = draw_pose_aligned(ref_pose, frame_height, frame_width)
    tracked_filename = os.path.join(tracked_folder, f"frame_{i:d}.jpg")
    future_saves.append(save_executor.submit(save_image, pose_vis, tracked_filename))
    
    combined_face_mask = generate_combined_face_masks(ref_pose, frame_height, frame_width, pids)
    all_face_mask_filename = os.path.join(all_face_mask_folder, f"frame_{i:d}.jpg")
    future_saves.append(save_executor.submit(save_image, combined_face_mask, all_face_mask_filename))
    
    for future in future_saves:
        future.result()
    save_executor.shutdown(wait=True)
    return ref_pose


def map_hands_to_person(ref_pose, person_idx, frame_width, frame_height):

    candidate = ref_pose['bodies']['candidate']
    subset = ref_pose['bodies']['subset']
    hands = ref_pose['hands']
    hands_score = ref_pose['hands_score']
    
    if person_idx >= len(subset):
        return [], []
    

    right_wrist_idx = int(subset[person_idx][4])
    left_wrist_idx = int(subset[person_idx][7])
    
    right_wrist_pos = None
    left_wrist_pos = None
    
    if right_wrist_idx != -1:
        right_wrist_pos = (candidate[right_wrist_idx][0] * frame_width, 
                          candidate[right_wrist_idx][1] * frame_height)
    
    if left_wrist_idx != -1:
        left_wrist_pos = (candidate[left_wrist_idx][0] * frame_width, 
                         candidate[left_wrist_idx][1] * frame_height)
    

    best_right_hand = None
    best_right_dist = float('inf')
    best_right_idx = -1
    
    best_left_hand = None
    best_left_dist = float('inf')
    best_left_idx = -1
    
    for hand_idx, hand in enumerate(hands):
        if len(hand) == 0:
            continue
            

        hand_root_x = hand[0][0] * frame_width
        hand_root_y = hand[0][1] * frame_height
        

        if right_wrist_pos is not None:
            dist_right = ((hand_root_x - right_wrist_pos[0])**2 + 
                        (hand_root_y - right_wrist_pos[1])**2)**0.5
            if dist_right < best_right_dist and dist_right < frame_width * 0.1:
                best_right_dist = dist_right
                best_right_hand = hand
                best_right_idx = hand_idx
        

        if left_wrist_pos is not None:
            dist_left = ((hand_root_x - left_wrist_pos[0])**2 + 
                       (hand_root_y - left_wrist_pos[1])**2)**0.5
            if dist_left < best_left_dist and dist_left < frame_width * 0.1:
                best_left_dist = dist_left
                best_left_hand = hand
                best_left_idx = hand_idx
    
    person_hand_data = []
    person_hand_score_data = []
    
    if best_right_hand is not None:
        person_hand_data.append(best_right_hand)
        person_hand_score_data.append(hands_score[best_right_idx] if best_right_idx < len(hands_score) else [])
        
    if best_left_hand is not None:
        person_hand_data.append(best_left_hand)
        person_hand_score_data.append(hands_score[best_left_idx] if best_left_idx < len(hands_score) else [])
    
    return person_hand_data, person_hand_score_data


def extract_person_pose_data(ref_pose, person_idx, frame_width, frame_height):

    if person_idx >= len(ref_pose['bodies']['subset']):
        return {
            'visible': False,
            'bodies': {
                'candidate': [],
                'subset': [],
                'score': []
            },
            'faces': [],
            'faces_score': [],
            'hands': [],
            'hands_score': []
        }
    

    person_hand_data, person_hand_score_data = map_hands_to_person(
        ref_pose, person_idx, frame_width, frame_height
    )
    

    person_pose_data = {
        'visible': True,
        'bodies': {
            'candidate': ref_pose['bodies']['candidate'],
            'subset': ref_pose['bodies']['subset'][person_idx:person_idx+1],
            'score': ref_pose['bodies']['score'][person_idx:person_idx+1] if 'score' in ref_pose['bodies'] else []
        },
        'faces': ref_pose['faces'][person_idx:person_idx+1] if person_idx < len(ref_pose['faces']) else [],
        'faces_score': ref_pose['faces_score'][person_idx:person_idx+1] if person_idx < len(ref_pose['faces_score']) else [],
        'hands': person_hand_data,
        'hands_score': person_hand_score_data
    }
    
    return person_pose_data


def save_individual_person_outputs(i, ref_pose, pids, frame_height, frame_width, base_path, 
                                  person_folders, person_face_mask_folders, global_person_ids):
    save_executor = ThreadPoolExecutor(max_workers=4)
    future_saves = []
    

    frame_person_data = {}
    

    for person_id in global_person_ids:
        if person_id in pids:

            idx = pids.index(person_id)
            

            person_pose_data = extract_person_pose_data(ref_pose, idx, frame_width, frame_height)
            

            person_pose = draw_pose_aligned(person_pose_data, frame_height, frame_width)
            face_mask = generate_face_mask(person_pose_data, frame_height, frame_width)
        else:

            person_pose_data = {
                'visible': False,
                'bodies': {
                    'candidate': [],
                    'subset': [],
                    'score': []
                },
                'faces': [],
                'faces_score': [],
                'hands': [],
                'hands_score': []
            }
            

            person_pose = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
            face_mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
        

        frame_person_data[person_id] = person_pose_data
        

        if person_id not in person_folders:
            person_folder = f"{base_path}/person_{person_id}"
            face_mask_folder = f"{base_path}/../faces/person_{person_id}"
            os.makedirs(person_folder, exist_ok=True)
            os.makedirs(face_mask_folder, exist_ok=True)
            person_folders[person_id] = {"folder": person_folder, "frame_count": 0}
            person_face_mask_folders[person_id] = face_mask_folder
        

        frame_index = person_folders[person_id]["frame_count"]
        person_filename = os.path.join(person_folders[person_id]["folder"], f"frame_{frame_index:d}.jpg")
        future_saves.append(save_executor.submit(save_image, person_pose, person_filename))
        
        face_mask_filename = os.path.join(person_face_mask_folders[person_id], f"frame_{frame_index:d}.jpg")
        future_saves.append(save_executor.submit(save_image, face_mask, face_mask_filename))
        
        person_folders[person_id]["frame_count"] += 1
    

    for future in future_saves:
        future.result()
    save_executor.shutdown(wait=True)
    

    return frame_person_data

if __name__ == '__main__':
    import os
    import cv2
    import argparse
    import numpy as np
    from tqdm import tqdm
    import decord
    
    parser = argparse.ArgumentParser("DWPose Skeleton Poses Extraction with tracking mask")
    parser.add_argument("--video_path", type=str, required=True,
                        help="Path to the input video file or folder containing frames")
    parser.add_argument("--output_video_path", type=str, required=True,
                        help="Path to save the output visualization images (base name)")
    parser.add_argument("--output_pose_path", type=str, required=True,
                        help="Path to save the extracted pose data as .npz")
    parser.add_argument("--mask_folder", type=str, default=None,
                        help="Folder for tracking masks (OBJxx_####.png)")
    parser.add_argument("--scale_factor", type=float, default=1,
                        help="Scale factor for input images to speed up processing (default: 0.5)")
    parser.add_argument("--process_every_n", type=int, default=1,
                        help="Process every Nth frame (default: 1, process all frames)")
    parser.add_argument("--num_workers", type=int, default=1,
                        help="Number of parallel workers for processing (default: 1)")
    parser.add_argument("--use_jpg", action="store_true",
                        help="Save output images as JPG instead of PNG for faster saving")
    
    args = parser.parse_args()
    
    scale_factor = args.scale_factor
    

    if os.path.isdir(args.video_path):
        all_frames = get_frames_from_folder(args.video_path)
        frames = [all_frames[i] for i in range(0, len(all_frames), args.process_every_n)]
        fps = 30
        frame_count = len(frames)
        frame_height, frame_width = frames[0].shape[:2]
        is_video_file = False
        print(f"Loaded {frame_count} frames from folder: {args.video_path}")
    else:
        video_reader = decord.VideoReader(args.video_path)
        fps = int(video_reader.get_avg_fps())
        total_frames = len(video_reader)
        indices = list(range(0, total_frames, args.process_every_n))
        frames = video_reader.get_batch(indices)
        frame_count = len(indices)
        frame_height, frame_width = video_reader[0].shape[:2]
        is_video_file = True
        print(f"Processing {frame_count} frames from video: {args.video_path}")
    
    base_path = args.output_video_path
    tracked_folder = f"{base_path}/allperson"
    all_face_mask_folder = f"{base_path}/../faces/allperson"
    os.makedirs(tracked_folder, exist_ok=True)
    os.makedirs(all_face_mask_folder, exist_ok=True)
    

    frame_results = []
    
    print("Extracting poses...")
    num_workers = min(args.num_workers, cpu_count(), 8)
    process_args = []
    for i, frame in enumerate(frames):
        process_args.append((i * args.process_every_n, frame, frame_height, frame_width, args, scale_factor))
    
    if num_workers > 1:
        print(f"Using {num_workers} processes for parallel processing...")
        from multiprocessing import Pool
        with Pool(processes=num_workers) as pool:
            results = list(tqdm(
                pool.imap(process_frame, process_args),
                total=len(process_args),
                desc="Processing frames"
            ))
        results.sort(key=lambda x: x[0])
        frame_results = results
    else:
        print("Using single process mode...")
        for args_tuple in tqdm(process_args, desc="Extracting poses"):
            result = process_frame(args_tuple)
            frame_results.append(result)
    

    print("Saving allperson outputs...")
    for i, ref_pose, pids in tqdm(frame_results, desc="Saving allperson outputs"):
        save_allperson_outputs(i, ref_pose, pids, frame_height, frame_width, base_path)
    

    global_person_ids = set()
    for i, ref_pose, pids in frame_results:
        for pid in pids:
            if pid != -1:
                global_person_ids.add(pid)
    global_person_ids = sorted(global_person_ids)
    print(f"Found {len(global_person_ids)} unique persons in the sequence: {global_person_ids}")
    

    person_folders = {}
    person_face_mask_folders = {}
    

    pose_by_person = {pid: [] for pid in global_person_ids}
    frame_metadata = []
    

    print("Saving individual person outputs...")
    for i, ref_pose, pids in tqdm(frame_results, desc="Saving individual person outputs"):

        frame_metadata.append({
            'frame_idx': i,
            'frame_shape': ref_pose["frame_shape"],
            'detected_pids': pids
        })
        

        frame_person_data = save_individual_person_outputs(
            i, ref_pose, pids, frame_height, frame_width, 
            base_path, person_folders, person_face_mask_folders, global_person_ids
        )
        

        for pid in global_person_ids:
            pose_by_person[pid].append(frame_person_data[pid])
    

    print("Saving person-organized pose data...")
    np.savez_compressed(
        args.output_pose_path,
        pose_by_person=pose_by_person,
        global_person_ids=global_person_ids,
        frame_metadata=frame_metadata,

        frames=[ref_pose for i, ref_pose, pids in frame_results]
    )
    
    print(f"Saved outputs:")
    print(f"- All person pose images in folder: {tracked_folder}")
    print(f"- Individual person images in folders: {base_path}/person_{{pid}}")
    print(f"- Face masks in folders: {base_path}/../faces/person_{{pid}}")
    print(f"- Combined face masks in folder: {all_face_mask_folder}")
    print(f"- Person-organized pose data: {args.output_pose_path}")


