import random
import itertools
import torch

class SamplingPolicy(torch.utils.data.Sampler):

    def __init__(self, num_snippets, sample_size, sampling_window):
        """
        Assuming our dataset is an instance of torch.utils.data.Dataset which chronologically
        enumerates all snippets from all the long-form contents, this class allows us to
        iterate over the dataset according to our proposed Sampling Policy.

        num_snippets (list):
            The number of snippets associated with each long-form content, i.e [M_0, M_1, M_2, ...., M_N]
        sample_size (int):
            Number of snippets sampled from each long-form content in creation of a minibatch.
        sampling_window (int):
            The farthest two samples from the same long-form content can be from one another.
        """
        self.k = sample_size
        self.w = sampling_window
        self.splits = list(zip([sum(num_snippets[:i]) for i in range(len(num_snippets))],num_snippets))

    def len_helper(self, l):
        if self.w == -1:
            """ Random sample from entire video """
            return (l//self.k) * self.k
        elif self.w >= self.k:
            assert self.w % self.k == 0
            """ Random sample withing a fixed-length temporal window """
            return (l//self.w) * self.w + ((l%self.w)//self.k) * self.k
        else:
            raise ValueError

    def grouping_helper(self, bias, l):
        if self.w == -1:
            """ Random sample from entire video """
            x = random.sample(population=range(bias, bias+l), k=l)
            # group every k snippets and drop the last batch if non-divisible
            _indices = list(zip(*[iter(x)]*self.k))

        elif self.w >= self.k:
            """ Random sample withing a fixed-length temporal window """
            assert self.w % self.k == 0
            _indices = []
            x = range(bias, bias+l)
            for i in range(0, len(x), self.w):
                chunk = x[i:i+self.w]
                chunk_perm = random.sample(population=chunk, k=len(chunk))
                _indices.append(zip(*[iter(chunk_perm)]*self.k))
            _indices = list(itertools.chain(*_indices))
        else:
            raise ValueError
        random.shuffle(_indices)
        return _indices

    def __iter__(self):
        # within-movie grouping of snippets into bundles of size k
        Q = [self.grouping_helper(bias, l) for bias, l in self.splits]
        # hierarchical sampling
        gindx, vindx = list(map(len, Q)), list(range(len(Q)))
        indices = []
        while vindx:
            random.shuffle(vindx)
            for i in vindx:
                indices.append(Q[i][gindx[i] - 1])
                gindx[i] -= 1
            vindx = list(filter(lambda i: gindx[i], vindx))
        indices = list(itertools.chain(*indices))
        return iter(indices)

    def __len__(self):
        return sum([self.len_helper(l) for _, l in self.splits])
