import triton
import triton.language as tl
import torch
from torch import nn

import abc

from triton.tools.tensor_descriptor import TensorDescriptor
from typing import NamedTuple, Optional, Dict

from custom_models.kernels.bwell_gated import gated_downprojection_bwell



@triton.jit()
def matmul_dense_to_bwell_only_indices_tma_kernel(
    a_td,  
    b_td,  
    c_td,  
    idxs_ptr,  
    nnzs_ptr,  
    M,
    N,
    K,
    
    T_m: tl.constexpr,
    T_n: tl.constexpr,
    T_k: tl.constexpr,
    num_blocks_n: tl.constexpr,
    
    G_m: tl.constexpr,
    warp_specialize: tl.constexpr,
):
    tl.assume(N % T_n == 0)
    tl.assume(N / T_n == num_blocks_n)
    tl.assume(T_n % 2 == 0)

    pid = tl.program_id(0)
    num_blocks_m = tl.cdiv(M, T_m)

    elements_per_group = num_blocks_n * G_m
    group_id = pid // elements_per_group
    group_iter = pid % elements_per_group

    starting_grouped_m = group_id * G_m
    current_group_size = min(num_blocks_m - starting_grouped_m, G_m)

    m = starting_grouped_m + (group_iter % current_group_size)
    n = group_iter // current_group_size

    start_index_m = m * T_m
    start_index_n = n * T_n

    c_block_accumulator = tl.zeros((T_m, T_n), dtype=tl.float32)
    for start_index_k in tl.range(0, K, T_k, warp_specialize=warp_specialize):
        a_block = a_td.load([start_index_m, start_index_k])
        b_block = b_td.load([start_index_k, start_index_n])
        c_block_accumulator = tl.dot(a_block, b_block, acc=c_block_accumulator)

    T_n_half: tl.constexpr = T_n // 2

    
    c_block_accumulator = tl.reshape(c_block_accumulator, (T_m, 2, T_n_half))
    c_block_accumulator = tl.permute(c_block_accumulator, (0, 2, 1))
    c_block_accumulator_0, c_block_accumulator_1 = tl.split(
        c_block_accumulator)

    
    c_block_0 = c_block_accumulator_0.to(tl.bfloat16)
    is_positive_0 = c_block_accumulator_0 > 0
    c_td.store(
        [start_index_m, start_index_n],
        value=tl.where(is_positive_0, c_block_0, 0),
    )

    
    c_block_1 = c_block_accumulator_1.to(tl.bfloat16)
    is_positive_1 = c_block_accumulator_1 > 0
    c_td.store(
        [start_index_m, start_index_n + T_n_half],
        value=tl.where(is_positive_1, c_block_1, 0),
    )

    
    T_m_range = tl.arange(0, T_m)[:, None]
    T_n_range_half = tl.arange(0, T_n_half)[None, :]
    is_positive_int_0 = is_positive_0.to(tl.int32)
    n_offsets_0 = tl.cumsum(is_positive_int_0, axis=1) - is_positive_int_0

    idxs_block_ptr_0 = idxs_ptr + (start_index_m + T_m_range) * N + (
        start_index_n + n_offsets_0
    )
    tl.store(idxs_block_ptr_0, T_n_range_half.to(tl.uint8),
             mask=is_positive_0)

    nnz_0 = tl.sum(is_positive_int_0, keep_dims=True, axis=1)

    
    is_positive_int_1 = is_positive_1.to(tl.int32)
    n_offsets_1 = tl.cumsum(is_positive_int_1, axis=1) - is_positive_int_1

    idxs_block_ptr_1 = idxs_ptr + (start_index_m + T_m_range) * N + (
        start_index_n + nnz_0 + n_offsets_1
    )
    tl.store(
        idxs_block_ptr_1,
        (T_n_range_half + T_n_half).to(tl.uint8),
        mask=is_positive_1,
    )

    nnz_1 = tl.sum(is_positive_int_1, keep_dims=True, axis=1)

    
    nnzs_block_ptr = nnzs_ptr + (start_index_m + T_m_range) * num_blocks_n + n
    tl.store(nnzs_block_ptr, (nnz_0 + nnz_1).to(tl.uint8))


