# 文件路径: llava/model/custom_attention.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class StableLayerNorm(nn.Module):
   
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, input):
        input_dtype = input.dtype
        input_float32 = input.float()
        mean = input_float32.mean(dim=-1, keepdim=True)
        var = input_float32.var(dim=-1, keepdim=True, unbiased=False)
        var = torch.clamp(var, min=self.eps) 
        std = torch.sqrt(var + self.eps)
        normalized = (input_float32 - mean) / std
        normalized = normalized.to(input_dtype)
        if self.elementwise_affine:
            normalized = normalized * self.weight + self.bias
        return normalized


class StableSoftmax(nn.Module):

    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, input):
     
        input_dtype = input.dtype
        input_float32 = input.float()
        input_shifted = input_float32 - input_float32.max(dim=self.dim, keepdim=True)[0]
        input_clipped = torch.clamp(input_shifted, min=-50.0, max=50.0)

        exp_values = torch.exp(input_clipped)
        sum_exp = exp_values.sum(dim=self.dim, keepdim=True)
        sum_exp = torch.clamp(sum_exp, min=1e-8)
        softmax_output = exp_values / sum_exp
        softmax_output = torch.clamp(softmax_output, min=1e-8, max=1.0)
        softmax_output = softmax_output / softmax_output.sum(dim=self.dim, keepdim=True)

        return softmax_output.to(input_dtype)
        
class RelationalAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        if self.head_dim * self.num_heads != self.hidden_size:
            raise ValueError("hidden_size must be divisible by num_heads")

       
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.q_norm = StableLayerNorm(self.head_dim)
        self.k_norm = StableLayerNorm(self.head_dim)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        relational_bias: Optional[torch.Tensor] = None
    ) -> torch.Tensor:

        bsz, seq_len, _ = hidden_states.shape

        query_states = self._shape(self.q_proj(hidden_states), seq_len, bsz)
        key_states = self._shape(self.k_proj(hidden_states), seq_len, bsz)
        value_states = self._shape(self.v_proj(hidden_states), seq_len, bsz)
        
        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)
        

        
        attn_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) / (self.head_dim ** 0.5)
        

        attn_scores = attn_scores + relational_bias
        
      
        stable_softmax = StableSoftmax(dim=-1)
        attn_weights = stable_softmax(attn_scores)
        
       
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
        attn_output = self.out_proj(attn_output)

        return attn_output