import torch

def distance_to_next_zero_1d(x: torch.Tensor) -> torch.Tensor:
    """
    Given a 1D tensor `x`, this function returns a tensor of the same size where
    each element represents the distance from the current index to the index of
    the first occurrence of 0 to its right (including itself if it is 0).

    For example, if x = [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0],
    the output would be [4, 3, 2, 1, 0, 1, 0, 3, 2, 1, 0].

    Args:
        x (torch.Tensor): A 1D tensor containing integer values (e.g., 0 or 1).

    Returns:
        torch.Tensor: A 1D tensor containing the distances to the next zero
                      for each position in the original tensor.
    """

    x = torch.cat([x.long(), torch.tensor([0], device=x.device)])  # Add a zero at the end
    # 1) Find all indices where x is 0
    zero_indices = (x == 0).nonzero(as_tuple=True)[0]
    # zero_indices will be sorted in ascending order by default

    # 2) Create a range tensor (0 to len(x)-1) to represent all positions
    i_positions = torch.arange(len(x), device=x.device)

    # 3) Use torch.searchsorted to find, for each position i, the insertion point
    #    in zero_indices (side="left"), giving us the first zero >= i.
    pos = torch.searchsorted(zero_indices, i_positions, side="left")

    # 4) Clamp pos so it doesn't go beyond zero_indices range
    pos = pos.clamp(max=zero_indices.shape[0] - 1)

    # 5) Calculate the distances
    distances = zero_indices[pos] - i_positions

    # 6) If there is no zero to the right, this could be negative.
    #    You can optionally mask or handle those positions here.
    #    We'll assume there's always a zero to the right in this example.

    return distances.float()[:-1]  # Remove the last element (the zero we added)


if __name__ == "__main__":
    # Example usage
    # Suppose we have a batch of two sequences:
    # 1) [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0]
    # 2) [1, 0, 1, 1, 0]
    x_2d = torch.tensor([
        [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0],
        [1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1]  # padded or same length for simplicity
    ])

    # Apply the function to each sequence
    distances = distance_to_next_zero_1d(x_2d[0])
    print(distances)
    # tensor([4., 3., 2., 1., 0., 1., 0., 3., 2., 1., 0.])

    distances = distance_to_next_zero_1d(x_2d[1])
    print(distances)
    # tensor([1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.])
