def capped_shortest_path_mat(A, edge_index, k=6, batch_size=100000):
    """
    Compute shortest path matrix using PyTorch operations on GPU.
    
    Args:
        A: torch.sparse.FloatTensor - Adjacency matrix
        edge_index: torch.LongTensor - Edge indices
        k: int - Maximum path length to consider
        batch_size: int - Batch size for processing
    
    Returns:
        torch.FloatTensor - Inverse of shortest path lengths
    """
    # Ensure inputs are on GPU
    device = edge_index.device
    A = A.to(device)
    
    # Compute powers of adjacency matrix using sparse operations
    A2 = torch.sparse.mm(A, A)
    A3 = torch.sparse.mm(A2, A)
    
    link_loader = DataLoader(range(edge_index.size(1)), batch_size)
    score_list = []
    
    for ind in tqdm(link_loader):
        src, dst = edge_index[0, ind], edge_index[1, ind]
        
        # Initialize scores with large values
        cur_scores = torch.full((len(src),), 999.0, device=device)
        
        if k >= 6:
            # Convert to dense for multiplication as PyTorch sparse doesn't support element-wise mult
            scores = torch.sum(A3[src].to_dense() * A3[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 6
            
        if k >= 5:
            scores = torch.sum(A2[src].to_dense() * A3[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 5
            
        if k >= 4:
            scores = torch.sum(A2[src].to_dense() * A2[dst].to_dense(), dim=1)
            cur_scores[scores > 0] = 4
        
        # Direct lookups for shorter paths
        src_l, dst_l = src.tolist(), dst.tolist()
        mask = A3[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 3
        
        mask = A2[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 2
        
        mask = A[src_l, dst_l].to_dense().squeeze() > 0
        cur_scores[mask] = 1
        
        score_list.append(cur_scores.cpu())  # Move to CPU for list storage
    
    # Combine all scores and return inverse
    return 1 / torch.cat(score_list, 0)
