from fractions import Fraction
import logging
import logging.handlers
import math
import av
import os
import sys
from typing import Optional
import decord
import numpy as np
from PIL import Image
import requests
from decord import VideoReader, cpu

from llava.constants import LOGDIR

server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content."

handler = None

import torch.distributed as dist


def load_video_decord(video_path, probs, prompt_len, force_sample=False, start_frame=None, end_frame=None, start_time=None, end_time=None, fps=None):
    if not isinstance(video_path, str) or not os.path.exists(video_path):
        raise ValueError(f"无效视频路径: {video_path}")

    # 初始化decord参数（强制CPU、单线程）
    decord.bridge.set_bridge('native')  # 避免潜在的torch冲突
    ctx = decord.cpu(0)
    try:
        # 使用更稳定的视频读取方式
        vr = decord.VideoReader(video_path, ctx=ctx, num_threads=1)
    except Exception as e:
        raise ValueError(f"无法打开视频文件：{video_path}") from e

    try:
        # 获取基础元数据（修正版）
        total_frames = len(vr)
        # 获取视频原始尺寸的正确方式
        try:
            meta = vr.get_meta_data()  # 标准方法获取元数据
            original_width = int(meta['width'])
            original_height = int(meta['height'])
        except (KeyError, AttributeError, IndexError):
            # 元数据获取失败时的备选方案
            if total_frames > 0:
                first_frame = vr[0].asnumpy()
                original_height, original_width = first_frame.shape[:2]
            else:
                original_width = original_height = 384  # 默认尺寸

        avg_fps = vr.get_avg_fps()  # 直接获取平均帧率
        # 智能帧率处理
        actual_fps = avg_fps if not math.isnan(avg_fps) and avg_fps > 0 else 30.0
        used_fps = fps if (isinstance(fps, (int, float)) and (fps > 0)) else actual_fps

        # 计算视频总时长（三种方法）
        duration_methods = [
            lambda: total_frames / actual_fps,  # 方法1：总帧数/帧率
            lambda: vr.get_frame_timestamp(total_frames-1)[-1] if total_frames >0 else 0,  # 方法2：最后一帧时间戳
            lambda: len(vr) / 30.0  # 方法3：保底策略（假设30fps）
        ]
        
        total_duration = 0.0
        for method in duration_methods:
            try:
                total_duration = method()
                if total_duration > 0:
                    break
            except Exception as e:
                print(f"时长计算方法失败: {str(e)}")
                continue
        if total_duration <= 0:
            total_duration = 0.0

        # 处理时间范围参数
        start_sec = 0.0
        end_sec = total_duration

        # 优先处理帧号参数
        if start_frame is not None and end_frame is not None:
            start_sec = max(0, start_frame) / used_fps
            end_sec = max(start_sec, end_frame / used_fps)
        elif start_time is not None and end_time is not None:
            try:
                start_sec = parse_time(start_time)
                end_sec = parse_time(end_time)
            except ValueError as e:
                raise e

        # 边界保护（双重验证）
        start_sec = max(0.0, start_sec)
        end_sec = max(start_sec, min(end_sec, total_duration))
        if end_sec < start_sec:
            start_sec, end_sec = end_sec, start_sec  # 交换start_sec和end_sec
        if start_frame is not None and end_frame is not None and end_frame < start_frame:
            start_frame, end_frame = end_frame, start_frame  # 交换start_frame和end_frame
        duration_sec = end_sec - start_sec
        if duration_sec < 0:
            print(f"无效的时间范围: start_time{start_time}, end_time{end_time}")
        if duration_sec <= 1:
            duration_sec = 1
            end_sec = start_sec + 1
        
        # 获取采样策略
        task_type = np.argmax(probs)
        sampling_interval, resolution = get_sampling_strategy(int(duration_sec), task_type, prompt_len)

        # 生成采样时间点（带边界检查）
        num_samples = max(1, int(duration_sec / sampling_interval))
        times = np.linspace(
            start=start_sec,
            stop=end_sec,
            num=num_samples,
            endpoint=False,
            dtype=np.float32
        )

        # 转换时间戳为帧索引（增强版）
        frame_indices = []
        valid_times = []
        for t in times:
            idx = int(t * actual_fps)
            if 0 <= idx < total_frames:
                frame_indices.append(idx)
                valid_times.append(t)
            elif idx >= total_frames:
                break  # 超出视频长度

        # 批量读取帧（带容错机制）
        try:
            if frame_indices:
                frames = vr.get_batch(frame_indices).asnumpy()
            else:
                # 当frame_indices为空时
                if total_frames == 0:
                    # 视频无帧，生成384x384的数组
                    frames = np.zeros((0, 384, 384, 3), dtype=np.uint8)
                else:
                    # 视频有帧但时间范围无效，生成原尺寸数组
                    frames = np.zeros((0, original_height, original_width, 3), dtype=np.uint8)
        except decord.DECORDError as e:
            print(f"批量读取失败，转为逐帧读取: {str(e)}")
            frames = []
            for idx in frame_indices:
                try:
                    frame = vr[idx].asnumpy()
                    frames.append(frame)
                except:
                    # 生成黑色占位帧，使用original_width和original_height
                    black_frame = np.zeros((original_height, original_width, 3), dtype=np.uint8)
                    frames.append(black_frame)
            # 转换为numpy数组
            frames = np.stack(frames) if frames else np.zeros(
                (0, 384, 384, 3) if total_frames == 0 else (0, original_height, original_width, 3), dtype=np.uint8
            )

        # 处理空帧，保持原始尺寸
        processed_frames = []
        # for frame in frames:
        #     if frame.size == 0 or frame.shape[0] == 0 or frame.shape[1] == 0:
        #         # 空帧，生成黑色帧
        #         black_frame = np.zeros((original_height, original_width, 3), dtype=np.uint8)
        #         processed_frames.append(black_frame)
        #     else:
        #         # 保留原尺寸
        #         processed_frames.append(frame)
        for frame in frames:
            processed_frames.append(frame)

        # 转换为最终的numpy数组
        if processed_frames:
            frame_array = np.stack(processed_frames)
        else:
            # 没有有效帧，根据视频是否有帧决定尺寸
            if total_frames == 0:
                frame_array = np.zeros((0, 384, 384, 3), dtype=np.uint8)
            else:
                frame_array = np.zeros((0, original_height, original_width, 3), dtype=np.uint8)

        # 重新计算实际duration（基于有效采样点）
        actual_duration = valid_times[-1] - valid_times[0] if valid_times else 0
        # # ========== 取帧查看逻辑 ==========
        # # 构建输出目录路径
        # base_dir = "/mnt/new_cpfs/yqs/model/frame_decord"
        # video_filename = os.path.basename(video_path)
        # base_name = os.path.splitext(video_filename)[0]
        # # 
        # # 处理目录命名冲突
        # target_dir_base = os.path.join(base_dir, base_name)
        # counter = 0
        # target_dir = target_dir_base
        # while os.path.exists(target_dir):
        #     target_dir = f"{target_dir_base}_{counter}"
        #     counter += 1
        # os.makedirs(target_dir, exist_ok=True)
        # # 保存所有调整后的帧
        # for frame_idx, frame in enumerate(frame_array):
        #     img = Image.fromarray(frame)
        #     frame_path = os.path.join(target_dir, f"frame_{frame_idx:04d}.jpg")
        #     img.save(frame_path)
        # # ========== 结束取帧 ==============

        return frame_array, actual_duration, sampling_interval, resolution
    finally:
        del vr # 显式释放资源（防止内存泄漏）
        decord.bridge.set_bridge('torch')  # 恢复默认bridge

