@nki.jit
def solution(q, k, v):
  """Beta 2 migrated attention kernel (matches `4_attention_test.py` semantics).

  Inputs are shaped (d_head=128, seqlen). Output is shaped (seqlen, d_head).
  """
  d_head, seqlen = q.shape
  assert q.shape == k.shape == v.shape
  assert d_head == nl.tile_size.pmax  # 128
  assert seqlen % 128 == 0
  assert seqlen >= 512

  PMAX = nl.tile_size.pmax                # 128
  Q_TILE = PMAX                           # 128 queries per tile
  K_TILE_SOFTMAX = nl.tile_size.gemm_moving_fmax  # 512 keys per tile for QK
  K_TILE_OUT = PMAX                       # 128 keys per tile for output matmul

  out = nl.ndarray((seqlen, d_head), dtype=q.dtype, buffer=nl.shared_hbm)

  n_q_tiles = seqlen // Q_TILE
  n_k_tiles_softmax = seqlen // K_TILE_SOFTMAX
  n_k_tiles_out = seqlen // K_TILE_OUT

  # Temporary tiles reused per q-tile.
  q_tile = nl.ndarray((PMAX, Q_TILE), dtype=q.dtype, buffer=nl.sbuf)
  k_tile = nl.ndarray((PMAX, K_TILE_SOFTMAX), dtype=k.dtype, buffer=nl.sbuf)

  qk_buf = nl.ndarray((PMAX, seqlen), dtype=nl.float32, buffer=nl.sbuf)
  row_max = nl.ndarray((PMAX, 1), dtype=nl.float32, buffer=nl.sbuf)
  norm_buf = nl.ndarray((PMAX, seqlen), dtype=nl.float32, buffer=nl.sbuf)
  exp_buf = nl.ndarray((PMAX, seqlen), dtype=nl.float32, buffer=nl.sbuf)
  sum_exp = nl.ndarray((PMAX, 1), dtype=nl.float32, buffer=nl.sbuf)
  inv_sum = nl.ndarray((PMAX, 1), dtype=nl.float32, buffer=nl.sbuf)
  scores = nl.ndarray((PMAX, seqlen), dtype=nl.float32, buffer=nl.sbuf)

  scores_tile = nl.ndarray((PMAX, K_TILE_OUT), dtype=nl.float32, buffer=nl.sbuf)
  v_tile = nl.ndarray((PMAX, K_TILE_OUT), dtype=v.dtype, buffer=nl.sbuf)
  attn_out = nl.ndarray((Q_TILE, d_head), dtype=q.dtype, buffer=nl.sbuf)

  for i_q in nl.affine_range(n_q_tiles):
    q_start = i_q * Q_TILE
    q_end = q_start + Q_TILE

    # Load q tile once.
    nisa.dma_copy(dst=q_tile, src=q[0:PMAX, q_start:q_end])

    # Build qk_buf for this q tile in 512-key chunks.
    for i_k in nl.affine_range(n_k_tiles_softmax):
      k_start = i_k * K_TILE_SOFTMAX
      k_end = k_start + K_TILE_SOFTMAX
      nisa.dma_copy(dst=k_tile, src=k[0:PMAX, k_start:k_end])
      # Declare qk_psum inside the loop so each nc_matmul starts fresh (no cross-iteration accumulation).
      qk_psum = nl.ndarray((Q_TILE, K_TILE_SOFTMAX), dtype=nl.float32, buffer=nl.psum)
      nisa.nc_matmul(dst=qk_psum, stationary=q_tile, moving=k_tile)
      nisa.tensor_copy(dst=qk_buf[0:PMAX, k_start:k_end], src=qk_psum)

    # Softmax across seqlen dimension for each query row.
    nisa.tensor_reduce(dst=row_max, op=nl.maximum, data=qk_buf, axis=1, keepdims=True)
    nisa.tensor_scalar(dst=norm_buf, data=qk_buf, op0=nl.subtract, operand0=row_max)
    nisa.activation(dst=exp_buf, op=nl.exp, data=norm_buf)
    nisa.tensor_reduce(dst=sum_exp, op=nl.add, data=exp_buf, axis=1, keepdims=True)
    nisa.reciprocal(dst=inv_sum, data=sum_exp)
    nisa.tensor_scalar(dst=scores, data=exp_buf, op0=nl.multiply, operand0=inv_sum)

    # Attention output: accumulate in sbuf, not psum
    attn_accum = nl.ndarray((Q_TILE, d_head), dtype=nl.float32, buffer=nl.sbuf)
    nisa.memset(dst=attn_accum, value=0.0)

    scores_t = nl.ndarray((PMAX, Q_TILE),    dtype=nl.float32, buffer=nl.sbuf)
    v_t      = nl.ndarray((PMAX, K_TILE_OUT), dtype=nl.float32, buffer=nl.sbuf)
    tmp_sbuf = nl.ndarray((Q_TILE, d_head), dtype=nl.float32, buffer=nl.sbuf)

    for i_k in nl.affine_range(n_k_tiles_out):
      k_start = i_k * K_TILE_OUT
      k_end   = k_start + K_TILE_OUT

      nisa.tensor_copy(dst=scores_tile, src=scores[0:PMAX, k_start:k_end])
      nisa.dma_copy(dst=v_tile, src=v[0:PMAX, k_start:k_end])

      # Tensor Engine nc_transpose: SBUF -> PSUM.  Declare psum INSIDE the loop
      # so the compiler allocates a fresh slot each iteration (overwrite, not accumulate).
      # Then tensor_copy brings the transposed tile to SBUF for nc_matmul.
      scores_t_psum = nl.ndarray((PMAX, Q_TILE),    dtype=nl.float32, buffer=nl.psum)
      nisa.nc_transpose(dst=scores_t_psum, data=scores_tile)
      nisa.tensor_copy(dst=scores_t, src=scores_t_psum)

      v_t_psum = nl.ndarray((PMAX, K_TILE_OUT), dtype=nl.float32, buffer=nl.psum)
      nisa.nc_transpose(dst=v_t_psum, data=v_tile)
      nisa.tensor_copy(dst=v_t, src=v_t_psum)

      # Declare tmp_psum inside the loop so each nc_matmul starts fresh.
      tmp_psum = nl.ndarray((Q_TILE, d_head), dtype=nl.float32, buffer=nl.psum)
      nisa.nc_matmul(dst=tmp_psum, stationary=scores_t, moving=v_t)
      nisa.tensor_copy(dst=tmp_sbuf, src=tmp_psum)
      nisa.tensor_tensor(dst=attn_accum, data1=attn_accum, data2=tmp_sbuf, op=nl.add)

    # Store result
    nisa.tensor_copy(dst=attn_out, src=attn_accum)
    nisa.dma_copy(dst=out[q_start:q_end, 0:d_head], src=attn_out)

  return out
