@nki.jit
def solution(input_tensor, epsilon, gamma_vector, beta_vector):

  # Row tile size (partition limit); column chunk size for nc_matmul (max 128x512)
  TILE_ROWS = 128
  PARAM_BCAST_CHUNK_COLS = 512

  # Compute LayerNorm:
  #   y = ((x - mean(x)) / sqrt(var(x) + epsilon)) * gamma + beta
  # Reduction (mean/var) is along the last (free) dimension.
  output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
                             buffer=nl.shared_hbm)

  assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]

  num_rows = input_tensor.shape[0]
  n_f = input_tensor.shape[1]

  # Load gamma/beta once; reused for all row tiles.
  gamma_tile = nl.ndarray((1, gamma_vector.shape[0]), dtype=gamma_vector.dtype, buffer=nl.sbuf)
  beta_tile = nl.ndarray((1, beta_vector.shape[0]), dtype=beta_vector.dtype, buffer=nl.sbuf)
  nisa.dma_copy(dst=gamma_tile, src=gamma_vector.reshape((1, gamma_vector.shape[0])))
  nisa.dma_copy(dst=beta_tile, src=beta_vector.reshape((1, beta_vector.shape[0])))

  # Process 128 rows at a time (tile size limit); tiles are independent.
  for i in nl.affine_range(math.ceil(num_rows / TILE_ROWS)):
    p_start = i * TILE_ROWS
    p_end = min(num_rows, p_start + TILE_ROWS)
    tile_rows = p_end - p_start

    # Load input tile from HBM to on-chip.
    x_tile = nl.ndarray((tile_rows, n_f), dtype=input_tensor.dtype, buffer=nl.sbuf)
    nisa.dma_copy(dst=x_tile, src=input_tensor[p_start:p_end, 0:n_f])

    # mean(x) and mean(x^2) along last dimension.
    sum_x = nl.ndarray((tile_rows, 1), dtype=input_tensor.dtype, buffer=nl.sbuf)
    nisa.tensor_reduce(dst=sum_x, op=nl.add, data=x_tile, axis=1, keepdims=True)

    x_square = nl.ndarray((tile_rows, n_f), dtype=input_tensor.dtype, buffer=nl.sbuf)
    nisa.tensor_tensor(dst=x_square, data1=x_tile, data2=x_tile, op=nl.multiply)
    sum_x2 = nl.ndarray((tile_rows, 1), dtype=input_tensor.dtype, buffer=nl.sbuf)
    nisa.tensor_reduce(dst=sum_x2, op=nl.add, data=x_square, axis=1, keepdims=True)

    mean = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    ex2 = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.tensor_scalar(dst=mean, data=sum_x, op0=nl.multiply, operand0=1.0 / n_f)
    nisa.tensor_scalar(dst=ex2, data=sum_x2, op0=nl.multiply, operand0=1.0 / n_f)

    # var(x) = E[x^2] - (E[x])^2
    mean_sq = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.tensor_tensor(dst=mean_sq, data1=mean, data2=mean, op=nl.multiply)
    var = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.tensor_scalar(dst=var, data=ex2, op0=nl.subtract, operand0=mean_sq)

    # inv_std = 1 / sqrt(var + epsilon)
    var_eps = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.tensor_scalar(dst=var_eps, data=var, op0=nl.add, operand0=epsilon)
    sqrt_var = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.activation(dst=sqrt_var, op=nl.sqrt, data=var_eps)
    inv_std = nl.ndarray((tile_rows, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.reciprocal(dst=inv_std, data=sqrt_var)

    # Normalize: (x - mean) * inv_std
    out_tile = nl.ndarray((tile_rows, n_f), dtype=input_tensor.dtype, buffer=nl.sbuf)
    nisa.tensor_scalar(dst=out_tile, data=x_tile, op0=nl.subtract, operand0=mean,
                       op1=nl.multiply, operand1=inv_std)

    # Broadcast gamma/beta to (tile_rows, n_f) in column chunks and apply:
    # out = out * gamma + beta
    ones = nl.ndarray((1, tile_rows), dtype=nl.float32, buffer=nl.sbuf)
    nisa.memset(dst=ones, value=1.0)
    for j in nl.affine_range((n_f + PARAM_BCAST_CHUNK_COLS - 1) // PARAM_BCAST_CHUNK_COLS):
      j_start = j * PARAM_BCAST_CHUNK_COLS
      j_end = min(j_start + PARAM_BCAST_CHUNK_COLS, n_f)
      chunk = j_end - j_start

      gamma_chunk = nl.ndarray((1, chunk), dtype=gamma_vector.dtype, buffer=nl.sbuf)
      beta_chunk = nl.ndarray((1, chunk), dtype=beta_vector.dtype, buffer=nl.sbuf)
      nisa.dma_copy(dst=gamma_chunk, src=gamma_tile[0:1, j_start:j_end])
      nisa.dma_copy(dst=beta_chunk, src=beta_tile[0:1, j_start:j_end])

      gamma_bcast_psum = nl.ndarray((tile_rows, chunk), dtype=gamma_vector.dtype, buffer=nl.psum)
      beta_bcast_psum = nl.ndarray((tile_rows, chunk), dtype=beta_vector.dtype, buffer=nl.psum)
      nisa.nc_matmul(dst=gamma_bcast_psum, stationary=ones, moving=gamma_chunk, is_stationary_onezero=True)
      nisa.nc_matmul(dst=beta_bcast_psum, stationary=ones, moving=beta_chunk, is_stationary_onezero=True)

      nisa.tensor_tensor(dst=out_tile[0:tile_rows, j_start:j_end],
                         data1=out_tile[0:tile_rows, j_start:j_end],
                         data2=gamma_bcast_psum, op=nl.multiply)
      nisa.tensor_tensor(dst=out_tile[0:tile_rows, j_start:j_end],
                         data1=out_tile[0:tile_rows, j_start:j_end],
                         data2=beta_bcast_psum, op=nl.add)

    # Store result tile back to HBM.
    nisa.dma_copy(dst=output_tensor[p_start:p_end, 0:n_f], src=out_tile)

  return output_tensor

