def identity_with_cast(q, k, v, offset: int = 0):
    return q.to(v.dtype), k.to(v.dtype), v
