# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Optional, Tuple

import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from transformers.utils import logging

from fla.layers.utils import pad_input, unpad_input
from fla.modules import  RotaryEmbedding
from apex.normalization import FusedRMSNorm as RMSNorm
if TYPE_CHECKING:
    from fla.models.utils import Cache

try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
    warnings.warn(
        "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
        category=ImportWarning
    )
    flash_attn_func = None

logger = logging.get_logger(__name__)


import torch.nn.functional as F
import math

class RoundSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.round(decimals=2)  # 四舍五入函数
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output 


class Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int = 2048,
        num_heads: int = 32,
        num_kv_heads: Optional[int] = None,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        window_size: Optional[int] = None,
        rope_theta: Optional[float] = 10000.,
        max_position_embeddings: Optional[int] = None,
        layer_idx: int = None,
        kl2: bool = False,
        kernel: str = 'softmax_v1',
        usequ: bool = False,
        returnu: bool = False,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        if num_kv_heads is None:
            self.num_kv_heads = self.num_heads
        else:
            self.num_kv_heads = num_kv_heads
        self.num_kv_groups = num_heads // self.num_kv_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.kv_dim = self.num_kv_heads * self.head_dim
        self.qkv_bias = qkv_bias
        self.qk_norm = qk_norm

        self.window_size = window_size
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.layer_idx = layer_idx

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.alpha_proj = nn.Linear(self.hidden_size, self.num_kv_heads, bias=True)
        self.beta_proj = nn.Linear(self.hidden_size, self.num_kv_heads, bias=False)
        
        self.kl2 = kl2
        self.usequ = usequ
        if usequ:
            self.qu_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
        if qk_norm:
            self.q_norm = RMSNorm(self.head_dim)
            self.k_norm = RMSNorm(self.head_dim)
        if self.rope_theta == -1:
            self.rotary = None
        else:
            self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
        self.outnorm = RMSNorm(self.head_dim)
        self.kernel = kernel
        self.returnu = returnu
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if attention_mask is not None:
            assert len(attention_mask.shape) == 2, (
                "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
                "for padding purposes (0 indicating padding). "
                "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
            )

        batch_size, q_len, _ = hidden_states.size()

        q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
        k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
        v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)

        if self.qk_norm:
            q, k = self.q_norm(q), self.k_norm(k)

        # equivalent to cu_seqlens in `flash_attn`
        cu_seqlens = kwargs.get('cu_seqlens', None)

        seqlen_offset, max_seqlen = 0, q_len
        if past_key_values is not None:
            seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
            max_seqlen = q.shape[1] + seqlen_offset

            if attention_mask is not None:
                # to deliminate the offsets of padding tokens
                seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
                max_seqlen = q.shape[1] + max(seqlen_offset)

        if self.max_position_embeddings is not None:
            max_seqlen = max(max_seqlen, self.max_position_embeddings)
        if self.rope_theta>0:
            q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
        if self.usequ:
            qu = rearrange(self.qu_proj(hidden_states), '... (h d) ->... h d', d=self.head_dim)
            if self.rope_theta>0:
                _,qu = self.rotary(qu, qu, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
        else:
            qu = None
        alpha = 2*F.sigmoid(self.alpha_proj(hidden_states).transpose(-1,-2).unsqueeze(-1)) #Alpha, Beta: [B,H,S,1]
        beta =  alpha  #2*F.sigmoid(self.beta_proj(hidden_states).transpose(-1,-2).unsqueeze(-1))  
        if self.kl2:
            k = F.normalize(k, dim=-1)
            kernel_scale = 1
        else:
            kernel_scale = None

        if past_key_values is not None:
            cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
            k_cached, v_cached = past_key_values.update(
                attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
                layer_idx=self.layer_idx,
                offset=q_len,
                cache_kwargs=dict(window_size=self.window_size)
            )['attn_state']
            if cache_has_content:
                k, v = k_cached, v_cached
                k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
                v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)

        if flash_attn_func is None:
            raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")

        assert attention_mask is None
        assert cu_seqlens is None

        o = self.Infinite_Deltanet(q,k,v,alpha,beta,qu,kernel_scale = kernel_scale)
        # o = self.outnorm(o)
        #Contains at least one padding token in the sequence
        # if attention_mask is not None:
        #     q, (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, v), attention_mask, q_len)
        #     cu_seqlens_q, cu_seqlens_k = cu_seqlens
        #     max_seqlen_q, max_seqlen_k = max_seq_lens
        #     o = flash_attn_varlen_func(
        #         q, k, v,
        #         cu_seqlens_q=cu_seqlens_q,
        #         cu_seqlens_k=cu_seqlens_k,
        #         max_seqlen_q=max_seqlen_q,
        #         max_seqlen_k=max_seqlen_k,
        #         causal=True,
        #         window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
        #     )
        #     o = pad_input(o, indices_q, batch_size, q_len)
        # elif cu_seqlens is not None:
        #     o = flash_attn_varlen_func(
        #         q.squeeze(0), k.squeeze(0), v.squeeze(0),
        #         cu_seqlens_q=cu_seqlens,
        #         cu_seqlens_k=cu_seqlens,
        #         max_seqlen_q=max_seqlen,
        #         max_seqlen_k=max_seqlen,
        #         causal=True,
        #         window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
        #     ).unsqueeze(0)
        # else:
        #     o = flash_attn_func(
        #         q, k, v,
        #         causal=True,
        #         window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
        #     )
        o = o.reshape(batch_size, q_len, -1)
        o = self.o_proj(o)

        if not output_attentions:
            attentions = None

        return o, attentions, past_key_values
    def Infinite_Deltanet(self, Q, K, V, Alpha=1, Beta=1, QU=None, kernel_scale=None):  ##Q,K,V,QU [B,S,H,D], Alpha, Beta: [B,H,S,1]
        qtype = Q.dtype
        if kernel_scale is None:
            kernel_scale = math.sqrt(K.shape[-1])
        if QU is None:
            A = torch.einsum("b s h d, b t h d-> b h s t", K / kernel_scale, K)
        else:
            A = torch.einsum("b s h d, b t h d-> b h s t", QU / kernel_scale, K)

        if self.kernel == 'softmax_v1':
            mask = torch.tril(torch.ones_like(A), diagonal=0) == 1
            A = torch.where(mask, A, float('-inf'))
            A = F.softmax(A, dim=-1).float()
        elif self.kernel == 'softmax_v2':
            mask = torch.tril(torch.ones_like(A), diagonal=-1) == 1
            A = torch.where(mask, A, float('-inf'))
            A = F.softmax(A, dim=-1).float()
            A[:,:,0] = 0
        elif self.kernel == 'round':
            A = (RoundSTE.apply(A)).float()
        elif self.kernel == 'linear':
            A = A.float()
        elif self.kernel == 'relu':
            A = F.relu(A).float()
        elif  self.kernel == 'exp':
            A = torch.exp(A).float()
        U = torch.linalg.solve_triangular(Beta * A, Alpha * (V.transpose(-3,-2)), upper=False, unitriangular=True).to(K).transpose(-3,-2)
        if self.returnu:
            return U
        return flash_attn_func(
                Q,
                K,
                U,
                causal=True,
            )
