import cv2
import numpy as np
from tqdm import tqdm
from moviepy import ImageSequenceClip
import os

def process_mask_video_with_ratio(video_mask_path, output_video_path, grid_ratio=0.05):
    cap = cv2.VideoCapture(video_mask_path)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    aspect_ratio = frame_width / frame_height

    total_grids = int(1 / grid_ratio)
    m = int(np.sqrt(total_grids / aspect_ratio))
    n = int(m * aspect_ratio)
    m = max(1, m)
    n = max(1, n)
    grid_height = frame_height // m
    grid_width = frame_width // n
    processed_frames = []
    for _ in tqdm(range(frame_count)):
        ret, frame = cap.read()
        if not ret:
            break
        if frame.ndim == 3:
            mask = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        else:
            mask = frame
        output_mask = np.zeros_like(mask)
        for i in range(m):
            for j in range(n):
                y_start = i * grid_height
                y_end = (i + 1) * grid_height if i < m - 1 else frame_height
                x_start = j * grid_width
                x_end = (j + 1) * grid_width if j < n - 1 else frame_width
                grid = mask[y_start:y_end, x_start:x_end]
                foreground_ratio = np.sum(grid == 255) / grid.size
                if foreground_ratio > 0.05: 
                    output_mask[y_start:y_end, x_start:x_end] = 255
                else: 
                    output_mask[y_start:y_end, x_start:x_end] = 0
        output_mask_rgb = cv2.cvtColor(output_mask, cv2.COLOR_GRAY2RGB)
        processed_frames.append(output_mask_rgb)
    cap.release()
    clip = ImageSequenceClip(processed_frames, fps=fps)
    clip.write_videofile(output_video_path, codec="libx264")

def process_video_folder(input_folder, output_folder, grid_ratio=0.05):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    for idx, video_file in enumerate(os.listdir(input_folder)):
        input_path = os.path.join(input_folder, video_file)
        if os.path.isfile(input_path) and video_file.endswith(('.mp4', '.avi', '.mov')):
            output_path = os.path.join(output_folder, video_file)
            process_mask_video_with_ratio(input_path, output_path, grid_ratio)

if __name__ == "__main__":
    # ===================== process video folder =====================
    input_folder = "./datasets/scripts/input_videos"
    output_folder = "./datasets/scripts/output_videos"
    grid_ratio = 0.001
    process_video_folder(input_folder, output_folder, grid_ratio)