import torch

def sparse_gather(input: torch.Tensor, dim: int, index: torch.Tensor,
                  default: float | torch.Tensor = 0.0) -> torch.Tensor:
    """
    Gather values from a sparse COO tensor along a specified dimension.

    This function behaves like torch.gather but works with sparse COO tensors.
    For positions that are not stored in the sparse tensor, it returns the default value.

    Like torch.gather, this selects values from input along dimension dim at positions
    specified by index. The key difference is that if a position isn't stored in the
    sparse tensor, default is returned instead of raising an error.

    Args:
        input: Sparse COO tensor to gather from
        dim: The dimension along which to gather values
        index: Dense tensor of indices specifying which values to gather.
               Must have same number of dimensions as input.
        default: Default value to return for positions not in the sparse tensor

    Returns:
        Dense tensor with same shape as index, containing gathered values
    """
    if not input.is_sparse:
        # If input is dense, use regular gather
        return torch.gather(input, dim, index)

    if input.layout != torch.sparse_coo:
        raise ValueError("sparse_gather only supports COO sparse tensors")

    # Coalesce to ensure indices are unique
    input = input.coalesce()

    # Get sparse tensor components
    sparse_indices = input.indices()  # Shape: [ndim, nnz]
    sparse_values = input.values()  # Shape: [nnz]
    input_shape = input.shape
    device = input.device
    dtype = input.dtype

    # Normalize dimension
    ndim = len(input_shape)
    dim = dim if dim >= 0 else ndim + dim
    if not (0 <= dim < ndim):
        raise ValueError(f"dim {dim} out of range for tensor of dimension {ndim}")

    # Validate index tensor
    if index.ndim != ndim:
        raise ValueError(f"index tensor must have same number of dimensions as input. "
                         f"Got index.ndim={index.ndim}, input.ndim={ndim}")

    # Initialize output with default values
    default = torch.as_tensor(default, dtype=dtype, device=device)
    output = torch.broadcast_to(default, index.shape).clone().to(dtype=dtype, device=device)

    if sparse_values.numel() == 0:
        # No stored values, return all defaults
        return output

    # Create coordinate grids corresponding to each position in the index tensor
    coord_grids = torch.meshgrid([torch.arange(s, device=device, dtype=torch.int32)
                                  for s in index.shape], indexing='ij')

    # For each position in the output, we need to construct the lookup coordinate
    # by replacing the dim-th coordinate with the corresponding index value
    lookup_coords = []
    for d in range(ndim):
        if d == dim:
            # Use the index values for the gather dimension
            lookup_coords.append(index)
        else:
            # Use the coordinate grid for other dimensions
            lookup_coords.append(coord_grids[d])

    # Stack coordinates to get shape [ndim, *index.shape]
    lookup_indices = torch.stack(lookup_coords, dim=0)

    # Flatten the lookup indices to [ndim, num_elements]
    original_shape = lookup_indices.shape[1:]
    lookup_indices_flat = lookup_indices.view(ndim, -1)
    output_flat = output.view(-1)

    # For efficient lookup, we'll match each lookup coordinate against sparse coordinates
    # This can be done more efficiently using broadcasting and matching

    if sparse_indices.size(1) > 0:
        # More memory-efficient approach: use hash-like lookups instead of large match tensor
        # Convert sparse coordinates to flat indices for efficient lookup
        sparse_strides = torch.tensor([input_shape[i + 1:].numel() for i in range(ndim)],
                                      device=device, dtype=torch.long)
        sparse_flat_indices = (sparse_indices * sparse_strides.unsqueeze(1)).sum(dim=0)

        # Convert lookup coordinates to flat indices
        lookup_flat_indices = (lookup_indices_flat * sparse_strides.unsqueeze(1)).sum(dim=0)

        # Use searchsorted for efficient lookup (requires sorted sparse indices)
        sorted_sparse_flat, sort_indices = torch.sort(sparse_flat_indices)

        # Find positions of lookup indices in the sorted sparse indices
        positions = torch.searchsorted(sorted_sparse_flat, lookup_flat_indices)

        # Handle boundary conditions and check for exact matches
        valid_positions = positions < len(sorted_sparse_flat)
        exact_matches = torch.zeros_like(lookup_flat_indices, dtype=torch.bool)

        if valid_positions.any():
            # Check which positions have exact matches
            candidate_indices = sorted_sparse_flat[positions[valid_positions]]
            lookup_subset = lookup_flat_indices[valid_positions]
            matches_subset = (candidate_indices == lookup_subset)
            exact_matches[valid_positions] = matches_subset

            # Get the corresponding sparse values for exact matches
            if matches_subset.any():
                matched_positions = positions[valid_positions][matches_subset]
                original_sparse_indices = sort_indices[matched_positions]
                matched_lookup_positions = torch.where(valid_positions)[0][matches_subset]
                output_flat[matched_lookup_positions] = sparse_values[original_sparse_indices]

    return output_flat.view(original_shape)
