import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.typing as nt
import numpy as np

# SUBSTITUTE HERE

@nki.jit
def ref(Q, K, V, past_k, past_v, attention_mask):
    """
    Token-generation attention kernel in NKI.
    Optimized for Trainium/Inferentia.
    Fixes the partition broadcast error by explicitly broadcasting the mask.
    Hoists Q transpose out of the group loop for performance (Opt #2).
    """
    out_hbm = nl.ndarray(Q.shape, dtype=Q.dtype, buffer=nl.shared_hbm)

    # Common index tiles
    i_p1   = nl.arange(1)[:, None]      # (1,1)
    i_p4   = nl.arange(4)[:, None]      # (4,1)
    i_p16  = nl.arange(16)[:, None]     # (16,1)
    i_p64  = nl.arange(64)[:, None]     # (64,1)
    i_p128 = nl.arange(128)[:, None]    # (128,1)

    i_f4   = nl.arange(4)[None, :]      # (1,4)
    i_f64  = nl.arange(64)[None, :]     # (1,64)
    i_f128 = nl.arange(128)[None, :]    # (1,128)
    i_f512 = nl.arange(512)[None, :]    # (1,512)

    # Load mask once: shape (1,512)
    mask_1x512 = nl.load(attention_mask[0, 0, i_p1, i_f512])
    # Explicitly broadcast mask to (4, 512) to match head groups. 
    # This prevents the "Unexpected partition broadcast" assertion error in nl.where.
    mask_4x512 = nl.broadcast_to(mask_1x512, shape=(4, 512))

    # -- Hoist Q-transpose out of the group loop --
    # Load all 16 heads of Q: (16, 64)
    Q_all = nl.load(Q[0, i_p16, 0, i_f64])
    # Transpose Q to (64, 16) stationary for nc_matmul. Result in PSUM.
    Q_all_T = nisa.nc_transpose(Q_all)
    # Copy back to SBUF for slicing. qT_all is (64, 16)
    qT_all = nisa.tensor_copy(Q_all_T, dtype=Q.dtype)

    inv_sqrt_d = 0.125
    neg_inf = -3.4028235e38

    # Process 4 KV groups
    for g in range(4):
        h0 = g * 4

        # Extract the 4-wide slice of qT_all for this group
        # qT_all is (64, 16), first dim is partition.
        # We want the columns corresponding to heads [h0 : h0+4].
        # Result qT is (64, 4), which acts as stationary for nc_matmul (K=64, M=4).
        qT = qT_all[i_p64, (g * 4) + i_f4]

        # ---- Load K active: (64,1) using transpose load pattern (indices on dim 3)
        k_act = nl.load(K[0, g, 0, i_p64])
        
        # Active score: qT.T (4,64) @ k_act (64,1) -> (4,1)
        score_act_psum = nisa.nc_matmul(qT, k_act)
        score_act = nisa.tensor_scalar(score_act_psum, op0=np.multiply, operand0=inv_sqrt_d, dtype=nl.float32)

        # ---- Construct K_prior^T (64, 512) in SBUF
        k_prior_T = nl.ndarray((64, 512), dtype=K.dtype, buffer=nl.sbuf)

        for j in range(4):
            seq0 = j * 128
            # Load chunk (128,64)
            k_chunk = nl.load(past_k[0, g, seq0 + i_p128, i_f64])
            # Transpose to (64,128)
            k_chunk_T = nisa.tensor_copy(nisa.nc_transpose(k_chunk), dtype=K.dtype)
            # Copy to larger buffer
            k_prior_T[i_p64, seq0 + i_f128] = k_chunk_T[i_p64, i_f128]

        # Prior scores: qT.T (4,64) @ k_prior_T (64,512) -> (4,512)
        score_prior_psum = nisa.nc_matmul(qT, k_prior_T)
        score_prior = nisa.tensor_scalar(score_prior_psum, op0=np.multiply, operand0=inv_sqrt_d, dtype=nl.float32)

        # Apply mask using explicit broadcast tile (4, 512)
        score_prior_masked = nl.where(mask_4x512, score_prior, neg_inf)

        # ---- Softmax
        # Max over prior (4,512) -> (4,1)
        max_prior = nl.max(score_prior_masked, axis=1, keepdims=True)
        # Combine with active score max
        max_val = nl.maximum(max_prior, score_act)

        score_prior_shifted = nl.subtract(score_prior_masked, max_val)
        score_act_shifted = nl.subtract(score_act, max_val)

        # Exp + Sum Prior
        sum_exp = nl.zeros((4, 1), dtype=nl.float32, buffer=nl.sbuf)
        exp_prior = nisa.activation(op=nl.exp, data=score_prior_shifted, 
                                    reduce_op=nl.add, reduce_res=sum_exp, reduce_cmd=nisa.reduce_cmd.reset_reduce,
                                    dtype=nl.float32)
        
        exp_act = nl.exp(score_act_shifted)
        denom = nl.add(sum_exp, exp_act)

        # Calculate probabilities
        probs_prior = nisa.tensor_copy(nl.divide(exp_prior, denom), dtype=Q.dtype)
        probs_act = nisa.tensor_copy(nl.divide(exp_act, denom), dtype=Q.dtype)

        # ---- Attn Prior: probs_prior (4,512) @ V_prior (512,64) -> (4,64)
        attn_prior_psum = nl.zeros((4, 64), dtype=nl.float32, buffer=nl.psum)

        for j in range(4):
            seq0 = j * 128
            # Take prob chunk (4,128)
            p_chunk = probs_prior[i_p4, seq0 + i_f128]
            # Transpose to (128,4) stationary
            p_chunk_T = nisa.tensor_copy(nisa.nc_transpose(p_chunk), dtype=Q.dtype)
            
            # Load V chunk (128,64)
            v_chunk = nl.load(past_v[0, g, seq0 + i_p128, i_f64])
            
            # Matmul
            attn_prior_psum += nisa.nc_matmul(p_chunk_T, v_chunk)

        attn_prior = nisa.tensor_copy(attn_prior_psum, dtype=Q.dtype)

        # ---- Attn Active
        v_act = nl.load(V[0, g, i_p1, i_f64]) # (1,64)
        # Explicit broadcast to (4, 64) to safely match partition dimension
        v_act_4x64 = nl.broadcast_to(v_act, shape=(4, 64))
        # probs_act is (4,1), v_act_4x64 is (4,64) -> result (4,64)
        attn_act = nl.multiply(probs_act, v_act_4x64)

        # Final accumulation and store
        out_tile = nl.add(attn_prior, attn_act)
        nl.store(out_hbm[0, h0 + i_p4, 0, i_f64], out_tile)

    return out_hbm

