import torch
import torch.nn as nn
import torch.nn.functional as F
import math


import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class UltraFastCrossAttentionFilter(nn.Module):
    def __init__(self, d_model, news_emb, hidden_dim=None, topk=64, init_logit_scale=20.0):
        super().__init__()
        self.d_model = d_model
        self.topk = topk
        self.hidden_dim = hidden_dim or d_model
        D = news_emb

        # projections
        self.q_proj = nn.Linear(d_model, self.hidden_dim, bias=False)
        self.k_proj = nn.Linear(D, self.hidden_dim, bias=False)
        self.v_proj = nn.Linear(D, d_model, bias=False)

        # normalizers to stabilize scales
        self.q_norm = nn.LayerNorm(self.hidden_dim)
        self.k_norm = nn.LayerNorm(self.hidden_dim)

        # gating network expects concatenated (K_proj, V) dims -> hidden_dim + d_model
        hidden = max(self.hidden_dim, d_model)
        self.gate_fc = nn.Sequential(
            nn.Linear(self.hidden_dim + d_model, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
        
        nn.init.constant_(self.gate_fc[-1].bias, 1.0)

        
        self.logit_scale_raw = nn.Parameter(torch.tensor(6.0))   
        self.logit_scale_max = 40.0   

        self.gate_scale_raw = nn.Parameter(torch.tensor(1.0))  # softplus(1) ≈ 1.31
        self.gate_scale_max = 5.0

        # context scale (can remain unconstrained small learnable)
        self.ctx_scale = nn.Parameter(torch.tensor(1.0))

        # warm-up bias for gate final layer
        nn.init.constant_(self.gate_fc[-1].bias, -1.0)


    def forward(self, h_q, news_E, news_mask=None, use_last=False, return_topk=False):
        # pool h_q
        if h_q.dim() == 3:
            h_pooled = h_q[:, -1, :] if use_last else h_q.mean(dim=1)
        elif h_q.dim() == 2:
            h_pooled = h_q
        else:
            raise ValueError("h_q must be shape (B,D) or (B,L,D)")

        B, N, D = news_E.size()
        device = news_E.device
        dtype = news_E.dtype

       
        Q = self.q_proj(h_pooled)              
        Q = self.q_norm(Q).unsqueeze(1)         
        K = self.k_proj(news_E)                  
        K = self.k_norm(K)
        V = self.v_proj(news_E)                
 
        scores_raw = (K * Q).sum(-1)            # (B,N)
        logit_scale = F.softplus(self.logit_scale_raw)
        logit_scale = torch.clamp(logit_scale, max=self.logit_scale_max)
        scores = (scores_raw / math.sqrt(self.hidden_dim)) * logit_scale

        # mask invalid positions so they won't be in topk
        if news_mask is not None:
            scores = scores.masked_fill(~news_mask, float('-1e9'))

        K_sel = min(self.topk, N)
        top_scores, top_idx = torch.topk(scores, k=K_sel, dim=1)  # (B, K_sel)
        top_idx_sorted, sort_indices = torch.sort(top_idx, dim=1) # top_idx_sorted: [B, K_sel] (升序索引)
        idx_exp_dim = top_idx_sorted.unsqueeze(-1).expand(-1, -1, self.hidden_dim) # [B, K, H]
        idx_exp_val = top_idx_sorted.unsqueeze(-1).expand(-1, -1, V.shape[-1])               # [B, K, D]
        
        K_top = torch.gather(K, 1, idx_exp_dim) # [B, K, H]
        V_top = torch.gather(V, 1, idx_exp_val) # [B, K, D]
        gate_in = torch.cat([K_top, V_top], dim=-1)           # [B, K, H+D]
        gate_flat = gate_in.view(-1, gate_in.size(-1))
        gates_flat = torch.sigmoid(self.gate_fc(gate_flat)).squeeze(-1)
        gates_top = gates_flat.view(B, K_sel)                 # [B, K]
        
        # Gate Scaling
        gate_scale = F.softplus(self.gate_scale_raw)
        gate_scale = torch.clamp(gate_scale, max=self.gate_scale_max)
        gates_top = gates_top * gate_scale
 
        e_prime_compact = V_top * gates_top.unsqueeze(-1) * self.ctx_scale # [B, K, D]
         
        return e_prime_compact, gates_top
        