import argparse
from pathlib import Path
import cv2
import numpy as np
from concurrent import futures

from tqdm import tqdm


def extract_frames(video_path, output_dir):
    """Extract one frame per second from video and save as PNG."""
    video_id = Path(video_path).stem
    cap = cv2.VideoCapture(str(video_path))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Check if all frames have already been extracted
    all_frames_extracted = True
    for i in range(0, frame_count, int(fps)):
        frame_id = str(i).zfill(3)
        output_path = Path(output_dir) / f"{video_id}_{frame_id}.png"
        if not output_path.exists():
            all_frames_extracted = False
            break

    if all_frames_extracted:
        return

    for i in range(0, frame_count, int(fps)):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frame_id = str(i).zfill(3)
            output_path = str(Path(output_dir) / f"{video_id}_{frame_id}.png")
            cv2.imwrite(output_path, frame)
    cap.release()

    """Extract the middle frame from video and save as PNG."""
    video_id = Path(str(video_path)).stem


def extract_middle_frame(video_path, output_dir):
    video_id = Path(video_path).stem
    cap = cv2.VideoCapture(str(video_path))

    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    middle_frame_idx = frame_count // 2  # Calculate the index of the middle frame

    cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
    ret, frame = cap.read()

    if ret:
        frame_id = str(middle_frame_idx).zfill(3)
        output_path = str(Path(output_dir) / f"{video_id}_{frame_id}.png")

        cv2.imwrite(output_path, frame)

    cap.release()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract frames from videos.")
    parser.add_argument("input_dir", help="Directory containing input videos.")
    parser.add_argument("output_dir", help="Directory to save output frames.")
    parser.add_argument("num_shards", type=int, help="Number of shards.")
    parser.add_argument("shard_id", type=int, help="ID of this shard (0-indexed).")
    args = parser.parse_args()

    input_dir = Path(args.input_dir)
    output_dir = Path(args.output_dir)

    video_paths = list(input_dir.rglob("*.mp4"))
    shard_size = int(np.ceil(len(video_paths) / args.num_shards))
    shard_start = args.shard_id * shard_size
    shard_end = min(shard_start + shard_size, len(video_paths))
    shard_video_paths = video_paths[shard_start:shard_end]

    for video_path in tqdm(shard_video_paths):
        with futures.ThreadPoolExecutor(max_workers=4) as executor:
            for video_filename in shard_video_paths:
                executor.submit(extract_middle_frame, video_path, output_dir)
