import torch

def sparse_reshape(x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
    """Support only (B,S,V) <-> (BS,V) without densifying."""
    assert x.layout == torch.sparse_coo

    # Identity: same shape -> return as-is
    if tuple(x.shape) == tuple(new_shape):
        return x

    if x.dim() == 3 and len(new_shape) == 2:
        B, S, V = x.shape
        BS, V2 = new_shape
        assert V2 == V and BS == B * S
        b, s, v = x.indices()
        rows = b * S + s
        idx2 = torch.stack([rows, v], dim=0)
        return torch.sparse_coo_tensor(idx2, x.values(), (BS, V)).coalesce()
    if x.dim() == 2 and len(new_shape) == 3:
        BS, V = x.shape
        B, S, V2 = new_shape
        assert V2 == V and BS == B * S
        rows, v = x.indices()
        b = rows // S
        s = rows % S
        idx3 = torch.stack([b, s, v], dim=0)
        return torch.sparse_coo_tensor(idx3, x.values(), (B, S, V)).coalesce()
    raise NotImplementedError(f"Only (B,S,V) <-> (BS,V) is supported, given '{x.shape=}' and '{new_shape=}'")


