#!/usr/bin/env python3
import os
import cv2
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from PIL import Image

def extract_smoke(frames_dir: str,
                  mask_video: str,
                  out_dir: str,
                  history: int = 100,
                  var_thresh: int = 50) -> np.ndarray:

    frames = sorted([f for f in os.listdir(frames_dir)
                     if f.lower().endswith(('.png', '.jpg'))])
    if not frames:
        raise RuntimeError(f"{frames_dir} no png/jpg files")

    cap = cv2.VideoCapture(mask_video)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open mask video: {mask_video}")

    bg_sub = cv2.createBackgroundSubtractorMOG2(
        history=history, varThreshold=var_thresh, detectShadows=False)

    Path(out_dir).mkdir(parents=True, exist_ok=True)
    images_out = []

    first_frame = cv2.imread(os.path.join(frames_dir, frames[0]))
    if first_frame is None:
        raise RuntimeError("Failed in reading the video")
    H, W = first_frame.shape[:2]

    for idx, fname in enumerate(tqdm(frames, desc="Extracting smoke")):
        ok, mask_frame = cap.read()
        if not ok:
            print("⚠ mask not enough")
            break

        mask_frame = cv2.resize(mask_frame, (W, H), interpolation=cv2.INTER_NEAREST)
        mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)

        _, mask_bin = cv2.threshold(mask_gray, 127, 255, cv2.THRESH_BINARY)

        frame_path = os.path.join(frames_dir, fname)
        frame = cv2.imread(frame_path)
        if frame is None:
            raise RuntimeError(f"Failed in reading the frame: {frame_path}")

        bg_sub.apply(frame)           

        mask_blur = cv2.GaussianBlur(mask_bin, (255, 255), 0)
        alpha = mask_blur.astype(np.float32) / 255.0

        smoke_rgb = (alpha[..., None] * frame).astype(np.uint8)
        smoke_gray = cv2.cvtColor(smoke_rgb, cv2.COLOR_BGR2GRAY)

        out_name = f"{idx:04d}.png"
        cv2.imwrite(os.path.join(out_dir, out_name), smoke_gray)
        images_out.append(smoke_gray)

    cap.release()
    return np.stack(images_out, axis=0)

# ----------------------------- CLI -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frames_dir", required=True,
                        help="img dir")
    parser.add_argument("--mask_video", required=True,
                        help="mask from SAM")
    parser.add_argument("--out_dir", required=True,
                        help="save smoke imgs")
    args = parser.parse_args()

    imgs = extract_smoke(args.frames_dir, args.mask_video, args.out_dir)
    print(f"Done! Extract {len(imgs)} frame, shape={imgs.shape}")