import os
import cv2
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm

def validate_video(
    full_path: str,
    frame_threshold: int,
    delete_bad: bool
):
    """
    Validate single video:
      - Can be opened
      - Total frames >= frame_threshold
    Returns (is_good: bool, path: str, reason: str)
    """
    cap = cv2.VideoCapture(full_path)
    if not cap.isOpened():
        cap.release()
        if delete_bad:
            try: os.remove(full_path)
            except: pass
        return False, full_path, "cannot open"

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    if frame_count < frame_threshold:
        if delete_bad:
            try: os.remove(full_path)
            except: pass
        return False, full_path, f"only {frame_count} frames"

    return True, full_path, ""

def collect_and_validate_mp4_parallel(
    root_dir: str,
    output_filename: str = "valid_mp4_list.txt",
    frame_threshold: int = 100,
    delete_bad: bool = False,
    num_workers: int = None
):
    """
    In parallel:
      1) Collect all .mp4 files;
      2) Validate each video using Pool;
      3) Show overall progress with tqdm;
      4) Delete (optional) bad files or short videos;
      5) Write qualified files to output_filename.
    """
    # 1) Collect paths
    all_mp4 = []
    for dp, dns, fns in os.walk(root_dir):
        for fn in fns:
            if fn.lower().endswith(".mp4"):
                all_mp4.append(os.path.join(dp, fn))

    total = len(all_mp4)
    print(f"🔍 Found {total} .mp4 files, starting parallel validation...")

    # 2) Parallel validation
    if num_workers is None:
        num_workers = max(1, cpu_count() - 1)

    good_paths = []
    bad_paths = []

    # partial to fix frame_threshold and delete_bad
    worker_func = partial(validate_video,
                          frame_threshold=frame_threshold,
                          delete_bad=delete_bad)

    with Pool(processes=num_workers) as pool:
        for is_good, path, reason in tqdm(pool.imap_unordered(worker_func, all_mp4),
                                          total=total,
                                          desc="Validating",
                                          unit="file"):
            if is_good:
                good_paths.append(os.path.abspath(path))
            else:
                bad_paths.append((path, reason))

    # 3) Write results
    out_path = os.path.join(root_dir, output_filename)
    with open(out_path, "w") as f:
        for p in good_paths:
            f.write(p + "\n")

    # 4) Print summary
    print(f"\n✅ Valid videos: {len(good_paths)} files, written to {out_path}")
    if bad_paths:
        print(f"⚠️ Skipped or deleted files: {len(bad_paths)} files")
        # for p, reason in bad_paths:
        #     print(f"   • {p} —— {reason}")

if __name__ == "__main__":
    
    target_directory = "path/to/evaluation/videos"

    collect_and_validate_mp4_parallel(
        root_dir=target_directory,
        output_filename="valid_mp4_list.txt",
        frame_threshold=20,
        delete_bad=False,
        num_workers=10   # or None to automatically use cpu_count()-1
    )

# Usage examples and dataset statistics:
# 
# Dance dataset processing:
# - Valid videos: 2862 files
# - Output: path/to/training/dataset/valid_mp4_list.txt
#
# Tutorial videos:
# - Valid videos: 528 files  
# - Output: path/to/tutorial/videos/valid_mp4_list.txt
#
# Sports videos:
# - Note: Skip first 2 seconds due to text overlay
# - Valid videos: 326 files
# - Output: path/to/sports/videos/valid_mp4_list.txt
#
# Paired dataset:
# - Valid videos: 1141 files
# - Output: path/to/paired/dataset/valid_mp4_list.txt