import nki
import nki.isa as nisa
import nki.language as nl
import numpy as np
import torch
from torch_xla.core import xla_model as xm
import os


# SUBSTITUTE HERE

NEG_INF = -9984.0


@nki.jit
def flash_attention_core(q, k, v, causal_mask,
                          kernel_dtype, acc_type,
                          num_heads=8,
                          seq_len=2048,
                          d_head=128):
    B_P_SIZE = 128
    B_F_SIZE = 512
    REDUCTION_TILE = min(2048, seq_len // 2)
    num_k_tiles = seq_len // B_F_SIZE
    num_q_tiles = seq_len // B_P_SIZE

    o = nl.ndarray((num_heads, seq_len, d_head), dtype=kernel_dtype, buffer=nl.shared_hbm)
    l = nl.ndarray((num_heads, seq_len, 1), dtype=nl.float32, buffer=nl.shared_hbm)
    m = nl.ndarray((num_heads, seq_len, 1), dtype=nl.float32, buffer=nl.shared_hbm)

    for head_idx in range(num_heads):
        for q_tile_idx in nl.affine_range(num_q_tiles):
            q_start = q_tile_idx * B_P_SIZE
            q_end = q_start + B_P_SIZE

            q_tile = nl.ndarray((B_P_SIZE, d_head), dtype=kernel_dtype, buffer=nl.sbuf)
            nisa.dma_copy(dst=q_tile, src=q[head_idx, q_start:q_end, 0:d_head])

            q_tile_T_psum = nl.ndarray((d_head, B_P_SIZE), dtype=nl.float32, buffer=nl.psum)
            nisa.nc_transpose(dst=q_tile_T_psum, data=q_tile)
            q_tile_T = nl.ndarray((d_head, B_P_SIZE), dtype=kernel_dtype, buffer=nl.sbuf)
            nisa.tensor_copy(dst=q_tile_T, src=q_tile_T_psum)

            qk_res_buf = nl.ndarray((B_P_SIZE, seq_len), buffer=nl.sbuf, dtype=acc_type)

            k_tile = nl.ndarray((d_head, B_F_SIZE), dtype=kernel_dtype, buffer=nl.sbuf)
            causal_tile = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=nl.float32, buffer=nl.sbuf)

            for k_i in nl.affine_range(num_k_tiles):
                k_start = k_i * B_F_SIZE
                k_end = k_start + B_F_SIZE
                nisa.dma_copy(dst=k_tile, src=k[head_idx, 0:d_head, k_start:k_end])
                qk_psum = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=nl.float32, buffer=nl.psum)
                nisa.nc_matmul(dst=qk_psum, stationary=q_tile_T, moving=k_tile)
                qk_sbuf = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=acc_type, buffer=nl.sbuf)
                nisa.tensor_copy(dst=qk_sbuf, src=qk_psum)
                # Always apply causal mask via (qk + 9984) * mask - 9984.
                # causal_mask=0 for future positions → -9984; causal_mask=1 elsewhere → qk.
                nisa.dma_copy(dst=causal_tile, src=causal_mask[q_start:q_end, k_start:k_end])
                qk_shifted = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=acc_type, buffer=nl.sbuf)
                nisa.tensor_scalar(dst=qk_shifted, data=qk_sbuf, op0=nl.add, operand0=-NEG_INF)
                masked_shifted = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=acc_type, buffer=nl.sbuf)
                nisa.tensor_tensor(dst=masked_shifted, data1=qk_shifted, data2=causal_tile, op=nl.multiply)
                masked_qk = nl.ndarray((B_P_SIZE, B_F_SIZE), dtype=acc_type, buffer=nl.sbuf)
                nisa.tensor_scalar(dst=masked_qk, data=masked_shifted, op0=nl.add, operand0=NEG_INF)
                nisa.tensor_copy(dst=qk_res_buf[0:B_P_SIZE, k_start:k_end], src=masked_qk)

            # Compute max over full qk_res_buf (no dynamic offset — reads from offset 0)
            max_ = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_reduce(dst=max_, op=nl.maximum, data=qk_res_buf, axis=1, keepdims=True)

            nisa.dma_copy(dst=m[head_idx, q_start:q_end, 0:1], src=max_)

            neg_max = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_scalar(dst=neg_max, data=max_, op0=nl.multiply, operand0=-1.0)

            # Two-pass exp+sum with STATIC offsets matching beta1's REDUCTION_TILE structure.
            # Static slices avoid the dynamic-offset activation read bug.
            p_local_0 = nl.ndarray((B_P_SIZE, REDUCTION_TILE), dtype=kernel_dtype, buffer=nl.sbuf)
            nisa.activation(dst=p_local_0, op=nl.exp,
                            data=qk_res_buf[0:B_P_SIZE, 0:REDUCTION_TILE], bias=neg_max, scale=1.0)
            ps_0 = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_reduce(dst=ps_0, op=nl.add, data=p_local_0, axis=1, keepdims=True)

            p_local_1 = nl.ndarray((B_P_SIZE, REDUCTION_TILE), dtype=kernel_dtype, buffer=nl.sbuf)
            nisa.activation(dst=p_local_1, op=nl.exp,
                            data=qk_res_buf[0:B_P_SIZE, REDUCTION_TILE:2*REDUCTION_TILE],
                            bias=neg_max, scale=1.0)
            ps_1 = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_reduce(dst=ps_1, op=nl.add, data=p_local_1, axis=1, keepdims=True)

            ps = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_tensor(dst=ps, data1=ps_0, data2=ps_1, op=nl.add)

            # PV matmul using two separate loops to avoid dynamic-offset nc_transpose.
            # Intermediate tensor_copy to static buffer before nc_transpose (matches CSA pattern).
            pv_accum = nl.ndarray((B_P_SIZE, d_head), dtype=nl.float32, buffer=nl.sbuf)
            nisa.memset(dst=pv_accum, value=0.0)

            p_slice_local = nl.ndarray((B_P_SIZE, B_P_SIZE), dtype=kernel_dtype, buffer=nl.sbuf)
            p_slice_T = nl.ndarray((B_P_SIZE, B_P_SIZE), dtype=kernel_dtype, buffer=nl.sbuf)
            v_tile = nl.ndarray((B_P_SIZE, d_head), dtype=kernel_dtype, buffer=nl.sbuf)

            # First REDUCTION_TILE cols (p_local_0): v rows 0..REDUCTION_TILE-1
            for k_i in nl.affine_range(REDUCTION_TILE // B_P_SIZE):
                k_start = k_i * B_P_SIZE
                nisa.tensor_copy(dst=p_slice_local,
                                 src=p_local_0[0:B_P_SIZE, k_start:k_start + B_P_SIZE])
                p_slice_T_psum = nl.ndarray((B_P_SIZE, B_P_SIZE), dtype=nl.float32, buffer=nl.psum)
                nisa.nc_transpose(dst=p_slice_T_psum, data=p_slice_local)
                nisa.tensor_copy(dst=p_slice_T, src=p_slice_T_psum)
                nisa.dma_copy(dst=v_tile, src=v[head_idx, k_start:k_start + B_P_SIZE, 0:d_head])
                pv_psum = nl.ndarray((B_P_SIZE, d_head), dtype=nl.float32, buffer=nl.psum)
                nisa.nc_matmul(dst=pv_psum, stationary=p_slice_T, moving=v_tile)
                tmp = nl.ndarray((B_P_SIZE, d_head), dtype=nl.float32, buffer=nl.sbuf)
                nisa.tensor_copy(dst=tmp, src=pv_psum)
                nisa.tensor_tensor(dst=pv_accum, data1=pv_accum, data2=tmp, op=nl.add)

            # Second REDUCTION_TILE cols (p_local_1): v rows REDUCTION_TILE..2*REDUCTION_TILE-1
            for k_i in nl.affine_range(REDUCTION_TILE // B_P_SIZE):
                k_start = k_i * B_P_SIZE
                v_start = REDUCTION_TILE + k_start
                nisa.tensor_copy(dst=p_slice_local,
                                 src=p_local_1[0:B_P_SIZE, k_start:k_start + B_P_SIZE])
                p_slice_T_psum = nl.ndarray((B_P_SIZE, B_P_SIZE), dtype=nl.float32, buffer=nl.psum)
                nisa.nc_transpose(dst=p_slice_T_psum, data=p_slice_local)
                nisa.tensor_copy(dst=p_slice_T, src=p_slice_T_psum)
                nisa.dma_copy(dst=v_tile, src=v[head_idx, v_start:v_start + B_P_SIZE, 0:d_head])
                pv_psum = nl.ndarray((B_P_SIZE, d_head), dtype=nl.float32, buffer=nl.psum)
                nisa.nc_matmul(dst=pv_psum, stationary=p_slice_T, moving=v_tile)
                tmp = nl.ndarray((B_P_SIZE, d_head), dtype=nl.float32, buffer=nl.sbuf)
                nisa.tensor_copy(dst=tmp, src=pv_psum)
                nisa.tensor_tensor(dst=pv_accum, data1=pv_accum, data2=tmp, op=nl.add)

            o_cast = nl.ndarray((B_P_SIZE, d_head), dtype=kernel_dtype, buffer=nl.sbuf)
            nisa.tensor_copy(dst=o_cast, src=pv_accum)
            nisa.dma_copy(dst=o[head_idx, q_start:q_end, 0:d_head], src=o_cast)

            log_ps = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.activation(dst=log_ps, op=nl.log, data=ps)
            l_val = nl.ndarray((B_P_SIZE, 1), dtype=acc_type, buffer=nl.sbuf)
            nisa.tensor_tensor(dst=l_val, data1=log_ps, data2=max_, op=nl.add)
            nisa.dma_copy(dst=l[head_idx, q_start:q_end, 0:1], src=l_val)

    return o, l, m

def test_nki(ref_func, test_func):
  """Test the kernel that processes all q tiles"""
  num_heads = 8
  d_head = 128
  seq_len = 2048
  device = xm.xla_device()

  print(f"Testing flash attention with num_heads={num_heads}, d_head={d_head} and seq_len={seq_len}...")

  # Create random inputs with head dimension
  q_np = np.random.rand(num_heads, seq_len, d_head).astype(np.float32)
  k_np = np.random.rand(num_heads, d_head, seq_len).astype(np.float32)
  v_np = np.random.rand(num_heads, seq_len, d_head).astype(np.float32)

  causal_mask_np = np.zeros((seq_len, seq_len), dtype=np.float32)
  for i in range(seq_len):
    causal_mask_np[i, :i+1] = 1.0

  q = torch.from_numpy(q_np).to(device=device)
  k = torch.from_numpy(k_np).to(device=device)
  v = torch.from_numpy(v_np).to(device=device)
  causal_mask = torch.from_numpy(causal_mask_np).to(device=device)

  # Run the kernel
  o_ref, l_ref, m_ref = ref_func(
    q, k, v, causal_mask,
    kernel_dtype=nl.bfloat16,
    acc_type=nl.float32,
    num_heads=num_heads,
    seq_len=seq_len,
    d_head=d_head
  )
  o_test, l_test, m_test = test_func(
    q, k, v, causal_mask,
    kernel_dtype=nl.bfloat16,
    acc_type=nl.float32,
    num_heads=num_heads,
    seq_len=seq_len,
    d_head=d_head
  )

  o_ref_np = o_ref.detach().cpu().to(torch.float32).numpy()
  o_test_np = o_test.detach().cpu().to(torch.float32).numpy()
  l_ref_np = l_ref.detach().cpu().to(torch.float32).numpy()
  l_test_np = l_test.detach().cpu().to(torch.float32).numpy()
  m_ref_np = m_ref.detach().cpu().to(torch.float32).numpy()
  m_test_np = m_test.detach().cpu().to(torch.float32).numpy()

  fail = False
  if not np.allclose(o_ref_np, o_test_np, atol=0.01, rtol=0.001):
    print(f"FAIL: o_ref != o_test")
    print("o_ref shape:", o_ref_np.shape)
    print("o_test shape:", o_test_np.shape)
    print("o_ref", o_ref_np[0,:5,:5])
    print("o_test", o_test_np[0,:5,:5])
    fail = True
  if not np.allclose(l_ref_np, l_test_np, atol=0.01, rtol=0.001):
    print(f"FAIL: l_ref != l_test")
    print("l_ref shape:", l_ref_np.shape)
    print("l_test shape:", l_test_np.shape)
    print("l_ref", l_ref_np[:5,:5])
    print("l_test", l_test_np[:5,:5])
    fail = True
  if not np.allclose(m_ref_np, m_test_np, atol=0.01, rtol=0.001):
    print(f"FAIL: m_ref != m_test")
    print("m_ref shape:", m_ref_np.shape)
    print("m_test shape:", m_test_np.shape)
    print("m_ref", m_ref_np[:5,:5])
    print("m_test", m_test_np[:5,:5])
    fail = True
  if fail:
    return False

  print(f"  All tests passed")

  return True

def benchmark_nki(nki_func):
  """Benchmark the flash attention kernel"""
  num_heads = 8
  d_head = 128
  seq_len = 2048
  device = xm.xla_device()

  q_np = np.random.rand(num_heads, seq_len, d_head).astype(np.float32)
  k_np = np.random.rand(num_heads, d_head, seq_len).astype(np.float32)
  v_np = np.random.rand(num_heads, seq_len, d_head).astype(np.float32)

  causal_mask_np = np.zeros((seq_len, seq_len), dtype=np.float32)
  for i in range(seq_len):
    causal_mask_np[i, :i+1] = 1.0

  q = torch.from_numpy(q_np).to(device=device)
  k = torch.from_numpy(k_np).to(device=device)
  v = torch.from_numpy(v_np).to(device=device)
  causal_mask = torch.from_numpy(causal_mask_np).to(device=device)

  bench_func = nki.benchmark(warmup=2, iters=10)(nki_func)
  bench_func(q, k, v, causal_mask,
             kernel_dtype=nl.bfloat16,
             acc_type=nl.float32,
             num_heads=num_heads,
             seq_len=seq_len,
             d_head=d_head)
  latency_res = bench_func.benchmark_result.nc_latency
  p99 = latency_res.get_latency_percentile(99)
  print("Latency: {:.3f} ms (P99)".format(p99 / 1000.0))

if __name__ == "__main__":
  os.environ["NEURON_PLATFORM_TARGET_OVERRIDE"] = "trn1" 
  test_result = test_nki(flash_attention_core, solution)
  if not test_result:
    print("Test failed")
    exit(1)
  else:
    print("Running benchmark...")
    benchmark_nki(solution)
