import torch
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.typing as nt
import numpy as np
import torch_xla.core.xla_model as xm
import math
import time

# SUBSTITUTE HERE

# @nki.jit
# def nki_rmsnorm_kernel_reference(a_tensor, g_tensor):
#   # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
#   # Where RMS(a_tensor) = sqrt(eps + (1/N) * sum(a_tensor * a_tensor))
#   # and N = a_tensor.shape[1], eps is 1e-5
#   # Reduction (mean) is performed in the free (2nd) dimension
#     out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
#                             buffer=nl.shared_hbm)

#     # Make sure shapes match
#     assert a_tensor.shape[1] == g_tensor.shape[0]

#     # Generate tensor indices to index input tensor
#     ix = nl.arange(128)[:, None]
#     iw = nl.arange(1)[:, None]
#     iy = nl.arange(a_tensor.shape[1])[None, :]

#     num_rows = a_tensor.shape[0]

#     # Load RMSNorm weight once, reused by rows/tiles of a_tensor
#     g_tile = nl.load(g_tensor.reshape((1, g_tensor.shape[0]))[iw, iy])

#     # Process 128 rows at a time due to 128-partition tile size limitation
#     # Since we're not reducing across the first dimension
#     # Tiles can be processed independently
#     for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)):

#         # Load input data from external memory to on-chip memory
#         a_tile = nl.load(a_tensor[i * 128 + ix, iy],
#                          mask=(i * 128 + ix < num_rows))

#         # Compute element-wise square of a_tensor
#         in_square = nl.square(a_tile)

#         # Calculate sum of squared elements, along last dimension
#         square_sum = nl.sum(in_square, axis=[1])

#         # Scale and get a reciprocal
#         mean = square_sum / a_tensor.shape[1]

#         # Take square root of mean and then reciprocal with
#         # rsqrt API (one ISA instruction)
#         rms_reciprocal = nl.rsqrt(mean + 1e-5)

#         # Scale the input tensor
#         out_tile = nl.multiply(a_tile, rms_reciprocal)

#         # Broadcast weight along first axis to match tensor shape
#         # num_rows_active = min(num_rows - i * 128, 128)
#         g_bcast = g_tile.broadcast_to((128, g_tensor.shape[0]))

#         # Multiply with the RMSNorm weight
#         out_tile[...] = nl.multiply(out_tile, g_bcast,
#                             mask=(i * 128 + ix < num_rows))

#         # store the addition results back to external memory (out_tensor)
#         nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
#                  mask=(i * 128 + ix < num_rows))

#     return out_tensor

