
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, group_inds):
        """
        group_inds: (N,) long
        Returns:
            out_inds: (N,) int - 0..K-1 index within group
        """
        # This assigns a sequential index to each element within its group.
        # e.g. [0, 0, 1, 0, 1] -> [0, 1, 0, 2, 1]
        
        # 1. Sort to group identicals
        perm = torch.argsort(group_inds)
        sorted_groups = group_inds[perm]
        
        # 2. Compute group boundaries / runs
        # We can subtract index of start of run
        
        # Find unique groups and their counts -> not enough, we need to enumerate.
        
        # Vectorized cumsum with reset implementation:
        # Calculate diff: 1 where group changes, 0 where same
        # But cumsum of diff only gives group ID, not inner index.
        
        # Trick:
        # global_index - start_index_of_group
        
        # 1. Start indices
        # diff = sorted_groups[1:] != sorted_groups[:-1]
        # boundaries = logical indices.
        
        out = torch.zeros_like(group_inds)
        
        # Determine unique groups and counts
        unique, counts = torch.unique(sorted_groups, return_counts=True)
        # unique: [0, 1], counts: [3, 2]
        
        # Create a "resetting range" for each count
        # method: concat ranges?
        # range(3) -> [0, 1, 2]
        # range(2) -> [0, 1]
        # Concat -> [0, 1, 2, 0, 1]
        
        # This vectorized construction logic:
        # Can use repeat_interleave logic? No, we need increasing sequences.
        # Construct flat tensor? `torch.cat([torch.arange(c) for c in counts])`
        # This is slow if many groups.
        
        # Faster way? 
        # But reference implementation doesn't need to be fastest. 
        # For N=large, iteration in python is bad.
        
        ranges = [torch.arange(c.item(), device=group_inds.device) for c in counts]
        inner_indices = torch.cat(ranges)
        
        # Map back using inv_perm
        # out[perm] = inner_indices
        out.scatter_(0, perm, inner_indices)
        
        return out.long()

def get_init_inputs():
    return []

def get_inputs():
    N = 100
    ind = torch.randint(0, 10, (N,)).long()
    return [ind]
