from fractions import Fraction
import bisect
import math
import os
import sys

import torch
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from utils.io_overrides import (
    read_chunk,
    read_video_timestamps,
)


class _DummyDataset(object):
    """
    Dummy dataset used for DataLoader in VideoClips.
    Defined at top level so it can be pickled when forking.
    """

    def __init__(self, x):
        self.x = x

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return read_video_timestamps(self.x[idx])


class VideoClips(object):
    """
    Given a list of video files, computes all consecutive subvideos of size
    `clip_length_in_frames`, where the distance between each subvideo in the
    same video is defined by `frames_between_clips`.

    Creating this instance the first time is time-consuming, as it needs to
    decode all the videos in `video_paths`. It is recommended that you
    cache the results after instantiation of the class.

    Recreating the clips for different clip lengths is fast, and can be done
    with the `compute_clips` method.

    Arguments:
        video_paths (List[str]): paths to the video files
        clip_length_in_frames (int): size of a clip in number of frames
        frames_between_clips (int): step (in frames) between each clip
        num_workers (int): how many subprocesses to use for data loading.
            0 means that the data will be loaded in the main process. (default: 0)
    """

    def __init__(
        self,
        video_paths,
        clip_length_in_frames=16,
        frames_between_clips=1,
        _precomputed_metadata=None,
    ):

        self.video_paths = video_paths
        self.cumulative_sizes = []

        if _precomputed_metadata is None:
            self._compute_frame_pts()
        else:
            self._init_from_metadata(_precomputed_metadata)
        self.compute_clips(clip_length_in_frames, frames_between_clips)

    def _collate_fn(self, x):
        return [([float(xxx) for xxx in xx[0]], xx[1]) for xx in x]

    def _compute_frame_pts(self):
        self.video_pts = []
        self.video_fps = []

        # strategy: use a DataLoader to parallelize read_video_timestamps
        # so need to create a dummy dataset first
        import torch.utils.data

        dl = torch.utils.data.DataLoader(
            _DummyDataset(self.video_paths),
            batch_size=16,
            num_workers=0,
            collate_fn=self._collate_fn,
        )

        with tqdm(total=len(dl)) as pbar:
            for batch in dl:
                pbar.update(1)
                clips, fps = list(zip(*batch))
                clips = [torch.as_tensor(c) for c in clips]
                self.video_pts.extend(clips)
                self.video_fps.extend(fps)

    def _init_from_metadata(self, metadata):
        self.video_paths = metadata["video_paths"]
        assert len(self.video_paths) == len(metadata["video_pts"])
        self.video_pts = metadata["video_pts"]
        assert len(self.video_paths) == len(metadata["video_fps"])
        self.video_fps = metadata["video_fps"]

    @property
    def metadata(self):
        _metadata = {
            "video_paths": self.video_paths,
            "video_pts": self.video_pts,
            "video_fps": self.video_fps,
        }
        return _metadata

    def compute_clips(self, num_frames, step):
        """
        Compute all consecutive sequences of clips from video_pts.
        Always returns clips of size `num_frames`, meaning that the
        last few frames in a video can potentially be dropped.

        Arguments:
            num_frames (int): number of frames for the clip
            step (int): distance between two clips
        """
        self.num_frames = num_frames
        self.step = step
        self.clips = []
        for video_pts, fps in zip(self.video_pts, self.video_fps):
            clips = self._unfold(video_pts, num_frames, step)
            self.clips.append(clips)
        clip_lengths = torch.as_tensor([len(v) for v in self.clips])
        self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

    def _unfold(self, tensor, size, step, dilation=1):
        """
        similar to tensor.unfold, but with the dilation
        and specialized for 1d tensors

        Returns all consecutive windows of `size` elements, with
        `step` between windows. The distance between each element
        in a window is given by `dilation`.
        """
        assert tensor.dim() == 1
        o_stride = tensor.stride(0)
        numel = tensor.numel()
        new_stride = (step * o_stride, dilation * o_stride)
        new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
        if new_size[0] < 1:
            new_size = (0, size)
        return torch.as_strided(tensor, new_size, new_stride)

    def __len__(self):
        return self.num_clips()

    def num_videos(self):
        return len(self.video_paths)

    def num_clips(self) -> int:
        """
        Number of subclips that are available in the video list.
        """
        if len(self.cumulative_sizes) > 0:
            return self.cumulative_sizes[-1]
        else:
            return 0

    def get_clip_location(self, idx):
        """
        Converts a flattened representation of the indices into a video_idx, clip_idx
        representation.
        """
        video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if video_idx == 0:
            clip_idx = idx
        else:
            clip_idx = idx - self.cumulative_sizes[video_idx - 1]
        return video_idx, clip_idx

    def get_clip(self, idx, vsize:tuple=(86, 48)):
        """
        Gets a subclip from a list of videos.

        Arguments:
            idx (int): index of the subclip. Must be between 0 and num_clips().

        Returns:
            video (Tensor)
            audio (Tensor)
            info (Dict)
            video_idx (int): index of the video in `video_paths`
        """
        if idx >= self.num_clips():
            raise IndexError(
                "Index {} out of range "
                "({} number of clips)".format(idx, self.num_clips())
            )
        video_idx, clip_idx = self.get_clip_location(idx)
        video_path = self.video_paths[video_idx]
        clip_pts = self.clips[video_idx][clip_idx]

        start_pts = clip_pts[0].item()
        end_pts = clip_pts[-1].item()

        video, audio, info, computed_vframes, computed_asamples = read_chunk(
            video_path,
            start_pts, end_pts,
            vframes=self.num_frames, aframes='auto',
            vsize=vsize,
        )

        assert len(video) == self.num_frames
        assert len(audio) == computed_asamples

        return video.transpose(1,0), audio.transpose(1,0), info, video_idx