@nki.jit
def nki_matmul_tiled_reference_(lhsT, rhs):
    """NKI kernel to compute a matrix multiplication operation in a tiled manner

    Args:
        lhsT: an input tensor of shape [K,M], where both K and M are multiples for
            128.  It is the left-hand-side argument of the matrix multiplication,
            delivered transposed for optimal performance.
        rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
            is a multiple of 512.  It is the right-hand-side argument of the matrix
            multiplication.
    Returns:
        result: the resulting output tensor of shape [M,N]
    """

    K, M = lhsT.shape
    K_, N = rhs.shape
    assert K == K_, "lhsT and rhs must have the same contraction dimension"
    result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

    TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
    TILE_K = nl.tile_size.pmax  # 128
    TILE_N = nl.tile_size.gemm_moving_fmax  # 512

    if M < TILE_M:
        TILE_M = M
    
    if N < TILE_N:
        TILE_N = N

    if K < TILE_K:
        TILE_K = K

    # Use affine_range to loop over tiles
    for m in nl.affine_range(M // TILE_M):
        for n in nl.affine_range(N // TILE_N):
            # Allocate a tensor in PSUM
            res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)

            for k in nl.affine_range(K // TILE_K):
                # Declare the tiles on SBUF
                lhsT_tile = nl.ndarray((TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
                rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)

                # Load tiles from lhsT and rhs
                lhsT_tile[...] = nl.load(lhsT[k * TILE_K:(k + 1) * TILE_K,
                                            m * TILE_M:(m + 1) * TILE_M])
                rhs_tile[...] = nl.load(rhs[k * TILE_K:(k + 1) * TILE_K,
                                            n * TILE_N:(n + 1) * TILE_N])

                # Accumulate partial-sums into PSUM
                res_psum += nl.matmul(lhsT_tile[...], rhs_tile[...], transpose_x=True)

            # Copy the result from PSUM back to SBUF, and cast to expected output data-type
            res_sb = nl.copy(res_psum, dtype=result.dtype)
            nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
                    value=res_sb)

    return result

def forward_reference(x, up_proj_weight, gate_proj_weight, down_proj_weight):
    b, s, h = x.shape
    x = x.view(-1, h)
    # x = RmsNorm.apply(x, post_attention_layernorm_weight, 1e-5, len(x.shape) - 1)
    # x = nki_rmsnorm_kernel_reference(x, post_attention_layernorm_weight)
    up = nki_matmul_tiled_reference_(x.t(), up_proj_weight)
    gate = nki_matmul_tiled_reference_(x.t(), gate_proj_weight)
    act = torch.nn.SiLU()(gate) * up
    output = nki_matmul_tiled_reference_(act.t() , down_proj_weight)

    return output

def get_test_weights(hidden_size, intermediate_size, dtype, device):
    """Create test weights for MLP."""
    up_proj_weight = torch.randn(hidden_size, intermediate_size, dtype=dtype, device=device)
    gate_proj_weight = torch.randn(hidden_size, intermediate_size, dtype=dtype, device=device)
    down_proj_weight = torch.randn(intermediate_size, hidden_size, dtype=dtype, device=device)
    return (
        up_proj_weight,
        gate_proj_weight,
        down_proj_weight,
    )


def compare_outputs(reference_out, test_out, atol=1e-3, rtol=1e-3):
    """Compare test output against reference output."""
    if not torch.allclose(reference_out, test_out, atol=atol, rtol=rtol):
        ref_cpu = reference_out.cpu()
        test_cpu = test_out.cpu()
        print("reference_out[:8]: %s", ref_cpu.flatten()[:8])
        print("test_out[:8]: %s", test_cpu.flatten()[:8])
        diff = (ref_cpu - test_cpu).abs()
        print("max_diff: %s", diff.max())
        print("mean_diff: %s", diff.mean())
        print("FAIL: test output does not match reference")
        return False
    return True

def benchmark_forward(forward_fn, x, weights, perf_iters=50):
    """Benchmark a forward function."""
    # Warmup
    _ = forward_fn(x, *weights)
    
    t0 = time.time()
    for _ in range(perf_iters):
        _ = forward_fn(x, *weights)
    t1 = time.time()
    ms_per_iter = (t1 - t0) * 1000.0 / perf_iters
    return ms_per_iter

def test_nki(ref_func, test_func):
    torch.manual_seed(0)
    dtype = torch.bfloat16
    device = xm.xla_device()
    hidden_size = 2048
    intermediate_size = 4096
    weights = get_test_weights(hidden_size, intermediate_size, dtype, device)
    
    for _ in range(2):
        batch, seq = 1, 1
        x = torch.randn(batch, seq, hidden_size, dtype=dtype, device=device)
        ref_out = ref_func(x, *weights)
        test_out = test_func(x, *weights)
        if not compare_outputs(ref_out, test_out):
            return False
    return True

def benchmark_nki(nki_func):
    torch.manual_seed(0)
    dtype = torch.bfloat16
    device = xm.xla_device()
    hidden_size = 2048
    intermediate_size = 4096
    weights = get_test_weights(hidden_size, intermediate_size, dtype, device)
    batch, seq = 1, 1
    x = torch.randn(batch, seq, hidden_size, dtype=dtype, device=device)
    
    test_ms = benchmark_forward(nki_func, x, weights, 50)
    print("Latency: {:.3f} ms/iter".format(test_ms))

if __name__ == "__main__":
    test_result = test_nki(forward_reference, solution)
    if not test_result:
        print("Test failed")
        exit(1)
    else:
        print("Test passed")
        benchmark_nki(solution)
