import gc
import logging
import math
import re
import warnings
from typing import Tuple, List, Optional

from PIL import Image
import numpy as np
import torch

import av
av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
    av = ImportError("PyAV too old")


# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10


def _read_from_stream(container, start_pts, end_pts, stream, stream_name, target_frames:Optional[int]=None):
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

    start_offset = int(start_pts / stream.time_base)
    if end_pts != float("inf"):
        end_offset = int(math.ceil(end_pts / stream.time_base))
    else:
        end_offset = end_pts

    frames = {}
    carry_over = []
    try:
        # TODO check if stream needs to always be the video stream here or not
        container.seek(start_offset, any_frame=False, backward=True, stream=stream)
    except av.AVError as e:
        logging.error(f"Invalid stream? ({e})")
        return []
    try:
        pts_bug = False
        for frame in container.decode(**stream_name):
            if pts_bug or (start_offset > 0 and frame.pts < start_offset):
                pts_bug = frame.pts < start_offset
                if pts_bug:
                    continue
            if target_frames is not None and target_frames <= len(frames):
                carry_over.append(frame)
            else:
                frames[frame.pts] = frame
            if frame.pts >= end_offset:
                break
    except av.AVError as e:
        logging.error(f"Decoding issue? No further decoding will happen. ({e})")
        return []

    # TODO This should not be necessary with Python >= 3.7
    result = [
        frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
    ]

    if target_frames is not None and len(result) > 0:
        missing_count = target_frames - len(result)
        if missing_count > 0:
            if result[0].pts > start_offset:
                lip_sync_gap = int(0.022 / stream.time_base)
                container.seek(max([0, start_offset - lip_sync_gap]), any_frame=False, backward=True, stream=stream)
                predecessors = [f for i, f in enumerate(container.decode(**stream_name)) if i < lip_sync_gap]
                if len(predecessors) > missing_count:
                    predecessors = predecessors[-missing_count:]
                result = predecessors + result

            missing_count = target_frames - len(result)
            if missing_count > 0:
                successors = carry_over[0:missing_count]
                result = result + successors

            if target_frames > len(result):
                logging.warning(f"Not enough audio data. Is producing less data than necessary.")

    return result


def read_chunk(filename, start_pts=0, end_pts=None, vframes='auto', aframes='auto', vsize:tuple=(86, 48)):
    """
    Reads an AV chunk from a file, returning both the video and audio frames.

    Parameters
    ----------
    filename : str
        path to the video file
    start_pts : float / Fraction
        the start presentation time of the video
    end_pts : float / Fraction
        the end presentation time

    Returns
    -------
    vframes : Tensor[T, H, W, C]
        the `T` video frames
    aframes : Tensor[K, L]
        the audio frames, where `K` is the number of channels and `L` is the
        number of points
    info : Dict
        metadata for the video and audio. Can contain the fields video_fps (float)
        and audio_fps (int)
    """
    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
        raise ValueError(
            "end_pts should be larger than start_pts, got "
            f"start_pts={start_pts}, end_pts={end_pts}"
        )

    info = {}
    video_frames = []
    audio_frames = []

    container = av.open(filename, metadata_errors="ignore")
    try:
        if container.streams.video:
            stream = container.streams.video[0]
            computed_vframes = vframes
            if computed_vframes == 'auto':
                computed_vframes = max([
                    int((end_pts - start_pts) / stream.time_base) - 1,
                    1
                ])
            logging.debug(f"Chunking video frames to {computed_vframes} frames.")
            video_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                stream,
                {"video": 0},
                computed_vframes,
            )
            video_fps = container.streams.video[0].average_rate
            if video_fps is not None:
                info["video_fps"] = float(video_fps)

        if container.streams.audio:
            stream = container.streams.audio[0]
            computed_aframes = aframes
            if computed_aframes == 'auto':
                frame_size = stream.codec_context.frame_size
                computed_aframes = max([
                    int((end_pts - start_pts) / stream.time_base / frame_size) - 1,
                    1
                ])
            logging.debug(f"Chunking audio frames to {computed_aframes} frames.")
            audio_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                stream,
                {"audio": 0},
                computed_aframes,
            )
            info["audio_fps"] = container.streams.audio[0].rate
    finally:
        container.close()

    vframes = [_resize_img_ndarray(frame.to_rgb().to_ndarray(), size=vsize).transpose((2,0,1)) for frame in video_frames]
    aframes = [frame.to_ndarray() for frame in audio_frames]

    try:
        vframes = torch.as_tensor(np.stack(vframes))
        aframes = torch.as_tensor(np.stack(aframes))
    except:
        vframes = aframes = None

    return vframes, aframes, info, computed_vframes, computed_aframes


def _resize_img_ndarray(frame:np.ndarray, size:tuple=(86, 48)) -> np.ndarray:
    '''
    This resizes with PIL's resize, so it does not try to preserver aspect ratio.
    '''
    img = Image.fromarray(frame)
    if img.size != size:
        img = img.resize(size)
    return np.asarray(img)

def _can_read_timestamps_from_packets(container):
    extradata = container.streams[0].codec_context.extradata or []
    return b"Lavc" in extradata


def read_video_timestamps(filename):
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

    Parameters
    ----------
    filename : str
        path to the video file

    Returns
    -------
    pts : List[Fraction]
        presentation timestamps for each one of the frames in the video.
    video_fps : int
        the frame rate for the video

    """
    ts = []
    video_fps = None

    container = av.open(filename, metadata_errors="ignore")
    try:
        if container.streams.video:
            stream = container.streams.video[0]
            if _can_read_timestamps_from_packets(container):
                # fast path
                ts = [
                    x.pts * stream.time_base
                    for x in container.demux(video=0)
                    if x.pts is not None
                ]
            else:
                video_frames = _read_from_stream(
                    container,
                    0, float("inf"),
                    stream, {"video": 0}
                )
                ts = [x.pts * stream.time_base for x in video_frames]
            video_fps = float(stream.average_rate)
    finally:
        container.close()

    return ts, video_fps
