# transfromers version 4.38.2
# No support of sliding window. Check our paper for more reason about why we don't use it.
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
from transformers.cache_utils import Cache
import numpy as np
from flash_attn import flash_attn_func, flash_attn_varlen_func

from .selfextend_flash_attn import self_extend_flash_forward



def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# bs,s_len
# bs,s_len,dim
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
#     # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
#     cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
#     sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
#     cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
#     sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
#     q_embed = (q * cos[:,:, -q.shape[2]:]) + (rotate_half(q) * sin[:,:, -q.shape[2]:]) if q is not None else None
#     k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None
#     return q_embed, k_embed

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin) if not q is None else None
    k_embed = (k * cos) + (rotate_half(k) * sin) if not k is None else None
    return q_embed, k_embed

def apply_grouped_rotary_pos_emb(q, k, cos, sin, position_ids, g_size_1=1, g_size_2=4096):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    position_ids_q = position_ids//g_size_1 + g_size_2 - g_size_2//g_size_1
    position_ids_k = position_ids//g_size_1

    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos_q = cos[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin_q = sin[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
    cos_k = cos[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin_k = sin[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos_q) + (rotate_half(q) * sin_q) if q is not None else None
    k_embed = (k * cos_k) + (rotate_half(k) * sin_k) if k is not None else None

    return q_embed, k_embed

def apply_rotary_pos_emb_by_heads(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    q_embed = (q * cos) + (rotate_half(q) * sin) if not q is None else None
    k_embed = (k * cos) + (rotate_half(k) * sin) if not k is None else None
    return q_embed, k_embed


def self_extend_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch.LongTensor] = None,
    group_size_1: Optional[float] = 8,
    group_size_2: Optional[float] = 2048,
    scale_base: Optional[float] = -1,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    if scale_base > 0:
        scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale 
        #scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale 
    else:
        scaled_query = query_states
    
    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    
    query_position = position_ids
    # only consider bsz=1 for now. 
    key_position = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len)


    neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position) 
    _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position) 
    _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
    group_query_states, _ = apply_grouped_rotary_pos_emb(scaled_query, None, cos, sin, query_position, g_size_1=group_size_1, g_size_2=_re_group_size_2) 
    _, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, key_position, g_size_1=group_size_1, g_size_2=_re_group_size_2) 


    group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
    neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 


    if group_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {group_attn_weights.size()}"
        )
    
    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        group_attn_weights = group_attn_weights + attention_mask
        neighbor_attn_weights = neighbor_attn_weights + attention_mask


    if q_len == 1:
        neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
        neighbor_attention_mask[:, -group_size_2:] = 1
    elif q_len == kv_seq_len:
        neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
        neighbor_attention_mask = torch.tril(neighbor_attention_mask)
        if q_len-group_size_2 > 0:
            group_attention_mask =  torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
            neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask

    else:
        raise ValueError("q_len should be 1 or seq_len.")


    neighbor_attention_mask = neighbor_attention_mask.bool()
    attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value


def concat_tensors(tensor_list, dim_range, half_head_dim):
    segments = []
    # First half
    for i in range(len(tensor_list)):
        segments.append(tensor_list[i][:,:,dim_range[i]:dim_range[i+1]])
    # Second half
    for i in range(len(tensor_list)):
        segments.append(tensor_list[i][:,:,half_head_dim+dim_range[i]:half_head_dim+dim_range[i+1]])
    return torch.cat(segments, dim=-1)

