def csr_to_sparse_tensor(csr_mat, device='cuda'):
    """
    Convert a scipy.sparse.csr_matrix to torch.sparse.FloatTensor.
    
    Args:
        csr_mat: scipy.sparse.csr_matrix - Input CSR matrix
        device: str - Target device ('cuda' or 'cpu')
    
    Returns:
        torch.sparse.FloatTensor - Sparse tensor on specified device
    """
    # Ensure the input is in CSR format
    if not isinstance(csr_mat, csr_matrix):
        try:
            csr_mat = csr_mat.tocsr()
        except AttributeError:
            raise ValueError("Input matrix must be scipy.sparse.csr_matrix or convertible to CSR")
    
    # Get the indices and values
    row_indices = csr_mat.indptr
    col_indices = csr_mat.indices
    values = csr_mat.data
    
    # Convert to COO format (which PyTorch expects)
    row = np.repeat(np.arange(len(row_indices) - 1), np.diff(row_indices))
    
    # Stack indices into a 2xN matrix
    indices = torch.from_numpy(np.vstack((row, col_indices))).long()
    values = torch.from_numpy(values).float()
    size = torch.Size(csr_mat.shape)
    
    # Create the sparse tensor and move to specified device
    sparse_tensor = torch.sparse_coo_tensor(indices, values, size)
    return sparse_tensor.to(device)
