import argparse
import gc
import timeit
from typing import Dict, Tuple, Optional

from pathlib import Path
import abc
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn

from custom_models.kernels.bwell_base import (
    matmul_dense_to_bwell_only_indices_tma_kernel,
    matmul_dense_to_bwell_only_indices_tma,
)

from custom_models.kernels.bwell_gated import (
    gated_downprojection_bwell_kernel,
    gated_downprojection_bwell,
    gated_downprojection_bwell_ord,
    gated_downprojection_bwell_ord_subtile,
    gated_downprojection_bwell_os,
    gated_downprojection_bwell_np2_kernel,
    gated_downprojection_bwell_np2,
    gated_upprojection_bwell,
    gated_up_ip_bwell,
)
import triton

from triton.tools.tensor_descriptor import TensorDescriptor



def matmul_dense_to_bwell_nn_128x256x64(
        a,
        b,
        N,
        K,
        num_blocks_n,
        ):
    M = a.shape[0]
    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)

    a_td = TensorDescriptor.from_tensor(a, block_shape=[128, 64])
    b_td = TensorDescriptor.from_tensor(b, block_shape=[64, 256])
    
    c_td = TensorDescriptor.from_tensor(c, block_shape=[128, 128])

    idxs = torch.empty((M, N), dtype=torch.uint8, device=a.device)
    nnzs = torch.empty((M, num_blocks_n), dtype=torch.uint8, device=a.device)

    matmul_dense_to_bwell_only_indices_tma_kernel[
            (triton.cdiv(M, 128) * num_blocks_n,)](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=M, N=N, K=K,
        num_blocks_n=num_blocks_n,
        warp_specialize=False,
        T_m=128,
        T_n=256,
        T_k=64,
        G_m=8,
        num_warps=8,
        num_stages=4,
    )
    return c, idxs, nnzs


def fused_up_down_bwell_to_dense_(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
    M,
    N,
    K,
    K_blocks,
    block_size,
):

    M = sparse_gate.shape[0]

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )


    gated_downprojection_bwell_kernel[(M,)](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=K,
        sparse_gate_stride_k=1,
        indices_stride_m=K,
        indices_stride_k=1,
        input_stride_m=N,
        input_stride_n=1,
        up_weight_stride_k=N,
        up_weight_stride_n=1,
        down_weight_stride_k=N,
        down_weight_stride_n=1,
        output_stride_m=N,
        output_stride_n=1,
        num_stages=1,
        num_warps=1,
    )
    return output


class BwellMLPBase(nn.Module, abc.ABC):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        **kwargs,
    ):
        nn.Module.__init__(self)
        gate_w = gate_linear.weight.detach().clone().T.contiguous()
        up_w = up_linear.weight.detach().clone().contiguous()
        down_w = down_linear.weight.detach().clone().T.contiguous()

        self.register_buffer("gate_weight", gate_w)
        self.register_buffer("up_weight", up_w)
        self.register_buffer("down_weight", down_w)
        self.block_size = 256
        self.K = self.gate_weight.shape[1]
        self.N = self.down_weight.shape[1]
        self.K_blocks = self.K // self.block_size
        assert self.K % self.block_size == 0, (
            "intermediate size (K) must be multiple of block_size"
        )
        assert self.gate_weight.shape[0] == self.N, (
            "gate weight shape mismatch"
        )
        assert (self.up_weight.shape[0] == self.K and
                self.up_weight.shape[1] == self.N), (
            "up weight shape mismatch"
        )
        assert self.down_weight.shape[0] == self.K, (
            "down weight shape mismatch"
        )
    
    @abc.abstractmethod
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pass

class BwellMLPv0(BwellMLPBase):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        **kwargs,
    ):
        BwellMLPBase.__init__(
            self,
            gate_linear=gate_linear,
            up_linear=up_linear,
            down_linear=down_linear,
        )
        
        if (self.N & (self.N - 1)) != 0:
            downprojection = gated_downprojection_bwell_np2
        else:
            downprojection = gated_downprojection_bwell_np2_kernel

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )
        return gated_downprojection_bwell(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)

