

import io
import os
import random
import re

import numpy as np
import decord
from PIL import Image


def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
    if sample in ['rand', 'middle']: # uniform sampling
        acc_samples = min(num_frames, vlen)
        # split the video into `acc_samples` intervals, and sample from each interval.
        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1] - 1))
        if sample == 'rand':
            try:
                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
            except:
                frame_indices = np.random.permutation(vlen)[:acc_samples]
                frame_indices.sort()
                frame_indices = list(frame_indices)
        elif fix_start is not None:
            frame_indices = [x[0] + fix_start for x in ranges]
        elif sample == 'middle':
            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
        else:
            raise NotImplementedError

        if len(frame_indices) < num_frames:  # padded with last frame
            padded_frame_indices = [frame_indices[-1]] * num_frames
            padded_frame_indices[:len(frame_indices)] = frame_indices
            frame_indices = padded_frame_indices
    elif 'fps' in sample:  # fps0.5, sequentially sample frames at 0.5 fps
        output_fps = float(sample[3:])
        duration = float(vlen) / input_fps
        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents
        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
        frame_indices = np.around(frame_seconds * input_fps).astype(int)
        frame_indices = [e for e in frame_indices if e < vlen]
        if max_num_frames > 0 and len(frame_indices) > max_num_frames:
            frame_indices = frame_indices[:max_num_frames]
    else:
        raise ValueError
    return frame_indices


def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
    video_reader = decord.VideoReader(video_path, num_threads=1)
    vlen = len(video_reader)
    fps = video_reader.get_avg_fps()
    duration = vlen / float(fps)
    if clip:
        start, end = clip
        duration = end - start
        vlen = int(duration * fps)
        start_index = int(start * fps)

    t_num_frames = np.random.randint(min_num_frames, num_frames + 1)

    frame_indices = get_frame_indices(
        t_num_frames, vlen, sample=sample, fix_start=fix_start,
        input_fps=fps
    )
    if clip:
        frame_indices = [f + start_index for f in frame_indices]
    frames = video_reader.get_batch(frame_indices).asnumpy()  # (T, H, W, C), np.uint8
    frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
    return frames


def extract_frame_number(filename):
    # Extract the numeric part from the filename using regular expressions
    match = re.search(r'_(\d+).jpg$', filename)
    return int(match.group(1)) if match else -1


def sort_frames(frame_paths):
    # Extract filenames from each path and sort by their numeric part
    return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))


def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
    image_list = sort_frames(list(os.listdir(video_path)))
    frames = []
    for image in image_list:
        fp = os.path.join(video_path, image)
        frame = Image.open(fp).convert('RGB')
        frames.append(frame)
    vlen = len(frames)

    t_num_frames = np.random.randint(min_num_frames, num_frames + 1)

    if vlen > t_num_frames:
        frame_indices = get_frame_indices(
            t_num_frames, vlen, sample=sample, fix_start=fix_start
        )
        frames = [frames[i] for i in frame_indices]
    return frames


class FrameSampler:
    def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
        self.max_num_frames = max_num_frames
        self.min_num_frames = min_num_frames
        self.sample = sample
    
    def __call__(self, file_name):
        fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
        frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
        return frames


def decode_video_byte(video_bytes):
    video_stream = io.BytesIO(video_bytes)
    vr = decord.VideoReader(video_stream)
    return vr


def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
    if isinstance(mp4_p, str):
        vr = decord.VideoReader(mp4_p, num_threads=1)
    elif isinstance(mp4_p, decord.video_reader.VideoReader):
        vr = mp4_p
    video_fps = vr.get_avg_fps()  # 获取视频的帧率
    video_duration = len(vr) / video_fps
    if n_frames is not None:
        if random_sample:
            frame_indices = sorted(random.sample(range(len(vr)), n_frames))
        else:
            frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
    else:
        frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
    frames = vr.get_batch(frame_indices).asnumpy()  # 转换为 numpy 数组
    frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
    if not return_frame_indices:
        return frames, video_duration
    else:
        return frames, video_duration, frame_indices


def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
    if isinstance(mp4_p, str):
        vr = decord.VideoReader(mp4_p, num_threads=1)
    elif isinstance(mp4_p, decord.video_reader.VideoReader):
        vr = mp4_p
    # sample the frames in frame_indices
    frames = vr.get_batch(frame_indices).asnumpy()  # 转换为 numpy 数组
    frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
    return frames