import torch
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import math

@nki.jit
def nki_rmsnorm_kernel(a_tensor, g_tensor):
  # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
  # Where RMS(a_tensor) = sqrt(eps + (1/N) * sum(a_tensor * a_tensor))
  # and N = a_tensor.shape[1], eps is 1e-5
  # Reduction (mean) is performed in the free (2nd) dimension
    out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
                            buffer=nl.shared_hbm)

    # Make sure shapes match
    assert a_tensor.shape[1] == g_tensor.shape[0]

    # Generate tensor indices to index input tensor
    ix = nl.arange(128)[:, None]
    iw = nl.arange(1)[:, None]
    iy = nl.arange(a_tensor.shape[1])[None, :]

    num_rows = a_tensor.shape[0]

    # Load RMSNorm weight once, reused by rows/tiles of a_tensor
    g_tile = nl.load(g_tensor.reshape((1, g_tensor.shape[0]))[iw, iy])

    # Process 128 rows at a time due to 128-partition tile size limitation
    # Since we're not reducing across the first dimension
    # Tiles can be processed independently
    for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)):

        # Load input data from external memory to on-chip memory
        a_tile = nl.load(a_tensor[i * 128 + ix, iy],
                         mask=(i * 128 + ix < num_rows))

        # Compute element-wise square of a_tensor
        in_square = nl.square(a_tile)

        # Calculate sum of squared elements, along last dimension
        square_sum = nl.sum(in_square, axis=[1])

        # Scale and get a reciprocal
        mean = square_sum / a_tensor.shape[1]

        # Take square root of mean and then reciprocal with
        # rsqrt API (one ISA instruction)
        rms_reciprocal = nl.rsqrt(mean + 1e-5)

        # Scale the input tensor
        out_tile = nl.multiply(a_tile, rms_reciprocal)

        # Broadcast weight along first axis to match tensor shape
        # num_rows_active = min(num_rows - i * 128, 128)
        g_bcast = g_tile.broadcast_to((128, g_tensor.shape[0]))

        # Multiply with the RMSNorm weight
        out_tile[...] = nl.multiply(out_tile, g_bcast,
                            mask=(i * 128 + ix < num_rows))

        # store the addition results back to external memory (out_tensor)
        nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
                 mask=(i * 128 + ix < num_rows))

    return out_tensor

@nki.jit
def nki_matmul_tiled_(lhsT, rhs):
    """NKI kernel to compute a matrix multiplication operation in a tiled manner

    Args:
        lhsT: an input tensor of shape [K,M], where both K and M are multiples for
            128.  It is the left-hand-side argument of the matrix multiplication,
            delivered transposed for optimal performance.
        rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
            is a multiple of 512.  It is the right-hand-side argument of the matrix
            multiplication.
    Returns:
        result: the resulting output tensor of shape [M,N]
    """

    K, M = lhsT.shape
    K_, N = rhs.shape
    assert K == K_, "lhsT and rhs must have the same contraction dimension"
    result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

    TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
    TILE_K = nl.tile_size.pmax  # 128
    TILE_N = nl.tile_size.gemm_moving_fmax  # 512

    if M < TILE_M:
        TILE_M = M
    
    if N < TILE_N:
        TILE_N = N

    if K < TILE_K:
        TILE_K = K

    # Use affine_range to loop over tiles
    for m in nl.affine_range(M // TILE_M):
        for n in nl.affine_range(N // TILE_N):
            # Allocate a tensor in PSUM
            res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)

            for k in nl.affine_range(K // TILE_K):
                # Declare the tiles on SBUF
                lhsT_tile = nl.ndarray((TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
                rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)

                # Load tiles from lhsT and rhs
                lhsT_tile[...] = nl.load(lhsT[k * TILE_K:(k + 1) * TILE_K,
                                            m * TILE_M:(m + 1) * TILE_M])
                rhs_tile[...] = nl.load(rhs[k * TILE_K:(k + 1) * TILE_K,
                                            n * TILE_N:(n + 1) * TILE_N])

                # Accumulate partial-sums into PSUM
                res_psum += nl.matmul(lhsT_tile[...], rhs_tile[...], transpose_x=True)

            # Copy the result from PSUM back to SBUF, and cast to expected output data-type
            res_sb = nl.copy(res_psum, dtype=result.dtype)
            nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
                    value=res_sb)

    return result

def solution(x, post_attention_layernorm_weight, up_proj_weight, gate_proj_weight, down_proj_weight):
    # Prefill shapes:
    # x.shape = (1, 32, 2048)
    # post_attention_layernorm_weight.shape = (2048,)
    # up_proj_weight.shape = (8192, 2048)
    # gate_proj_weight.shape = (8192, 2048)
    # down_proj_weight.shape = (2048, 8192)
    # output.shape = (32, 8192)
    b, s, h = x.shape
    x = x.view(-1, h)
    x = nki_rmsnorm_kernel(x, post_attention_layernorm_weight)
    up = nki_matmul_tiled_(x.t(), up_proj_weight.t())
    gate = nki_matmul_tiled_(x.t(), gate_proj_weight.t())
    act = torch.nn.SiLU()(gate) * up
    output = nki_matmul_tiled_(act.t() , down_proj_weight.t())

    return output
