# developed based on deja vu and TEAL

import torch
import triton
import triton.language as tl
from typing import Optional
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'

def init_to_zero(*names):
    def init_func(nargs):
        for name in names:
            nargs[name].zero_()
    return init_func

norm_configs=[
    triton.Config({"Norm_BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 64}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 128}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 256}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 512}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 32}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 64}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 256}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 512}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"Norm_BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
]
@triton.autotune(
    configs=norm_configs,
    key=["N"]  # N=input_dim
)
@triton.jit
def norm_squared_kernel(
    Y,  # Output pointer (1 x 1)
    X,  # Input pointer (1 x 1 x N)
    N,  # N=input_dim
    Norm_BLOCK_N: tl.constexpr,  # Block size for input_dim (N)
):
    # Block indices
    start_n = tl.program_id(0)
    input_indices = start_n * Norm_BLOCK_N + tl.arange(0, Norm_BLOCK_N)
    X_ptr = X + input_indices  
    Y_ptr = Y

    # Load input and create sparsity mask
    input_mask = input_indices < N
    
    x = tl.load(X_ptr, mask=input_mask, eviction_policy='evict_last')
    acc = tl.sum(x.to(tl.float32) * x.to(tl.float32))
    tl.atomic_add(Y_ptr, acc)


configs=[

    triton.Config({"BLOCK_M": 16, "BLOCK_N": 16}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=1, pre_hook=init_to_zero("Y")),

    triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero("Y")),

    triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")),

    triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=1, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, pre_hook=init_to_zero("Y")),

    triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=2, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, pre_hook=init_to_zero("Y")),

    triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
    triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
]

@triton.autotune(
    configs=configs,
    key=["N", "M", "threshold_ratio"]  # N=input_dim, M=out_dim
)
@triton.jit
def sparse_gemv_kernel(
    Y,  # Output pointer (1 x 1 x M)
    A,  # Weight pointer (M x N, column-major)
    X,  # Input pointer (1 x 1 x N)
    threshold_ratio,  # Threshold for sparsity
    norm_squared_ptr,     # Squared norm of X
    N, M,             # N=input_dim, M=out_dim
    BLOCK_N: tl.constexpr,  # Block size for input_dim (N)
    BLOCK_M: tl.constexpr,  # Block size for out_dim (M)
):
    # Block indices
    start_n = tl.program_id(0)
    start_m = tl.program_id(1)
    input_indices = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    output_indices = start_m * BLOCK_M + tl.arange(0, BLOCK_M)

    # A is column major: A[output][input] = A + input * M + output
    A_ptr = A + (input_indices[None, :] * M + output_indices[:, None])
    X_ptr = X + input_indices  
    Y_ptr = Y + output_indices

    # Load input and create sparsity mask
    input_mask = input_indices < N
    output_mask = output_indices < M
    
    norm_squared = tl.load(norm_squared_ptr)

    x = tl.load(X_ptr, mask=input_mask, eviction_policy='evict_last')
    x_squared = x * x
    sparse_mask = (x_squared > threshold_ratio * norm_squared) & input_mask
    
    # Load weights using sparsity mask
    a = tl.load(A_ptr, mask=output_mask[:, None] & sparse_mask[None, :], 
                other=0.0, eviction_policy='evict_first')
    
    # Compute sparse matrix-vector product
    acc = tl.sum(a.to(tl.float32) * x.to(tl.float32)[None, :], axis=1)

    # Atomic add to output
    tl.atomic_add(Y_ptr, acc, mask=output_mask)


def sparse_gemv(
    x: torch.Tensor,
    weight: torch.Tensor,
    threshold_ratio: float,
    norm_squared: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Compute y = weight @ sparse(x), where weight is stored in column-major format.
    
    Args:
        x: Input tensor [1, 1, input_dim]
        weight: Weight matrix [out_dim, input_dim] (column-major)
        threshold_ratio: Sparsity threshold ratio
        norm_squared: Optional, squared norm of input vector. If None, it will be computed
    
    Returns:
        output: Result tensor [1, 1, out_dim]
    """
    M, N = weight.shape  # M=out_dim, N=input_dim
    batch_size, seq_len, input_dim = x.shape
    
    input_vec = x.contiguous()

    if weight.stride(0) != 1:
        raise ValueError("Weight matrix must be column-major (first dimension contiguous)")
    
    # Compute squared norm (if not provided)
    if norm_squared is None:
        norm_squared = torch.empty(1, device=x.device, dtype=torch.float32)
        grid = lambda META: (
            triton.cdiv(N, META["Norm_BLOCK_N"]),
        )
        norm_squared_kernel[grid](
            norm_squared,
            input_vec,
            N
        )
    
    # Create output tensor
    output_vec = torch.empty(
        1, 1, M,  # Output dimension is the row count of weight matrix
        device=x.device,
        dtype=torch.float32,
    )
    
    # Define grid size
    grid = lambda META: (
        triton.cdiv(N, META["BLOCK_N"]),  # Number of blocks in input dimension
        triton.cdiv(M, META["BLOCK_M"]),  # Number of blocks in output dimension
    )
    
    # Call the optimized kernel
    sparse_gemv_kernel[grid](
        output_vec,        # Output pointer
        weight,            # Weight matrix (column-major)
        input_vec,         # Input vector
        threshold_ratio,   # Sparsity threshold ratio
        norm_squared,      # Squared norm of input vector
        N,                 # Input dimension
        M,                 # Output dimension
    )
    
    output_vec = output_vec.to(dtype=x.dtype)
    
    return output_vec