import os
import cv2
import math
import argparse

def extract_frames(
    video_path: str,
    out_dir: str,
    stride: int = 1,
    target_size=(1920, 1080),
):
    """
    Save frames from a video every `stride` frames, resized to `target_size`,
    into <out_dir>/original as frame_XXX.png (sequentially numbered by saved frames).
    Also write these saved frames into <out_dir>/ori_processed.mp4.
    Stop once 270 frames have been saved.
    """
    assert stride >= 1, "stride must be >= 1"

    os.makedirs(out_dir, exist_ok=True)
    orig_dir = os.path.join(out_dir, "original")
    os.makedirs(orig_dir, exist_ok=True)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open video: {video_path}")

    src_fps = cap.get(cv2.CAP_PROP_FPS)
    if not src_fps or src_fps != src_fps or src_fps <= 0:
        src_fps = 24.0

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
    estimated_saves = min(270, math.ceil(total_frames / stride)) if total_frames > 0 else 270
    pad = max(3, len(str(max(estimated_saves - 1, 0))))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_out_path = os.path.join(out_dir, "ori_processed.mp4")
    writer = cv2.VideoWriter(video_out_path, fourcc, src_fps, target_size)

    saved = 0
    idx = 0

    while True:
        ok, frame = cap.read()
        if not ok:
            break

        if idx % stride == 0 and saved < 270:
            # Resize to target size (W,H) = (1920,1080)
            resized = cv2.resize(frame, target_size, interpolation=cv2.INTER_AREA)

            filename = f"frame_{str(saved).zfill(pad)}.png"
            cv2.imwrite(os.path.join(orig_dir, filename), resized)

            writer.write(resized)

            saved += 1

            if saved >= 270:
                break

        idx += 1

    cap.release()
    writer.release()
    print(f"Done. Saved {saved} frames to {orig_dir}")
    print(f"Video written to {video_out_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract frames from a video.")
    parser.add_argument("video_path", type=str, help="input video")
    parser.add_argument("out_dir", type=str, help="output directory")
    parser.add_argument("stride", type=int, help="frame frequency (save 1 frame every `stride` frames)")
    args = parser.parse_args()

    extract_frames(args.video_path, args.out_dir, stride=args.stride, target_size=(1920, 1080))