def forward_reference(Q, K, V, past_k, past_v, attention_mask, kernel):
    """Direct NKI kernel call with attention inputs."""
    return kernel(Q, K, V, past_k, past_v, attention_mask)

def get_test_data(dtype):
    """Create test data for attention kernel."""
    # Q: (1, 16, 1, 64) - batch=1, num_heads=16, seq=1, head_dim=64
    Q = np.random.randn(1, 16, 1, 64).astype(dtype)
    # K: (1, 4, 1, 64) - batch=1, num_groups=4, seq=1, head_dim=64
    K = np.random.randn(1, 4, 1, 64).astype(dtype)
    # V: (1, 4, 1, 64) - batch=1, num_groups=4, seq=1, head_dim=64
    V = np.random.randn(1, 4, 1, 64).astype(dtype)
    # past_k: (1, 4, 512, 64) - batch=1, num_groups=4, past_seq=512, head_dim=64
    past_k = np.random.randn(1, 4, 512, 64).astype(dtype)
    # past_v: (1, 4, 512, 64) - batch=1, num_groups=4, past_seq=512, head_dim=64
    past_v = np.random.randn(1, 4, 512, 64).astype(dtype)
    # attention_mask: (1, 1, 1, 512) - batch=1, num_heads=1, seq=1, past_seq=512
    # Use ones for valid positions, zeros for masked positions
    attention_mask = np.random.randint(0, 2, (1, 1, 1, 512)).astype(np.bool_)
    return (Q, K, V, past_k, past_v, attention_mask)

def compare_outputs(reference_out, test_out, atol=1e-3, rtol=1e-3):
    """Compare test output against reference output."""
    ref_f32 = reference_out.astype(nl.float32)
    test_f32 = test_out.astype(nl.float32)
    if not np.allclose(ref_f32, test_f32, atol=atol, rtol=rtol):
        print("reference_out[:8]: %s", ref_f32.flatten()[:8])
        print("test_out[:8]: %s", test_f32.flatten()[:8])
        diff = np.abs(ref_f32 - test_f32)
        print("max_diff: %s", np.max(diff))
        print("mean_diff: %s", np.mean(diff))
        print("FAIL: test output does not match reference")
        return False
    return True

def test_nki(ref_func, test_func):
    np.random.seed(0)
    dtype = nl.bfloat16
    
    for _ in range(2):
        test_data = get_test_data(dtype)
        ref_out = forward_reference(*test_data, kernel=ref_func)
        test_out = forward_reference(*test_data, kernel=test_func)
        if not compare_outputs(ref_out, test_out):
            return False
    return True

def benchmark_nki(nki_func):
    dtype = nl.bfloat16
    test_data = get_test_data(dtype)
    
    bench_func = nki.benchmark(warmup=2, iters=10)(nki_func)
    bench_func(*test_data)
    latency_res = bench_func.benchmark_result.nc_latency
    p99 = latency_res.get_latency_percentile(99)
    print("Latency: {:.3f} ms (P99)".format(p99 / 1000.0))

if __name__ == "__main__":
    test_result = test_nki(ref, solution)
    if not test_result:
        print("Test failed")
        exit(1)
    else:
        print("Test passed")
        benchmark_nki(solution)