def get_sampling_strategy(duration_sec: int, task_type: int, text_tokens: int) -> tuple[float, int]:
    # 定义任务类型与分辨率优先级映射
    y_priority_map = {
        0: [27, 18, 14],  # 属性识别型，分辨率
        1: [18, 14, 11],  # 行为识别型，平衡偏帧数
        2: [14, 11],  # 时序定位型，帧数，不需高分辨率
        3: [27, 18, 14],  # 空间定位型，分辨率，不能极低分辨率
        4: [27, 18, 14, 11],  # 检索计数型，平衡偏分辨率
        5: [18, 14, 11],  # 总结概括型，平衡偏帧数，不需极高分辨率
        6: [27, 18, 14, 11],  # 检索计数型，平衡偏分辨率
        7: [27, 18, 14]   # OCR型，平衡偏分辨率，不能极低分辨率
    }
    y_time_map = { # 能忍受的最大时间（每帧）
        0: [4, 6],  # 属性识别型，平衡偏分辨率
        1: [2, 8],  # 行为识别型，平衡偏帧数
        2: [10],  # 时序定位型，帧数
        3: [4, 8],  # 空间定位型，分辨率
        4: [4, 6, 8],  # 检索计数型，平衡偏分辨率
        5: [2, 6],  # 总结概括型，平衡偏帧数
        6: [4, 6, 10],  # 检索计数型，平衡偏分辨率
        7: [2, 6]   # OCR型，平衡偏分辨率
    }
    max_total_tokens = 21_000
    video_token_budget = max_total_tokens - text_tokens
    # assert video_token_budget > 0, "Token budget exhausted by text"

    # 获取候选参数
    y_candidates = y_priority_map[task_type]
    time_thresholds = y_time_map[task_type]

    # 计算最小采样间隔约束
    if duration_sec <= 30:
        min_x = 0.5 
    elif duration_sec <= 120:
        min_x = 1 
    else:
        min_x = 2
    
    max_frames_limit = 120
    x_min_based_on_max_frames = duration_sec / max_frames_limit

    # 遍历分辨率候选
    for i, y in enumerate(y_candidates):
        tokens_per_frame = y ** 2
        max_frames = video_token_budget // tokens_per_frame
        
        # 跳过无法满足至少1帧的分辨率
        if max_frames < 1:
            continue

        # 基础计算
        x_min_based_on_tokens = duration_sec / max_frames
        
        # 综合约束计算
        x_min = max(
            x_min_based_on_tokens,
            min_x,
            x_min_based_on_max_frames
        )
        
        # 时间阈值检查
        if i < len(time_thresholds) and x_min > time_thresholds[i]:
            continue

        # 计算最终采样间隔（考虑帧数上限）
        x_candidate = round(x_min, 1)
        x_candidate = max(
            x_candidate,
            math.ceil((duration_sec / max_frames_limit) * 10) / 10  # 确保帧数≤120
        )
        
        # 验证token消耗
        required_frames = duration_sec / x_candidate
        required_tokens = required_frames * tokens_per_frame
        
        if required_tokens > video_token_budget:
            continue
        
        # print(f"Tokens Consumed: {text_tokens} + {required_tokens} = {text_tokens + required_tokens}")
        return x_candidate, y

    # 保底策略（理论上不会执行到此处）
    y = y_candidates[-1]
    tokens_per_frame = y ** 2
    max_frames = video_token_budget // tokens_per_frame
    x_min = max(
        duration_sec / max_frames,
        min_x,
        duration_sec / max_frames_limit
    )
    return round(x_min, 1), y

