from qwen_vl_utils.vision_process import *


def _read_video_decord_frame_idx(
    ele: dict, frame_idx: list, nframes: int
) -> torch.Tensor:
    """read video using decord.VideoReader

    Args:
        ele (dict): a dict contains the configuration of video.
        support keys:
            - video: the path of video. support "file://", "http://", "https://" and local path.
            - video_start: the start time of video.
            - video_end: the end time of video.
    Returns:
        torch.Tensor: the video tensor with shape (T, C, H, W).
    """
    import decord
    video_path = ele["video"]
    st = time.time()
    vr = decord.VideoReader(video_path)
    # TODO: support start_pts and end_pts
    if 'video_start' in ele or 'video_end' in ele:
        raise NotImplementedError("not support start_pts and end_pts in decord for now.")
    total_frames, video_fps = len(vr), vr.get_avg_fps()
    logger.info(f"decord:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
    # nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
    # idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
    idx = frame_idx[:nframes]
    idx = sorted(idx)
    idx_new = idx
    for i, id_data in enumerate(idx):
        if id_data > total_frames - 1:
            idx_new = idx[:i]
            break
    video = vr.get_batch(idx_new).asnumpy()
    video = torch.tensor(video).permute(0, 3, 1, 2)  # Convert to TCHW format
    return video

def fetch_video_frame_idx(ele: dict, frame_idx, nframes, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:

    video = _read_video_decord_frame_idx(ele, frame_idx, nframes)
    nframes, _, height, width = video.shape

    min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
    total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
    max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
    max_pixels = ele.get("max_pixels", max_pixels)
    if "resized_height" in ele and "resized_width" in ele:
        resized_height, resized_width = smart_resize(
            ele["resized_height"],
            ele["resized_width"],
            factor=image_factor,
        )
    else:
        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=image_factor,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
    print(f"{height}*{width} -> {resized_height}*{resized_width}")
    video = transforms.functional.resize(
        video,
        [resized_height, resized_width],
        interpolation=InterpolationMode.BICUBIC,
        antialias=True,
    ).float()
    return video

def process_vision_info_frame_idx(
    conversations: list[dict] | list[list[dict]],frame_idx, nframes: int
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]:
    vision_infos = extract_vision_info(conversations)
    ## Read images or videos
    image_inputs = []
    video_inputs = []
    for vision_info in vision_infos:
        if "video" in vision_info:
            video_inputs.append(fetch_video_frame_idx(vision_info, frame_idx, nframes))
        else:
            raise ValueError("image, image_url or video should in content.")
    if len(image_inputs) == 0:
        image_inputs = None
    if len(video_inputs) == 0:
        video_inputs = None
    return image_inputs, video_inputs