# utils/video_utils.py
import cv2
from PIL import Image
import decord

def get_video_info_decord(video_path: str):
    vr = decord.VideoReader(video_path)
    return len(vr), float(vr.get_avg_fps())

def build_1fps_indices_decord(video_path: str):
    total, fps = get_video_info_decord(video_path)
    if fps <= 0: fps = 30.0
    seconds = int(total // fps)
    idx = []
    for t in range(seconds+1):
        f = int(round(t * fps))
        if f < total:
            if len(idx)==0 or f != idx[-1]:
                idx.append(f)
    return idx  # Global frame indices (1fps)

class LazyFrameProvider:
    def __init__(self, video_path: str):
        self.vr = decord.VideoReader(video_path)
    def __call__(self, global_frame_idx: int) -> Image.Image:
        arr = self.vr[global_frame_idx].asnumpy()
        return Image.fromarray(arr)

def get_video_frame_count_decord(video_path: str) -> int:
    """Get total video frame count with decord."""
    try:
        vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
        return len(vr)
    except decord.DECORDError as e:
        print(f"Error reading video with decord: {video_path}, {e}")
        return 0

def extract_frames_decord(video_path: str, fps: float = 1.0) -> (list, list):
    """Extract frames using decord (supports float fps), used by AKS/TopK/Scope."""
    frames = []
    original_indices = []
    try:
        vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
        video_fps = vr.get_avg_fps()
        if video_fps == 0:  # Avoid division by zero
            return [], []
        try:
            target_fps = float(fps)
        except Exception:
            target_fps = 1.0
        if target_fps <= 0:
            target_fps = 1.0

        # Approximate target fps by sampling every N original frames (works for 0.1/0.5/2, etc.).
        frame_interval = int(round(float(video_fps) / target_fps))
        if frame_interval == 0:  # Ensure at least 1
            frame_interval = 1
        
        # Build the indices to extract
        # indices = range(0, len(vr), frame_interval)
        indices_to_extract = list(range(0, len(vr), frame_interval))

        
        # Batch extract in one call
        frame_batch = vr.get_batch(indices_to_extract).asnumpy()
        frames = [Image.fromarray(frame) for frame in frame_batch]
        original_indices = indices_to_extract
    except decord.DECORDError as e:
        print(f"Error extracting frames with decord: {video_path}, {e}")
    return frames, original_indices

def extract_frames(video_path, fps=1):
    frames = []
    vidcap = cv2.VideoCapture(video_path)
    video_fps = vidcap.get(cv2.CAP_PROP_FPS)
    frame_interval = int(video_fps / fps)

    count = 0
    while True:
        success, image = vidcap.read()
        if not success:
            break
        if count % frame_interval == 0:
            # Convert to PIL Image
            frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
        count += 1
    vidcap.release()
    return frames