def parse_time(time_str):
    """
    将时间字符串转换为浮点型秒数，尽可能支持多种格式并容忍格式错误。支持格式包括：
    - hh:mm:ss,ms
    - mm:ss,ms
    - ss,ms
    - hh:mm:ss.ms
    - mm:ss.ms
    - ss.ms
    - 空格、前导零缺失、非法字符等容错处理
    params: time_str (str): 时间字符串，如 "00:21:30,000"
    return: float
    """
    # 初始化主时间部分和毫秒部分
    main_time = time_str
    ms_str = '0'

    # 查找第一个逗号或点作为分隔符
    sep_pos = -1
    sep_char = None
    for i, c in enumerate(time_str):
        if c in ',.':
            sep_pos = i
            sep_char = c
            break
    if sep_pos != -1:
        main_time = time_str[:sep_pos]
        ms_str = time_str[sep_pos+1:]

    # 分割主时间部分
    time_components = main_time.strip().split(':')
    
    # 处理时间组件数量
    if len(time_components) > 3: # 超过3个部分时，取前三个
        time_components = time_components[:3]
    elif len(time_components) < 1:
        raise ValueError(f"无效时间格式: {time_str}")

    # 确定 hh, mm, ss
    if len(time_components) == 3:
        hh, mm, ss = time_components
    elif len(time_components) == 2:
        hh = '0'
        mm, ss = time_components
    elif len(time_components) == 1:
        hh = mm = '0'
        ss = time_components[0]

    # 安全转换为整数（无效则返回0）
    def safe_int(s, default=0):
        try:
            return int(s.strip())
        except (ValueError, TypeError):
            return default

    hh = safe_int(hh)
    mm = safe_int(mm)
    ss = safe_int(ss)

    # 处理毫秒部分
    ms_str = ms_str.strip()
    if not ms_str:
        ms = 0
    else:
        import re
        # 提取前缀数字部分
        match = re.match(r'^\d*', ms_str)
        digits = match.group()
        if not digits:
            ms = 0
        else:
            # 补零至三位并取前三位
            digits = digits.ljust(3, '0')[:3]
            ms = safe_int(digits)

    # 计算总秒数
    total_seconds = hh * 3600 + mm * 60 + ss + ms / 1000.0
    return total_seconds

