import time
import torch


def vec_sample_indices_from_classes(t: torch.Tensor, 
                                    n_max: int, 
                                    shuffle_within_class: bool = True, 
                                    perm_len_for_repeat: int = -1,
                                    ):
    """
    Sample indices from a tensor based on specified classes.

    This function samples indices from a tensor `t` which contains classes in increasing order.
    It allows for sampling a specified maximum number of indices per class, with options to
    shuffle within each class and to optimize permutation generation for large vectors.

    Parameters
    ----------
    t : torch.Tensor
        A tensor with a single dimension specifying the classes in increasing order.
    n_max : int
        The maximum number of indices to take from each class.
    shuffle_within_class : bool, optional
        If False, the first items within a class are taken. Default is True.
    perm_len_for_repeat : int, optional
        Creating permutations may be slow for very large vectors. If `perm_len_for_repeat` > 0,
        a pseudo-permutation of length `perm_len_for_repeat` is created and repeated. Default is -1.

    Returns
    -------
    torch.Tensor
        A tensor of indices sampled from the input tensor `t`.

    Examples
    --------
    For `n_max=3`, the result should be:

        t        =  [0,1,1,1,1,1,2,2,2,2,3, 3 ]
        i        =  [0,1,2,3,4,5,6,7,8,9,10,11]
    returns 
        idcs     =  [0,1,2,3,    6,7,8,  10,11]  (if shuffle_within_class = False)
        idcs     =  [0,2,5,3,    9,8,7,  10,11]  (if shuffle_within_class = True)
        
    this can be used for selection from t:
        t[idcs]  =  [0,1,1,1,    2,2,2,  3, 3]

    Notes
    -----
    The algorithm is vectorized and works roughly as follows:
        1. Identify the lengths of each subsequence.
        2. Crop the lengths to the target length.
        3. Use a combination of `torch.cumsum` and `repeat_interleave` to form the indices:
            a) First find offset_idcs:              [0 1 1 1 6 6 6 10 10]
            b) Then find restarted_count_from_0:    [0 0 1 2 0 1 2  0  1]
            c) indices          --> sum them up     [0 1 2 3 6 7 8 10 11]
        4. (Optional) The algorithm so far always takes the first `n_max` within a class.
           To shuffle the indices within a class:
            a) Create a permutation 'perm' to permute the tensor.
            b) Get the argsort indices 'unpermute_indices' to unpermute the tensor.
            c) Apply perm[unpermute_indices][indices].
    """
    with torch.no_grad():
                
        max_plus_1 = t[-1][None] + 1
        zero_t = torch.zeros(size=(1,),dtype=t.dtype, device=t.device)
        
        ## 1.
        diff = torch.cat([t[1:], max_plus_1]) - t
        torch.clip(diff, max=1, out=diff)
        assert torch.all((diff == 0) | (diff == 1)), "Assert montonicity, i.e. t[i+1] - t[i] >= 0"

        jump_up_idcs = torch.where(diff == 1)[0] + 1
        idcs_incl_start_excl_end = torch.cat([zero_t, jump_up_idcs])
        lengths = idcs_incl_start_excl_end[1:] - idcs_incl_start_excl_end[:-1]
        
        ## 2.
        target_lengths = torch.where(lengths < n_max, lengths, n_max)
        
        ## 3.a
        start_indices = idcs_incl_start_excl_end[:-1]
        offset_indices = torch.repeat_interleave(start_indices, target_lengths)
        
        ## 3.b
        cum_lens = torch.cumsum(target_lengths,dim=0)
        indices_to_subtract = torch.cat([zero_t, cum_lens[:-1]])
        indices_to_subtract = torch.repeat_interleave(indices_to_subtract, target_lengths)
        
        lens_arange = torch.arange(target_lengths.sum(), device=t.device)
        restarted_count_from_0 = lens_arange - indices_to_subtract
        
        ## 3.c
        indices = offset_indices + restarted_count_from_0
        
        ## 4
        if shuffle_within_class:
            ## 4.a
            if perm_len_for_repeat > 0:
                ## in case a full permutation is too slow (for very large len(t)),
                # we can do the trick to create a small permutation and repeat it
                perm_len = min(perm_len_for_repeat, len(t))
                n_repeats = len(t) // perm_len + 1
                perm = torch.randperm(perm_len, device=t.device).repeat(n_repeats)[:len(t)]
            else:
                perm = torch.randperm(len(t), device=t.device)

            ## 4.b
            unpermute_select_idcs = torch.argsort(t[perm])
            ## 4.c
            indices = perm[unpermute_select_idcs][indices]
        
        return indices
    
if __name__ == '__main__':
    device = 'cuda:2'
    # tensor = torch.tensor([0,1,1,1,1,1,2,2,2,2,1, 1 ])
    # tensor = torch.tensor([0,1,1,1,1,1,2,2,2,2,3, 3 ])
    tensor = torch.tensor([0,1,1,1,1,1,3, 3, 4 ]).to(device)
    # num_classes = 1_000_000
    # tensor = torch.cat([
    #     torch.full((torch.randint(1, 6, (1,)).item(),), i) for i in range(num_classes)
    # ]).to(device)

    n_max = 3
    start_t = time.time()
    sampled_indices = vec_sample_indices_from_classes(tensor, n_max)
    print(f'needed {time.time()-start_t:0.6f} seconds')
    print(f'{sampled_indices=}')
    tensor[sampled_indices][:100]