import os
import cv2
import numpy as np
from tqdm import tqdm
from PIL import Image
import argparse

def process_frames_with_dynamic_background_to_images(data_dir, mask_video_path):
    src_dir = os.path.join(data_dir, "original")
    out_img_dir = os.path.join(data_dir, "frame_cam00")
    out_mask_dir = os.path.join(data_dir, "mask")

    if not os.path.isdir(src_dir):
        raise FileNotFoundError(f"{src_dir}")

    os.makedirs(out_img_dir, exist_ok=True)
    os.makedirs(out_mask_dir, exist_ok=True)

    frames = sorted([f for f in os.listdir(src_dir) if f.lower().endswith(".png")])
    if len(frames) == 0:
        raise RuntimeError(f"{src_dir}")

    first_img = cv2.imread(os.path.join(src_dir, frames[0]))
    if first_img is None:
        raise RuntimeError(f"{os.path.join(src_dir, frames[0])}")
    h, w = first_img.shape[:2]

    cap = cv2.VideoCapture(mask_video_path)
    if not cap.isOpened():
        raise ValueError(f"{mask_video_path}")

    bg_subtractor = cv2.createBackgroundSubtractorMOG2(history=100, varThreshold=50, detectShadows=False)

    images = []
    saved = 0

    for fname in tqdm(frames, total=len(frames), desc="Processing Frames"):
        ret, mask_frame = cap.read()
        if not ret:
            print("mask not enough")
            break

        mask_frame = cv2.resize(mask_frame, (w, h))
        mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
        mask_name = f"frame_{saved:03d}.png"
        cv2.imwrite(os.path.join(out_mask_dir, mask_name), mask_gray)

        frame_path = os.path.join(src_dir, fname)
        frame = cv2.imread(frame_path)
        if frame is None:
            print(f"fail: {frame_path}")
            continue
        if frame.shape[0] != h or frame.shape[1] != w:
            frame = cv2.resize(frame, (w, h))

        bg_subtractor.apply(frame.astype(np.uint8))
        background = bg_subtractor.getBackgroundImage()
        if background is None:
            background = np.zeros_like(frame, dtype=np.uint8)

        _, mask_binary = cv2.threshold(mask_gray, 127, 255, cv2.THRESH_BINARY)

        k = 80
        if k % 2 == 0:
            k += 1
        mask_blurred = cv2.GaussianBlur(mask_binary, (k, k), 0)
        alpha = (mask_blurred / 255.0).astype(np.float32)[..., None]  # (H,W,1)

        frame_f = frame.astype(np.float32)
        smoke = alpha * frame_f  # + (1 - alpha) * 0
        smoke = np.clip(smoke, 0, 255).astype(np.uint8)

        result_frame = np.zeros_like(frame, dtype=np.uint8)
        result_frame[mask_blurred > 0] = smoke[mask_blurred > 0]

        result_gray = cv2.cvtColor(result_frame, cv2.COLOR_BGR2GRAY)
        out_name = f"frame_{saved:03d}.png"
        cv2.imwrite(os.path.join(out_img_dir, out_name), result_gray)

        images.append(np.array(Image.fromarray(result_gray).convert('L')))

        saved += 1

    cap.release()
    return np.array(images)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process frames with dynamic background and save masks & outputs.")
    parser.add_argument("mask_video_path", type=str, help="mask")
    parser.add_argument("data_dir", type=str, help="data root")
    args = parser.parse_args()

    _ = process_frames_with_dynamic_background_to_images(args.data_dir, args.mask_video_path)
    print(f"Done. {os.path.join(args.data_dir, 'frames_cam00')}\n  - {os.path.join(args.data_dir, 'mask')}")