from bisect import bisect_left

def load_video_pyav(video_path, probs, prompt_len, force_sample=False, start_frame=None, end_frame=None, start_time=None, end_time=None, fps=None):
    # 输入验证
    if not isinstance(video_path, str) or not os.path.exists(video_path):
        raise ValueError(f"无效视频路径: {video_path}")

    try:
        container = av.open(video_path)
    except Exception as e:
        raise ValueError(f"无法打开视频文件：{video_path}") from e

    try:
        video_stream = next(s for s in container.streams if s.type == 'video')
    except StopIteration:
        container.close()
        raise ValueError(f"视频文件中没有视频流：{video_path}")

    # 初始化元数据
    original_width = video_stream.width
    original_height = video_stream.height
    avg_fps = video_stream.average_rate if video_stream.average_rate > 0 else 30.0
    used_fps = fps if (isinstance(fps, (int, float)) and (fps > 0)) else avg_fps

    # 获取第一帧尺寸作为备选方案
    try:
        container.seek(0)
        first_frame = next(container.decode(video_stream))
        first_frame = first_frame.to_ndarray(format='rgb24')
        original_height, original_width = first_frame.shape[:2]
    except:
        original_width = original_height = 384

    # 计算总帧数和持续时间
    total_frames = 0
    frame_times = []
    container.seek(0)
    for frame in container.decode(video_stream):
        total_frames += 1
        frame_times.append(frame.pts * frame.time_base)
    total_duration = total_frames / avg_fps if total_frames > 0 else 0

    # 处理时间范围参数（与原函数逻辑相同）
    start_sec = 0.0
    end_sec = total_duration

    # 处理帧号参数
    if start_frame is not None and end_frame is not None:
        start_sec = max(0, start_frame) / used_fps
        end_sec = max(start_sec, end_frame / used_fps)
    # 处理时间参数（需要实现 parse_time）
    elif start_time is not None and end_time is not None:
        try:
            start_sec = parse_time(start_time)
            end_sec = parse_time(end_time)
        except ValueError as e:
            container.close()
            raise e

    # 边界保护
    start_sec = max(0.0, start_sec)
    end_sec = max(start_sec, min(end_sec, total_duration))
    if end_sec < start_sec:
        start_sec, end_sec = end_sec, start_sec
    duration_sec = end_sec - start_sec

    # 获取采样策略
    task_type = np.argmax(probs)
    sampling_interval, resolution = get_sampling_strategy(int(duration_sec), task_type, prompt_len)

    # 生成采样时间点
    num_samples = max(1, int(duration_sec / sampling_interval))
    times = np.linspace(start_sec, end_sec, num=num_samples, endpoint=False, dtype=np.float32)

    # 获取有效帧索引
    valid_indices = []
    valid_times = []
    for t in times:
        idx = bisect_left(frame_times, t)
        if 0 <= idx < total_frames:
            valid_indices.append(idx)
            valid_times.append(t)

    # 重新打开视频流进行帧采集
    container.close()
    container = av.open(video_path)
    video_stream = container.streams.video[0]

    # 排序并去重索引
    sorted_indices = sorted(set(valid_indices))
    collected = {}
    current_idx = 0

    try:
        for frame in container.decode(video_stream):
            if current_idx in sorted_indices:
                try:
                    img = frame.to_ndarray(format='rgb24')
                    collected[current_idx] = img
                except Exception as e:
                    img = np.zeros((original_height, original_width, 3), dtype=np.uint8)
                    collected[current_idx] = img
            if current_idx >= max(sorted_indices, default=0):
                break
            current_idx += 1
    except Exception as e:
        print(f"解码错误: {str(e)}")

    # 构建最终帧数组
    frames = []
    for idx in valid_indices:
        frames.append(collected.get(idx, np.zeros((original_height, original_width, 3), dtype=np.uint8)))

    frame_array = np.stack(frames) if frames else np.zeros(
        (0, 384, 384, 3) if total_frames == 0 else (0, original_height, original_width, 3),
        dtype=np.uint8
    )

    # 计算实际持续时间
    actual_duration = valid_times[-1] - valid_times[0] if valid_times else 0

    container.close()
    return frame_array, actual_duration, sampling_interval, resolution


