def get_ring_indices_vectorized(H, W, device='cpu'):
    """
    Generates ring-scan indices from inner to outer shells with alternating directions.
    The scan for each shell starts from its top-left corner.
    
    Args:
        H, W: Dimensions of the feature map.
        device: Torch device (e.g., 'cuda' or 'cpu').
        
    Returns:
        indices: 1D tensor for reordering the flattened sequence.
        inv_indices: 1D tensor for restoring the original row-major order.
    """
    
    # 1. Setup centered coordinate grid (y: rows, x: columns)
    y = torch.linspace(-(H - 1) / 2, (H - 1) / 2, H, device=device)
    x = torch.linspace(-(W - 1) / 2, (W - 1) / 2, W, device=device)
    grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')

    # 2. Calculate L-infinity distance to define concentric square shells
    # L-infinity = max(|x|, |y|), which creates the "ring" structure
    ring_dist = torch.max(torch.abs(grid_y), torch.abs(grid_x))
    
    # 3. Assign a unique ID to each shell (0 is the innermost)
    unique_dists = torch.unique(ring_dist)
    shell_id = torch.zeros_like(ring_dist)
    for i, d in enumerate(unique_dists):
        shell_id[ring_dist == d] = i

    # 4. Compute raw angles using atan2
    # In this grid, the top-left corner (-radius, -radius) has an angle of -3pi/4
    angle = torch.atan2(grid_y, grid_x)
    
    # 5. Normalize angles: shift so that top-left is 0.0 and range is [0, 2pi)
    start_offset = -3 * math.pi / 4
    norm_angle = (angle - start_offset) % (2 * math.pi)

    # 6. Handle alternating directions (CW vs CCW)
    # shell_id % 2 == 0: Clockwise (angle increases)
    # shell_id % 2 != 0: Counter-Clockwise (reverse the normalized angle)
    final_angle = torch.where(
        shell_id.long() % 2 == 0,
        norm_angle,
        (2 * math.pi - norm_angle) % (2 * math.pi)
    )

    # 7. Generate sorting weights
    # Priority 1: ring_dist (inner shells first)
    # Priority 2: final_angle (order within the shell)
    sort_weight = ring_dist * 100.0 + final_angle
    
    # 8. Obtain the flattened indices
    indices = torch.argsort(sort_weight.reshape(-1))
    
    # 9. Compute inverse indices for reconstruction
    inv_indices = torch.zeros_like(indices)
    inv_indices[indices] = torch.arange(len(indices), device=device)
    
    return indices, inv_indices