import torch
import torch.nn as nn
import torch.nn.functional as F

from .kernel import _attention
from .utils import get_block_map, HedgehogFeatureMap


class SparseLinearAttention(nn.Module):
    def __init__(
        self,
        head_dim,
        topk,
        feat_kernel="softmax",
        BLKQ=64,
        BLKK=64,
        use_bf16=True,
        tie_feature_map_qk=True,
    ):
        super().__init__()
        self.dtype = torch.bfloat16 if use_bf16 else torch.float16
        self.topk = topk
        self.BLKQ = BLKQ
        self.BLKK = BLKK
        self.proj_l = nn.Linear(head_dim, head_dim, dtype=self.dtype)

        if feat_kernel == 'hedgehog':
            self.feature_map_q = HedgehogFeatureMap(head_dim).to(self.dtype)
            self.feature_map_k = HedgehogFeatureMap(head_dim).to(self.dtype)
        elif feat_kernel == 'elu':
            def elu_feature_map(x):
                return F.elu(x) + 1
            self.feature_map_q = elu_feature_map
            self.feature_map_k = elu_feature_map
        elif feat_kernel == 'relu':
            self.feature_map_q = nn.ReLU()
            self.feature_map_k = nn.ReLU()
        elif feat_kernel == 'softmax':
            def softmax_feature_map(x):
                return F.softmax(x, dim=-1)
            self.feature_map_q = softmax_feature_map
            self.feature_map_k = softmax_feature_map
        else:
            raise NotImplementedError(f"Not supported feature map `{feat_kernel}`.")

        if tie_feature_map_qk:
            self.feature_map_k = self.feature_map_q

        self.init_weights_()

    def init_weights_(self):
        with torch.no_grad():
            nn.init.zeros_(self.proj_l.weight)
            nn.init.zeros_(self.proj_l.bias)
        
        if isinstance(self.feature_map_q, HedgehogFeatureMap):
            self.feature_map_q.init_weights_()
            self.feature_map_k.init_weights_()

    def forward(self, q, k, v, return_sparsity=False):
        dtype = q.dtype
        
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        
        current_topk = self.topk
        
        sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=current_topk, BLKQ=self.BLKQ, BLKK=self.BLKK)

        q = q.to(self.dtype)
        k = k.to(self.dtype)
        v = v.to(self.dtype)
        c_q = self.feature_map_q(q).contiguous().to(self.dtype)
        c_k = self.feature_map_k(k).contiguous().to(self.dtype)

        o_s, o_l = _attention.apply(q, k, v, c_q, c_k, sparse_map, lut, real_topk, self.BLKQ, self.BLKK)
        o = (o_s + self.proj_l(o_l)).to(dtype)

        if return_sparsity:
            return o, real_topk / sparse_map.shape[-1]
        else:
            return o