def flash_self_extend_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch.LongTensor] = None,
    group_size_1: Optional[float] = 8,
    group_size_2: Optional[float] = 2048,
	dim_range: Optional[dict] = {"low_dim": 32, "high_dim": 56},
    selected_dim: Optional[torch.Tensor] = None,
    scale_base: Optional[float] = -1,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )
    bsz, q_len, _ = hidden_states.size()
    assert len(group_size_1) == 8
    # group_size_1_low_dim = group_size_1["low_dim"]
    # group_size_1_mid_dim = group_size_1["mid_dim"]
    # group_size_1_high_dim = group_size_1["high_dim"]
    # low_dim, high_dim = dim_range.values()
    group_size_1_all = 1
    selected_dim = selected_dim[self.layer_idx]

    half_head_dim = self.head_dim // 2


    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    query_position = position_ids
    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, query_position)

    if scale_base > 0:
        scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale 
        #scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale 
    else:
        scaled_query = query_states

    
    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    
    key_position = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only support batch=1 for now.


    attn_dropout = self.config.attention_dropout if self.training else 0.0
    if q_len == 1:
        neighbor_key_position = position_ids[:, -1] - key_position
        _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
        group_key_position, decode_key_position = [], []
        decode_k_cos_list, decode_k_sin_list = [], []
        for i in group_size_1:
            group_key_position_i = position_ids[:, -1]//i - key_position//i + (_re_group_size_2 - _re_group_size_2//i)
            group_key_position.append(group_key_position_i)
            decode_key_position_i = torch.cat([group_key_position_i[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
            decode_key_position.append(decode_key_position_i)
            decode_k_cos_i, decode_k_sin_i = self.rotary_emb(value_states, decode_key_position_i)
            decode_k_cos_list.append(decode_k_cos_i)
            decode_k_sin_list.append(decode_k_sin_i)
        group_key_position_all = position_ids[:, -1]//group_size_1_all - key_position//group_size_1_all + (_re_group_size_2 - _re_group_size_2//group_size_1_all)
        decode_key_position_all = torch.cat([group_key_position_all[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
        decode_k_cos_all, decode_k_sin_all = self.rotary_emb(value_states, decode_key_position_all)#, seq_len=None)

        mask_list = torch.zeros((len(group_size_1), 1, self.num_heads, half_head_dim), dtype=torch.bool, device=query_states.device)

        # for i in range(len(group_size_1)):
        #     mask_list[i].scatter_(-1, selected_dim.unsqueeze(0),
        #                         ((selected_dim >= dim_range[i]) & (selected_dim < dim_range[i+1])).unsqueeze(0))
        #     mask_list[i] = torch.cat([mask_list[i], mask_list[i]], dim=-1)

        #     mask_list[i] = mask_list[i].unsqueeze(2).expand(-1, -1, value_states.size(2), -1)  # [1, num_heads, seq_len, head_dim]
        for i in range(len(group_size_1)):
            mask_list[i].scatter_(-1, selected_dim.unsqueeze(0),
                                ((selected_dim >= dim_range[i]) & (selected_dim < dim_range[i+1])).unsqueeze(0))
        
        mask_list = torch.cat([mask_list, mask_list], dim=-1)

        mask_list = mask_list.unsqueeze(3).expand(-1, -1, -1, value_states.size(2), -1)  # [len(group_size_1), 1, num_heads, seq_len, head_dim]

        decode_k_cos = decode_k_cos_all.clone()
        decode_k_sin = decode_k_sin_all.clone()

        for i in range(len(group_size_1)-1, -1, -1):
            decode_k_cos = torch.where(mask_list[i], decode_k_cos_list[i], decode_k_cos)
            decode_k_sin = torch.where(mask_list[i], decode_k_sin_list[i], decode_k_sin)


        # decode_k_cos = concat_tensors(decode_k_cos_list, dim_range, half_head_dim)
        # decode_k_sin = concat_tensors(decode_k_sin_list, dim_range, half_head_dim)

        #import pdb; pdb.set_trace()
        #neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position_ids) 
        decode_query_states = scaled_query.transpose(1,2).contiguous() # position 0: cos 0 = 1, sin 0 = 0
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        _, decode_key_states = apply_rotary_pos_emb_by_heads(None, key_states, decode_k_cos, -decode_k_sin) 

        
        # decode_key_states = repeat_kv(decode_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        # decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        decode_key_states = decode_key_states.transpose(1, 2).contiguous()
        # decode_value_states = value_states.transpose(1, 2).contiguous()
        attn_output = flash_attn_func(decode_query_states,
                                      decode_key_states,
                                      decode_value_states,
                                      attn_dropout, 
                                      softmax_scale=None, 
                                      causal=True)
    
    elif q_len == kv_seq_len:
        # set correct position_ids & apply RoPE.
        # print(query_position.shape)
        neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position)#, seq_len=None)
        neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position)#, seq_len=None)

        _re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position


        # group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
        # group_key_position = key_position // group_size_1

        group_query_position, group_key_position = [], []

        for i in group_size_1:
            group_query_position.append(query_position // i + _re_group_size_2 - _re_group_size_2 / i)
            group_key_position.append(key_position // i)

        # group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position)#, seq_len=None)
        # group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position)#, seq_len=None)

        group_q_cos_list, group_q_sin_list = [], []
        group_k_cos_list, group_k_sin_list = [], []
        for i in range(len(group_size_1)):
            group_q_cos_i, group_q_sin_i = self.rotary_emb(value_states, group_query_position[i])
            group_q_cos_list.append(group_q_cos_i)
            group_q_sin_list.append(group_q_sin_i)
            group_k_cos_i, group_k_sin_i = self.rotary_emb(value_states, group_key_position[i])
            group_k_cos_list.append(group_k_cos_i)
            group_k_sin_list.append(group_k_sin_i)

        group_query_position_all = query_position // group_size_1_all + _re_group_size_2 - _re_group_size_2 / group_size_1_all
        group_key_position_all = key_position // group_size_1_all
        group_q_cos_all, group_q_sin_all = self.rotary_emb(value_states, group_query_position_all)#, seq_len=None)
        group_k_cos_all, group_k_sin_all = self.rotary_emb(value_states, group_key_position_all)#, seq_len=None)

        mask_list = torch.zeros((len(group_size_1), 1, self.num_heads, half_head_dim), dtype=torch.bool, device=query_states.device)
        # print(mask_list.shape)
        # print(selected_dim)
        # print(dim_range)
        for i in range(len(group_size_1)):
            mask_list[i].scatter_(-1, selected_dim.unsqueeze(0),
                                ((selected_dim >= dim_range[i]) & (selected_dim < dim_range[i+1])).unsqueeze(0))
        
        mask_list = torch.cat([mask_list, mask_list], dim=-1)

        mask_list = mask_list.unsqueeze(3).expand(-1, -1, -1, value_states.size(2), -1)  # [len(group_size_1), 1, num_heads, seq_len, head_dim]

        group_q_cos = group_q_cos_all.clone()
        group_q_sin = group_q_sin_all.clone()
        group_k_cos = group_k_cos_all.clone()
        group_k_sin = group_k_sin_all.clone()

        for i in range(len(group_size_1)-1, -1, -1):
            group_q_cos = torch.where(mask_list[i], group_q_cos_list[i], group_q_cos)
            group_q_sin = torch.where(mask_list[i], group_q_sin_list[i], group_q_sin)
            group_k_cos = torch.where(mask_list[i], group_k_cos_list[i], group_k_cos)
            group_k_sin = torch.where(mask_list[i], group_k_sin_list[i], group_k_sin)
            


        neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
        _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
        
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        group_query_states, _ = apply_rotary_pos_emb_by_heads(scaled_query, None, group_q_cos, group_q_sin, None)
        _, group_key_states = apply_rotary_pos_emb_by_heads(None, key_states, group_k_cos, group_k_sin, None)



        neighbor_query_states = neighbor_query_states.transpose(1, 2).contiguous()
        neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        # neighbor_key_states = neighbor_key_states.transpose(1, 2).contiguous()
        group_query_states = group_query_states.transpose(1, 2).contiguous()
        # group_key_states = repeat_kv(group_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        # value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
        group_key_states = group_key_states.transpose(1, 2).contiguous()
        value_states = value_states.transpose(1, 2).contiguous()

        attn_output = self_extend_flash_forward(self,
                                                query_position,
                                                group_size_2,
                                                neighbor_query_states,
                                                neighbor_key_states,
                                                group_query_states,
                                                group_key_states,
                                                value_states,
                                                attention_mask,
                                                bsz,
                                                q_len,
                                                kv_seq_len,
                                                attn_dropout,
                                            )
    else:
        raise ValueError("q_len should be 1 or seq_len.")

    attn_output = attn_output.contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

