


import cv2
import os
import glob
import argparse
import numpy as np
import threading
import concurrent.futures
from queue import Queue
import time

def process_mask(mask_path):
    pid = int(os.path.basename(os.path.dirname(mask_path)).split('_')[-1])
    frame = int(os.path.basename(mask_path).split('_')[1].split('.')[0])
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    if mask is None or mask.sum() == 0:
        return None
    
    x, y, w, h = cv2.boundingRect((mask > 127).astype(np.uint8))
    return (frame, pid, x, y, w, h)

def masks2mot(seq_dir, out_txt, max_workers=None):
    tracks = {}
    mask_files = sorted(glob.glob(f'{seq_dir}/masks/person_*/*.png'))
    

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_mask = {executor.submit(process_mask, mask): mask for mask in mask_files}
        
        for future in concurrent.futures.as_completed(future_to_mask):
            result = future.result()
            if result:
                frame, pid, x, y, w, h = result
                tracks.setdefault(frame, []).append((pid, x, y, w, h))


    with open(out_txt, 'w') as f:
        for t, objs in sorted(tracks.items()):
            for pid, x, y, w, h in objs:
                f.write(f"{t+1},{pid},{x},{y},{w},{h},1,-1,-1\n")

def process_sequence(src_root, dst_root, seq, mode):
    if seq.startswith('.') or seq == 'seqmaps':
        return
    
    seq_dir = os.path.join(src_root, seq)
    if not os.path.isdir(seq_dir):
        return

    sub = 'gt' if mode == 'gt' else 'det'
    out_dir = os.path.join(dst_root, seq, sub)
    os.makedirs(out_dir, exist_ok=True)
    out_txt = os.path.join(out_dir, f'{sub}.txt')
    

    masks2mot(seq_dir, out_txt)


    ini_path = os.path.join(dst_root, seq, 'seqinfo.ini')
    if not os.path.exists(ini_path):
        H = W = num_frames = None


        mask_list = glob.glob(os.path.join(seq_dir, 'masks', 'person_*', '*.png'))
        if mask_list:
            img = cv2.imread(mask_list[0])
            if img is not None:
                H, W = img.shape[:2]
            num_frames = len({os.path.basename(p).split('_')[1].split('.')[0] for p in mask_list})


        if H is None:
            rgb_list = glob.glob(os.path.join(seq_dir, 'images', 'frame_*.*'))
            if rgb_list:
                img = cv2.imread(rgb_list[0])
                if img is not None:
                    H, W = img.shape[:2]
                num_frames = len(rgb_list)


        if H is None:
            H, W, num_frames = 720, 1280, 0


        with open(ini_path, 'w') as f_ini:
            f_ini.write('[Sequence]\n')
            f_ini.write(f'name={seq}\n')
            f_ini.write('frameRate=30\n')
            f_ini.write(f'seqLength={num_frames}\n')
            f_ini.write(f'imWidth={W}\n')
            f_ini.write(f'imHeight={H}\n')
            f_ini.write('imExt=.png\n')
    
    print(f"Processed sequence: {seq}")

if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('--src_root', required=True)
    ap.add_argument('--dst_root', required=True)
    ap.add_argument('--mode', choices=['gt', 'det'], default='gt')
    ap.add_argument('--num_threads', type=int, default=None, 
                    help='Number of worker threads for processing sequences (defaults to CPU count)')
    ap.add_argument('--mask_threads', type=int, default=None,
                    help='Number of worker threads for processing masks within each sequence (defaults to CPU count)')
    args = ap.parse_args()


    sequences = [seq for seq in sorted(os.listdir(args.src_root)) 
                if not seq.startswith('.') and seq != 'seqmaps' and 
                os.path.isdir(os.path.join(args.src_root, seq))]
    

    start_time = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor:
        futures = [
            executor.submit(process_sequence, args.src_root, args.dst_root, seq, args.mode)
            for seq in sequences
        ]
        

        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()
            except Exception as e:
                print(f"Error processing sequence: {e}")
    
    elapsed_time = time.time() - start_time
    print(f"Processed {len(sequences)} sequences in {elapsed_time:.2f} seconds")