import os
import cv2
import numpy as np
from PIL import Image
from moviepy import VideoFileClip, ImageSequenceClip
from tqdm import tqdm

def apply_frequency_domain_ideal_2D_LPF_to_gray_frames(frame, cutoff_frequency=0.25):
    dft = cv2.dft(np.float32(frame), flags=cv2.DFT_COMPLEX_OUTPUT)
    dft_shift = np.fft.fftshift(dft)
    rows, cols = frame.shape
    crow, ccol = rows // 2, cols // 2 
    mask = np.zeros((rows, cols, 2), np.float32)

    radius = cutoff_frequency * min(rows, cols) / 2.0
    for i in range(rows):
        for j in range(cols):
            distance = np.sqrt((i - crow) ** 2 + (j - ccol) ** 2)
            if distance <= radius:
                mask[i, j] = 1

    filtered_dft = dft_shift * mask
    dft_ishift = np.fft.ifftshift(filtered_dft)
    img_back = cv2.idft(dft_ishift)
    img_back = cv2.magnitude(img_back[:, :, 0], img_back[:, :, 1])
    img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return img_back

def process_video_with_ideal_LPF_gray(input_video_path, output_video_path, cutoff_frequency=0.25):
    cap = cv2.VideoCapture(input_video_path)

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    processed_frames = []
    for _ in tqdm(range(frame_count), desc="Processing video frames", unit="frame"):
        ret, frame = cap.read()
        if not ret:
            break
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        filtered_frame = apply_frequency_domain_ideal_2D_LPF_to_gray_frames(gray_frame, cutoff_frequency=cutoff_frequency)
        processed_frames.append(cv2.cvtColor(filtered_frame, cv2.COLOR_GRAY2RGB))
    cap.release()
    clip = ImageSequenceClip(processed_frames, fps=fps)
    clip.write_videofile(output_video_path, codec="libx264")
    
if __name__ == "__main__":
    # =========== process video folder ==============
    input_video_folder = "./datasets/scripts/video_folder"
    output_video_folder = "./datasets/scripts/output_video_folder"
    for idx, video_file in enumerate(os.listdir(input_video_folder)):
        print(f"Processing video: {video_file}, progress: {idx + 1}/{len(os.listdir(input_video_folder))}")
        input_video = os.path.join(input_video_folder, video_file)
        output_video = os.path.join(output_video_folder, video_file.replace(".mp4", "_2D_LPF.mp4"))
        process_video_with_ideal_LPF_gray(input_video, output_video, cutoff_frequency=0.25)