# 未修复oom，但添加了对输入范围的支持
# def load_video_pyav(video_path, probs, prompt_len, force_sample=False, 
#                    start_frame=None, end_frame=None, 
#                    start_time=None, end_time=None, fps=None):
#     # 打开视频文件
#     container = av.open(video_path)
#     video_stream = container.streams.video[0]
#     video_stream.thread_type = "AUTO"

#     # 获取视频实际参数
#     actual_fps = float(video_stream.average_rate) if video_stream.average_rate else 30.0
#     used_fps = fps if fps is not None else actual_fps

#     # 读取所有视频帧
#     video_frames = []
#     for frame in container.decode(video=0):
#         video_frames.append(frame)
#     total_original_frames = len(video_frames)
#     if total_original_frames == 0:
#         raise ValueError("视频文件为空或损坏")

#     # 确定截取范围
#     if start_frame is not None and end_frame is not None:
#         # 处理帧范围
#         start = max(0, min(int(start_frame), total_original_frames-1))
#         end = max(0, min(int(end_frame), total_original_frames-1))
#         if start > end:
#             start, end = end, start
#         video_frames = video_frames[start:end+1]
#     elif start_time is not None and end_time is not None:
#         # 转换时间为秒
#         start_sec = parse_time(start_time)
#         end_sec = parse_time(end_time)
#         if start_sec > end_sec:
#             start_sec, end_sec = end_sec, start_sec
#         # 计算帧数
#         start = int(start_sec * used_fps)
#         end = int(end_sec * used_fps)
#         start = max(0, min(start, total_original_frames-1))
#         end = max(0, min(end, total_original_frames-1))
#         video_frames = video_frames[start:end+1]

#     # 处理截取后参数
#     total_frames = len(video_frames)
#     if total_frames == 0:
#         raise ValueError("视频片段无有效帧")
    
#     # 计算实际持续时间
#     if total_frames > 0:
#         duration_sec = video_frames[-1].time - video_frames[0].time
#     else:
#         duration_sec = 0

#     # 动态采样策略
#     task_type = np.argmax(probs)
#     sampling_interval, resolution = get_sampling_strategy(duration_sec, task_type, prompt_len)
    
#     # 计算采样帧数
#     desired_frames = max(1, int(round(duration_sec / sampling_interval)))
    
#     # 生成采样索引
#     if total_frames < desired_frames:
#         frame_indices = list(range(total_frames))
#     else:
#         frame_indices = np.linspace(0, total_frames-1, desired_frames, dtype=int).tolist()
    
#     # 强制采样逻辑
#     if force_sample and len(frame_indices) > desired_frames:
#         frame_indices = np.linspace(0, total_frames-1, desired_frames, dtype=int).tolist()

#     # 提取并处理帧数据
#     sampled_frames = [video_frames[i] for i in frame_indices]
#     frame_array = np.stack([frame.to_ndarray(format="rgb24") for frame in sampled_frames])

#     return frame_array, duration_sec, sampling_interval, resolution

