import torch

from transformers.models.llama import modeling_llama
from transformers.models.llama.modeling_llama import *
import copy

# should be implemented with a triton kernel in the future
def sparse_attn(
    q, k, v,
    kv_indices,
    attention_mask,
):
    # q: [bsz, num_heads, q_len, head_dim]
    # k,v: [bsz, num_heads, kv_len, head_dim]
    # kv_indices: [bsz, num_heads, q_len, topk]
    # attention_mask: [bsz, num_heads, q_len, kv_len]
    
    bsz, num_heads, kv_len, head_dim = k.size()
    topk = kv_indices.size(-1)
    q_len = q.size(2)

    # broadcast_k: [bsz, num_heads, q_len, topk, head_dim]
    # ! this might lead to a large memory consumption, need to be tested on A100
    broadcast_k = torch.zeros(bsz, num_heads, q_len, topk, head_dim, device=k.device, dtype=k.dtype)
    for i in range(q_len):
        broadcast_k[:, :, i, :, :] = k.gather(-2, kv_indices[:, :, i, :].unsqueeze(-1).expand(-1, -1, -1, head_dim))

    # score: [bsz, num_heads, q_len, topk]
    score = torch.matmul(
        q.unsqueeze(-2),
        broadcast_k.transpose(-2, -1)
    ).squeeze(-2) / (q.size(-1) ** 0.5)

    if attention_mask is not None:
        score = score + attention_mask.gather(-1, kv_indices)
    
    attn_weight = torch.softmax(score, dim=-1, dtype=torch.float32).to(q.dtype)

    broadcast_v = torch.zeros(bsz, num_heads, q_len, topk, head_dim, device=v.device, dtype=v.dtype)
    for i in range(q_len):
        broadcast_v[:, :, i, :, :] = v.gather(-2, kv_indices[:, :, i, :].unsqueeze(-1).expand(-1, -1, -1, head_dim))
    
    attn_output = torch.matmul(
        attn_weight.unsqueeze(-2),
        broadcast_v
    ).squeeze(-2)

    return attn_output

def sparQ(
    q, k, v,
    attention_mask,
    num_top_dim_in_q,
    topk,
    local_window,
    use_mean_v = True,
):
    # q: [bsz, num_heads, q_len, head_dim]
    # k,v: [bsz, num_heads, kv_len, head_dim]
    # attention_mask: [bsz, num_heads, q_len, kv_len]
    head_dim = q.size(-1)
    q_len, kv_len = q.size(2), k.size(2)

    # ======================
    # estimation of the attn
    # top_dim_indices: [bsz, num_heads, q_len, num_top_dim_in_q]
    top_dim_indices = q.abs().topk(num_top_dim_in_q, dim=-1).indices
    q_topk = q.gather(-1, top_dim_indices)
    k_topk = k.gather(-1, top_dim_indices)

    # scale: [bsz, num_heads, q_len, 1]
    # add L1 norm to sqrt(d) to prevent sharp distribution
    scale = (
        q_topk.abs()
        .sum(-1)
        .div_(q.abs().sum(-1))
        .mul_(head_dim)
        .pow_(0.5)
        .unsqueeze(-1)
    )

    # attn: [bsz, num_heads, q_len, kv_len]
    approx_attn_weight = torch.matmul(q_topk, k_topk.transpose(2, 3)) / scale
    approx_attn_weight = approx_attn_weight + attention_mask

    is_local = torch.tril(torch.ones((q_len, kv_len), device=q.device, dtype=torch.bool)) \
        ^ torch.tril(torch.ones((q_len, kv_len), device=q.device, dtype=torch.bool), -local_window)
    approx_attn_weight = approx_attn_weight.masked_fill(is_local, torch.finfo(q.dtype).max)

    # ======================
    # find max-score keys
    topk = min(topk, kv_len)
    topk_indices = approx_attn_weight.topk(topk, dim=-1).indices

    attn_output = sparse_attn(
        q, k, v,
        topk_indices,
        attention_mask.expand(q.size(0), q.size(1), -1, -1),
    )

    if use_mean_v:
        causal_mask = torch.tril(torch.ones((q_len, kv_len), device=q.device, dtype=q.dtype))
        mean_v = torch.matmul(causal_mask, v) / causal_mask.sum(-1, keepdim=True)
        kv_weight = (
            torch.softmax(approx_attn_weight, dim=-1).gather(-1, topk_indices)
            .sum(-1)
            .to(q.dtype)
        ).unsqueeze(-1)
        attn_output = (1 - kv_weight) * mean_v + kv_weight * attn_output
    
    return attn_output

class SparQAttn(LlamaAttention):
    def __init__(self, config, **kwargs):
        # num_top_dim_in_q: 16
        # topk: 128,256,512
        # local_window: topk/4
        self.num_top_dim_in_q = config.to_dict().get("num_top_dim_in_q", -1)
        self.topk = config.to_dict().get("topk", -1)
        self.local_window = config.to_dict().get("local_window", 100)
        super().__init__(config, **kwargs)

    def dense_attn(
        self,
        q, k, v,
        attention_mask = None
    ):
        head_dim = q.size(-1)
        attn_weight = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : k.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, v)
        return attn_output

    def forward(
        self,
        hidden_states,
        attention_mask = None,
        position_ids = None,
        past_key_value = None,
        output_attentions = False,
        use_cache = False,
        cache_position = None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        past_key_value = getattr(self, "past_key_value", past_key_value)
        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if self.topk == -1:
            attn_output = self.dense_attn(query_states, key_states, value_states, attention_mask)
        else:
            attn_output = sparQ(query_states, key_states, value_states, attention_mask, self.num_top_dim_in_q, self.topk, self.local_window)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

def apply_sparq(model):
    
    def _convert(module):
        if isinstance(module, LlamaAttention):
            replacement = SparQAttn(module.config, layer_idx=module.layer_idx)
            
            replacement.to(module.q_proj.weight.device)
            replacement.to(module.q_proj.weight.dtype)
            replacement.load_state_dict(module.state_dict(), strict=False)
            return replacement
        
        result = module
        for name, child in module.named_children():
            replacement = _convert(child)
            if replacement is not child:
                if result is module:
                    result = copy.deepcopy(module)
                    result._modules[name] = module._modules.copy()
                result._modules[name] = replacement
        return result
    
    return _convert(model)
                    
        