def matmul_dense_to_bwell_only_indices_tma(a, b, warp_specialize: bool = False):
    assert a.shape[1] == b.shape[0], "Inner dimensions must match for tma nn ops"
    assert a.is_contiguous(), "Check memory in a is contiguous"
    M, K = a.shape
    N = b.shape[1]

    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
    if warp_specialize:
        print(
            "Warning: warp specialize enabled for non persistent TMA matmul"
            " is this running on blackwell?")

    T_m = 128
    T_k = 64
    T_n = 256
    assert N % T_n == 0, "N must be multiple of 256 for this bwell nn ops"
    num_blocks_n = N // T_n
    a_td = TensorDescriptor.from_tensor(a, block_shape=[T_m, T_k])
    b_td = TensorDescriptor.from_tensor(b, block_shape=[T_k, T_n])
    c_td = TensorDescriptor.from_tensor(c, block_shape=[T_m, T_n//2])

    idxs = torch.empty((M, N), dtype=torch.uint8, device=a.device)
    nnzs = torch.empty((M, num_blocks_n), dtype=torch.uint8, device=a.device)

    def grid(meta):
        return (triton.cdiv(M, meta['T_m']) * num_blocks_n,)

    matmul_dense_to_bwell_only_indices_tma_kernel[grid](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=M, N=N, K=K, T_n=T_n, num_blocks_n=num_blocks_n,
        warp_specialize=warp_specialize,
        T_m=128,
        T_k=64,
        G_m=8,
        num_warps=8,
        num_stages=4,
    )

    return c, idxs, nnzs

@triton.jit()
def matmul_bwell_to_dense_only_indices_split_kernel(
    
    a_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    b_ptr,
    c_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    split_size: tl.constexpr,
    a_m,
    a_k,
    
    indices_stride_m,
    indices_stride_k,
    b_stride_k,
    b_stride_n,
    c_stride_m,
    c_stride_n,
):
    row_idx = tl.program_id(0)
    split_start = tl.program_id(1) * split_size
    output_accumulator = tl.zeros((split_size,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + block_idx * block_size + idx * indices_stride_k
            ).to(tl.int64) + block_start
            value = tl.load(
                a_ptr + row_idx * a_m + index * a_k,)

            
            b_row = tl.load(
                b_ptr + index * b_stride_k + (split_start + tl.arange(0, split_size)) * b_stride_n)

            output_accumulator += (value * b_row).to(tl.float32)
    c_row_ptr = c_ptr + row_idx * c_stride_m + (split_start + tl.arange(0, split_size)) * c_stride_n
    tl.store(c_row_ptr, output_accumulator.to(tl.bfloat16))


def matmul_bwell_to_dense_only_indices_split(a, indices_by_block, nonzeros_by_block, b, split_size=1024):
    
    assert a.is_contiguous(), 'values must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert b.is_contiguous(), 'b must be contiguous'

    M, K = a.shape
    K, N = b.shape
    M, K_blocks = nonzeros_by_block.shape

    assert K % K_blocks == 0, "K must be divisible by K_blocks"
    block_size = K // K_blocks
    n_splits = N // split_size
    assert N % split_size == 0, "N must be divisible by split_size"

    c = torch.empty(
        (M, N),
        device=a.device,
        dtype=torch.bfloat16,
    )

    grid = (M, n_splits)

    matmul_bwell_to_dense_only_indices_split_kernel[grid](
        a_ptr=a,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        b_ptr=b,
        c_ptr=c,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        split_size=split_size,
        a_m=a.stride(0),
        a_k=a.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        b_stride_k=b.stride(0),
        b_stride_n=b.stride(1),
        c_stride_m=c.stride(0),
        c_stride_n=c.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return c



class BwellMLPNGBase(nn.Module, abc.ABC):
    def __init__(
        self,
        
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        **kwargs,
    ):
        nn.Module.__init__(self)
        
        up_w = up_linear.weight.detach().clone().T.contiguous()
        down_w = down_linear.weight.detach().clone().T.contiguous()

        
        self.register_buffer("up_weight", up_w)
        self.register_buffer("down_weight", down_w)
        self.block_size = 256
        self.K = self.up_weight.shape[1]
        self.N = self.down_weight.shape[1]
        self.K_blocks = self.K // self.block_size
        assert self.K % self.block_size == 0, (
            "intermediate size (K) must be multiple of block_size"
        )
        
        
        
        assert (self.up_weight.shape[1] == self.K and
                self.up_weight.shape[0] == self.N), (
            "up weight shape mismatch"
        )
        assert self.down_weight.shape[0] == self.K, (
            "down weight shape mismatch"
        )
    
    @abc.abstractmethod
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pass

class BwellMLPNGv0(BwellMLPNGBase):
    def __init__(
        self,
        
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        **kwargs,
    ):
        BwellMLPNGBase.__init__(
            self,
            
            up_linear=up_linear,
            down_linear=down_linear,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.reshape(batch_size * seq_len, self.N).contiguous()
        up_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.up_weight,
            )
        )
        
        
        
        
        
        
        
        return matmul_bwell_to_dense_only_indices_split(
            a=up_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            b=self.down_weight,
            split_size=1024
        ).view(batch_size, seq_len, self.N)