# 选帧
# def load_video_pyav(video_path, probs, prompt_len, force_sample=False):
#     container = av.open(video_path)
#     container.streams.video[0].thread_type = "AUTO"

#     video_frames = []
#     for packet in container.demux():
#         if packet.stream.type == 'video':
#             for frame in packet.decode():
#                 video_frames.append(frame)

#     total_frame_num = len(video_frames)
#     video_time = video_frames[-1].time if total_frame_num > 0 else 0
#     if video_time <= 0 or total_frame_num <= 0:
#         raise ValueError("Video file is empty or corrupted.")
#     duration_sec = video_time

#     task_type = np.argmax(probs) # 根据概率确定任务类型
#     sampling_interval, resolution = get_sampling_strategy(duration_sec, task_type, prompt_len) # 获取动态采样策略

#     # 计算需要抽取的帧数，至少1帧
#     desired_frames = max(1, int(round(duration_sec / sampling_interval)))

#     # 根据实际帧数调整
#     if total_frame_num < desired_frames: # 若总帧数不足，则只取现有的
#         frame_indices = list(range(total_frame_num))
#     else:
#         frame_indices = np.linspace(0, total_frame_num - 1, desired_frames, dtype=int).tolist() # 统一采样
#     # 强制采样逻辑（保留原参数兼容性）
#     if force_sample and len(frame_indices) > desired_frames:
#         frame_indices = np.linspace(0, total_frame_num - 1, desired_frames, dtype=int).tolist()

#     # 提取帧数据
#     sampled_frames = [video_frames[i] for i in frame_indices]
#     frame_array = np.stack([frame.to_ndarray(format="rgb24") for frame in sampled_frames])

#     return frame_array, duration_sec, sampling_interval, resolution

# 非选帧 原代码
# def load_video_pyav(video_path, max_frames_num, fps, force_sample=False):
#     container = av.open(video_path)
#     # 使用自动线程类型
#     container.streams.video[0].thread_type = "AUTO"

#     video_frames = []
#     for packet in container.demux():
#         if packet.stream.type == 'video':
#             for frame in packet.decode():
#                 video_frames.append(frame)

#     total_frame_num = len(video_frames)
#     video_time = video_frames[-1].time if total_frame_num > 0 else 0

#     if video_time <= 15:  # 每秒2帧
#         frames_to_extract = min(total_frame_num, 2*int(video_time))
#     if video_time <= 60:  # 每秒1帧
#         frames_to_extract = min(total_frame_num, int(video_time))
#     elif video_time <= 181:  # 固定抽60帧
#         frames_to_extract = 60
#     elif video_time <= 600:  # 固定抽90帧
#         frames_to_extract = 90
#     else:  # 600秒以上，固定抽120帧
#         frames_to_extract = 120

#     f = max(1, round(total_frame_num / frames_to_extract))  # 每过f帧抽一帧
#     frame_idx = np.linspace(0, total_frame_num - 1, frames_to_extract, dtype=int).tolist()

#     interval = video_frames[1].time - video_frames[0].time if len(video_frames) > 1 else 0

#     if len(frame_idx) > max_frames_num or force_sample:
#         sample_fps = max_frames_num
#         uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
#         frame_idx = uniform_sampled_frames.tolist()

#     frames = [video_frames[i] for i in frame_idx]
#     spare_frames = np.stack([x.to_ndarray(format="rgb24") for x in frames])

#     return spare_frames, video_time, interval


def process_video_with_decord(video_file, data_args):
    vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
    total_frame_num = len(vr)
    video_time = total_frame_num / vr.get_avg_fps()
    avg_fps = round(vr.get_avg_fps() / data_args.video_fps)
    frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
    frame_time = [i/avg_fps for i in frame_idx]

    if data_args.frames_upbound > 0:
        if len(frame_idx) > data_args.frames_upbound or data_args.force_sample:
            uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
            frame_idx = uniform_sampled_frames.tolist()
            frame_time = [i/vr.get_avg_fps() for i in frame_idx]
    
    video = vr.get_batch(frame_idx).asnumpy()
    frame_time = ",".join([f"{i:.2f}s" for i in frame_time])

    num_frames_to_sample = num_frames = len(frame_idx)
    vr.seek(0)
    return video, video_time, frame_time, num_frames_to_sample