class BwellMLPv01(BwellMLPBase, nn.Module):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        **kwargs,
    ):
        BwellMLPBase.__init__(
            self,
            gate_linear=gate_linear,
            up_linear=up_linear,
            down_linear=down_linear,
        )
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_nn_128x256x64(
                a=hidden_states.view(batch_size * seq_len, self.N),
                b=self.gate_weight,
                N=self.K,
                K=self.N,
                num_blocks_n=self.K_blocks,
            )
        )
        return fused_up_down_bwell_to_dense_(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            input=hidden_states.view(batch_size * seq_len, self.N),
            up_weight=self.up_weight,
            down_weight=self.down_weight,
            M=batch_size * seq_len,
            N=self.N,
            K=self.K,
            K_blocks=self.K_blocks,
            block_size=self.block_size,
        ).view(batch_size, seq_len, self.N)


class BwellMLPv0Sort(BwellMLPBase):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        layer_idx: int,
        **kwargs,
    ):
        BwellMLPBase.__init__(
            self,
            gate_linear=gate_linear,
            up_linear=up_linear,
            down_linear=down_linear,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )
        return gated_downprojection_bwell_ord(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            priority=torch.argsort(nz_by_block.sum(-1), descending=True),
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)


class BwellMLPv0SyncSort(BwellMLPBase):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        layer_idx: int,
        start_layer_idx: int = 1,
        **kwargs,
    ):
        BwellMLPBase.__init__(
            self,
            gate_linear=gate_linear,
            up_linear=up_linear,
            down_linear=down_linear,
        )
        self.layer_idx = layer_idx
        self.start_layer_idx = start_layer_idx
        self.compute_priority = layer_idx == start_layer_idx
        self.priority = None
    
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )
        if self.compute_priority:
            self.priority = torch.argsort(nz_by_block.sum(-1), descending=True)
        return gated_downprojection_bwell_ord(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            priority=self.priority,
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)

class BwellPriorityCache(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.register_buffer(
            "priority",
            torch.empty(0, dtype=torch.int64),
            persistent=False,
        )

    def clear(self):
        self.priority = self.priority.new_empty(0)

    def get(self):
        if self.priority.numel() == 0:
            return None
        return self.priority

    def set(self, priority: torch.Tensor):
        self.priority = priority


class BwellMLPv0SyncSort(BwellMLPBase):
    def __init__(
        self,
        gate_linear: nn.Linear,
        up_linear: nn.Linear,
        down_linear: nn.Linear,
        layer_idx: int,
        start_layer_idx: int = 1,
        sort_pre_priority: bool = False,
        priority_cache: BwellPriorityCache | None = None,
        **kwargs,
    ):
        BwellMLPBase.__init__(
            self,
            gate_linear=gate_linear,
            up_linear=up_linear,
            down_linear=down_linear,
        )
        self.layer_idx = layer_idx
        self.start_layer_idx = start_layer_idx
        self.compute_priority = layer_idx == start_layer_idx
        self.priority_cache = priority_cache
        self.sort_pre_priority = sort_pre_priority
        self.pre_priority_layer = start_layer_idx < layer_idx
        self.pre_priority_forward = (
            self._sorted_forward if self.sort_pre_priority 
            else self._unsorted_forward
        )

    def _unsorted_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )
        return gated_downprojection_bwell(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)
    
    def _sorted_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)
        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )
        return gated_downprojection_bwell_ord(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            priority=torch.argsort(nz_by_block.sum(-1), descending=True),
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(batch_size * seq_len, self.N)

        gate_activations, idxs, nz_by_block = (
            matmul_dense_to_bwell_only_indices_tma(
                a=hidden_states,
                b=self.gate_weight,
            )
        )

        priority = None
        if self.priority_cache is not None:
            priority = self.priority_cache.get()

        if self.compute_priority:
            priority = torch.argsort(
                nz_by_block.sum(-1),
                descending=True,
            )
            if self.priority_cache is not None:
                self.priority_cache.set(priority)

        out = gated_downprojection_bwell_ord(
            sparse_gate=gate_activations,
            indices_by_block=idxs,
            nonzeros_by_block=nz_by_block,
            priority=priority,
            input=hidden_states,
            up_weight=self.up_weight,
            down_weight=self.down_weight,
        ).view(batch_size, seq_len, self.N)

        return out