import math
import warnings
from typing import Optional, Tuple
import torch
from torch import nn

from quant.new_pack import triton_quantize_and_pack_along_last_dim
from quant.matmul import cuda_bmm_fA_qB_outer


def metric(k,v):
    return k[0].abs().mean(dim=(0,2))


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):

    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
        

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def repeat_kv_quant(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def quant_step(k_bits,v_bits,group_size,residual_length,sink_num,sink_max_size,past_key_value,kv_seq_len):
    key_states_quant_trans = past_key_value[0]
    key_states_full = past_key_value[1]
    key_scale_trans = past_key_value[2]
    key_mn_trans = past_key_value[3]
    value_states_quant = past_key_value[4]
    value_states_full = past_key_value[5]
    value_scale = past_key_value[6]
    value_mn = past_key_value[7]
    if sink_num>0:
        sink_id = past_key_value[8]
        sink_k = past_key_value[9]
        sink_v  = past_key_value[10]
        sink_scores = past_key_value[11]
    quant_len=((key_states_full.shape[-2]-residual_length)//group_size)*group_size
    key_states_to_quant=key_states_full[:,:,:quant_len,:]
    value_states_to_quant=value_states_full[:,:,:quant_len,:]
    key_states_full=key_states_full[:,:,quant_len:,:]
    value_states_full=value_states_full[:,:,quant_len:,:]
    
    if sink_num>0:
        scores = metric(key_states_to_quant,value_states_to_quant)
        if sink_id is not None:
            mask=(scores<sink_scores[sink_num-1]).nonzero().squeeze(1)
            if mask.shape[0]>0 and sink_id.shape[0]<sink_max_size:
                values=scores[mask]
                sink_values,sink_indices=torch.topk(values,min(sink_num,mask.shape[0]),largest=False)
                sink_indices=mask[sink_indices]
            else:
                sink_values,sink_indices = None, None
        else:
            sink_values,sink_indices=torch.topk(scores,sink_num,largest=False)
        if sink_indices is not None:
            if sink_id is None:
                sink_id=sink_indices
                sink_k=key_states_to_quant[0,:,sink_indices,:]
                sink_v=value_states_to_quant[0,:,sink_indices,:]
                sink_scores=sink_values
            else:
                sink_id=torch.cat([sink_id,sink_indices+value_states_quant.shape[2]])
                sink_k=torch.cat([sink_k,key_states_to_quant[0,:,sink_indices,:]],dim=1)
                sink_v=torch.cat([sink_v,value_states_to_quant[0,:,sink_indices,:]],dim=1)
                sink_scores=torch.cat([sink_scores,sink_values])
                
            sink_scores,reverse_indices=sink_scores.sort()
            sink_id=sink_id[reverse_indices]
            sink_k=sink_k[:,reverse_indices,:]
            sink_v=sink_v[:,reverse_indices,:]

            key_states_to_quant[0,:,sink_indices,:]=key_states_to_quant[0,:,sink_indices,:].mean(dim=1).unsqueeze(1)
            value_states_to_quant[0,:,sink_indices,:]=value_states_to_quant[0,:,sink_indices,:].mean(dim=1).unsqueeze(1)
    
    
    key_states_quant_trans_new, key_scale_trans_new, key_mn_trans_new = triton_quantize_and_pack_along_last_dim(key_states_to_quant.transpose(2, 3).contiguous(), group_size, k_bits)
    if key_states_quant_trans is not None:
        key_states_quant_trans = torch.cat([key_states_quant_trans, key_states_quant_trans_new], dim=3)
        key_scale_trans = torch.cat([key_scale_trans, key_scale_trans_new], dim=3)
        key_mn_trans = torch.cat([key_mn_trans, key_mn_trans_new], dim=3)
    else:
        key_states_quant_trans = key_states_quant_trans_new
        key_scale_trans = key_scale_trans_new
        key_mn_trans = key_mn_trans_new            


    value_states_quant_new, scale, mn = triton_quantize_and_pack_along_last_dim(value_states_to_quant.contiguous(), group_size, v_bits)
    if value_states_quant is not None:
        value_states_quant = torch.cat([value_states_quant, value_states_quant_new], dim=2)
        value_scale = torch.cat([value_scale, scale], dim=2)
        value_mn = torch.cat([value_mn, mn], dim=2)
    else:
        value_states_quant = value_states_quant_new
        value_scale = scale
        value_mn = mn
    if sink_num>0:
        past_key_value = (key_states_quant_trans, key_states_full, key_scale_trans, key_mn_trans, value_states_quant, value_states_full, value_scale, value_mn,sink_id, sink_k,sink_v,sink_scores, kv_seq_len)
    else:
        past_key_value = (key_states_quant_trans, key_states_full, key_scale_trans, key_mn_trans, value_states_quant, value_states_full, value_scale, value_mn, kv_seq_len)
    return past_key_value      
        

class ModelAttention_SinkQ(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config,layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx=layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.k_bits = config.k_bits
        self.v_bits = config.v_bits
        self.group_size = config.group_size
        self.residual_length = config.residual_length
        self.sink_num = config.sink_num
        self.sink_max_size = config.sink_max_size
        
        # no sink in shallow layers
        if layer_idx in [0,1]:
            self.sink_num=0
 
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[-1]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            key_states_quant_trans = past_key_value[0]
            key_states_full = past_key_value[1]
            key_scale_trans = past_key_value[2]
            key_mn_trans = past_key_value[3]
            value_states_quant = past_key_value[4]
            value_states_full = past_key_value[5]
            value_scale = past_key_value[6]
            value_mn = past_key_value[7]
            if self.sink_num>0:
                sink_id = past_key_value[8]
                sink_k = past_key_value[9]
                sink_v  = past_key_value[10]
                sink_scores = past_key_value[11]
            if key_states_quant_trans is not None:
                # import ipdb; ipdb.set_trace()
                key_states_quant_trans_repeat = repeat_kv_quant(key_states_quant_trans, self.num_key_value_groups)
                key_scale_trans_repeat = repeat_kv_quant(key_scale_trans, self.num_key_value_groups)
                key_mn_trans_repeat = repeat_kv_quant(key_mn_trans, self.num_key_value_groups)
                att_qkquant = cuda_bmm_fA_qB_outer(self.group_size, query_states, key_states_quant_trans_repeat, 
                                key_scale_trans_repeat, key_mn_trans_repeat, self.k_bits) # key_states_quant_trans, key_scale_trans, key_mn_trans need to be repeated
                if self.sink_num > 0 and sink_id is not None:
                    att_sink = torch.matmul(query_states, repeat_kv_quant(sink_k.unsqueeze(0),self.num_key_value_groups).transpose(2,3))
                    att_qkquant[:,:,:,sink_id] = att_sink
            else:
                att_qkquant = None

            if key_states_full is not None:
                key_states_full = torch.cat([key_states_full, key_states], dim=2)
            else:
                key_states_full = key_states
            
            key_states_full_repeat = repeat_kv(key_states_full, self.num_key_value_groups)
            att_qkfull = torch.matmul(query_states, key_states_full_repeat.transpose(2, 3)) # key_states_full need to be repeated
            if att_qkquant is not None:
                attn_weights = torch.cat([att_qkquant, att_qkfull], dim=-1) / math.sqrt(self.head_dim)
            else:
                attn_weights = att_qkfull / math.sqrt(self.head_dim)

            if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                    f" {attn_weights.size()}"
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
                )

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

            value_states_full = torch.cat([value_states_full, value_states], dim=2)
            value_full_length = value_states_full.shape[-2]
            if value_states_quant is None:
                value_states_full_repeat = repeat_kv(value_states_full, self.num_key_value_groups)
                attn_output = torch.matmul(attn_weights, value_states_full_repeat) # value_states_full need to be repeated
            else:
                if self.sink_num > 0 and sink_id is not None:
                    attn_weights_sink=attn_weights[:, :, :, sink_id]
                    attn_weights[:, :, :, sink_id] = 0
                value_states_quant_repeat = repeat_kv_quant(value_states_quant, self.num_key_value_groups)
                value_scale_repeat = repeat_kv_quant(value_scale, self.num_key_value_groups)
                value_mn_repeat = repeat_kv_quant(value_mn, self.num_key_value_groups)
                attn_output = cuda_bmm_fA_qB_outer(self.group_size, attn_weights[:, :, :, :-value_full_length], value_states_quant_repeat, 
                                                value_scale_repeat, value_mn_repeat, self.v_bits) # value_states_quant, value_scale, value_mn need to be repeated
                
                value_states_full_repeat = repeat_kv(value_states_full, self.num_key_value_groups)
                attn_output += torch.matmul(attn_weights[:, :, :, -value_full_length:], value_states_full_repeat) # value_states_full need to be repeated
                if self.sink_num > 0 and sink_id is not None:
                    attn_output += torch.matmul(attn_weights_sink, repeat_kv(sink_v.unsqueeze(0), self.num_key_value_groups))
        else:
            key_states_repeat = repeat_kv(key_states, self.num_key_value_groups)
            value_states_repeat = repeat_kv(value_states, self.num_key_value_groups)
            attn_weights = torch.matmul(query_states, 
                                        key_states_repeat.transpose(2, 3)) / math.sqrt(self.head_dim)

            if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                    f" {attn_weights.size()}"
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
                )

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(
                attn_weights, dim=-1, dtype=torch.float32
            ).to(query_states.dtype)

            attn_output = torch.matmul(attn_weights, value_states_repeat) 
            key_states_quant_trans, key_states_full, key_scale_trans, key_mn_trans, value_states_quant, value_states_full, value_scale, value_mn,=None,key_states,None,None,None,value_states,None,None
            if self.sink_num>0:
                sink_id, sink_k,sink_v,sink_scores = None,None,None,None
        if use_cache:
            if self.sink_num>0:
                past_key_value = (key_states_quant_trans, key_states_full, key_scale_trans, key_mn_trans, value_states_quant, value_states_full, value_scale, value_mn,sink_id, sink_k,sink_v,sink_scores, kv_seq_len)
            else:
                past_key_value = (key_states_quant_trans, key_states_full, key_scale_trans, key_mn_trans, value_states_quant, value_states_full, value_scale, value_mn, kv_seq_len)
        if key_states_full.shape[-2] >= self.residual_length+self.group_size:
            past_key_value=quant_step(self.k_bits,self.v_bits,self.group_size,self.residual_length,self.sink_num,self.sink_max_size,past_key_value,kv_seq_len)
            
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None
        return attn_output, attn_weights, past_key_value
    