def process_video_with_pyav(video_file, data_args):
    container = av.open(video_file)
    # !!! This is the only difference. Using auto threading
    container.streams.video[0].thread_type = "AUTO"

    video_frames = []
    for packet in container.demux():
        if packet.stream.type == 'video':
            for frame in packet.decode():
                video_frames.append(frame)
    total_frame_num = len(video_frames)
    video_time = video_frames[-1].time
    avg_fps = round(total_frame_num / video_time / data_args.video_fps)
    frame_idx = [i for i in range(0, total_frame_num, avg_fps)]

    if data_args.frames_upbound > 0:
        if len(frame_idx) > data_args.frames_upbound:
            uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
            frame_idx = uniform_sampled_frames.tolist()


    frames = [video_frames[i] for i in frame_idx]
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def rank0_print(*args):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(f"Rank {dist.get_rank()}: ", *args)
    else:
        print(*args)


def rank_print(*args):
    if dist.is_initialized():
        print(f"Rank {dist.get_rank()}: ", *args)
    else:
        print(*args)

def build_logger(logger_name, logger_filename):
    global handler

    formatter = logging.Formatter(
        fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # Set the format of root handlers
    if not logging.getLogger().handlers:
        logging.basicConfig(level=logging.INFO)
    logging.getLogger().handlers[0].setFormatter(formatter)

    # Redirect stdout and stderr to loggers
    stdout_logger = logging.getLogger("stdout")
    stdout_logger.setLevel(logging.INFO)
    sl = StreamToLogger(stdout_logger, logging.INFO)
    sys.stdout = sl

    stderr_logger = logging.getLogger("stderr")
    stderr_logger.setLevel(logging.ERROR)
    sl = StreamToLogger(stderr_logger, logging.ERROR)
    sys.stderr = sl

    # Get logger
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)

    # Add a file handler for all loggers
    if handler is None:
        os.makedirs(LOGDIR, exist_ok=True)
        filename = os.path.join(LOGDIR, logger_filename)
        handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
        handler.setFormatter(formatter)

        for name, item in logging.root.manager.loggerDict.items():
            if isinstance(item, logging.Logger):
                item.addHandler(handler)

    return logger


class StreamToLogger(object):
    """
    Fake file-like stream object that redirects writes to a logger instance.
    """

    def __init__(self, logger, log_level=logging.INFO):
        self.terminal = sys.stdout
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def __getattr__(self, attr):
        return getattr(self.terminal, attr)

    def write(self, buf):
        temp_linebuf = self.linebuf + buf
        self.linebuf = ""
        for line in temp_linebuf.splitlines(True):
            # From the io.TextIOWrapper docs:
            #   On output, if newline is None, any '\n' characters written
            #   are translated to the system default line separator.
            # By default sys.stdout.write() expects '\n' newlines and then
            # translates them so this is still cross platform.
            if line[-1] == "\n":
                self.logger.log(self.log_level, line.rstrip())
            else:
                self.linebuf += line

    def flush(self):
        if self.linebuf != "":
            self.logger.log(self.log_level, self.linebuf.rstrip())
        self.linebuf = ""


def disable_torch_init():
    """
    Disable the redundant torch default initialization to accelerate model creation.
    """
    import torch

    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)


def violates_moderation(text):
    """
    Check whether the text violates OpenAI moderation API.
    """
    url = "https://api.openai.com/v1/moderations"
    headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
    text = text.replace("\n", "")
    data = "{" + '"input": ' + f'"{text}"' + "}"
    data = data.encode("utf-8")
    try:
        ret = requests.post(url, headers=headers, data=data, timeout=5)
        flagged = ret.json()["results"][0]["flagged"]
    except requests.exceptions.RequestException as e:
        print(f"######################### Moderation Error: {e} #########################")
        flagged = False
    except KeyError as e:
        print(f"######################### Moderation Error: {e} #########################")
        flagged = False

    return flagged


def pretty_print_semaphore(semaphore):
    if semaphore is None:
        return "None"
    return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
