import torch
from torch import nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return (self.weight * x).to(input_dtype)

class FFN(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = F.silu

    def forward(self, x):
        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))

class Attention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scaling = self.head_dim ** -0.5

        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x):
        bsz, seq_len, _ = x.size()
        q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        k = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        v = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1,2)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
        attn_probs = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_probs, v)  # (bsz, num_heads, seq_len, head_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_dim)

        output = self.o_proj(attn_output)
        return output

class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, intermediate_dim, dropout=0.1):
        super().__init__()
        self.self_attn = Attention(hidden_dim, num_heads)
        self.ffn = FFN(hidden_dim, intermediate_dim)
        self.norm1 = RMSNorm(hidden_dim)
        self.norm2 = RMSNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x_norm = self.norm1(x)
        attn_out = self.self_attn(x_norm)
        attn_out = self.dropout(attn_out)
        x = residual + attn_out

        residual = x
        x_norm = self.norm2(x)
        ffn_out = self.ffn(x_norm)
        ffn_out = self.dropout(ffn_out)
        x = residual + ffn_out
        return x

class EAReranker(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=768, num_channels=4,
                 num_layers=4, num_heads=8, intermediate_dim=1024, dropout=0.1):
        super().__init__()
        # 多通道线性层
        self.cls_token = nn.Parameter(torch.randn(1,1,hidden_dim))
        self.sep_token = nn.Parameter(torch.randn(1,1,hidden_dim))
        self.q_channels = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_channels)])
        self.d_channels = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_channels)])
        self.decoder = nn.ModuleList(
            [DecoderLayer(hidden_dim, num_heads, intermediate_dim, dropout) for _ in range(num_layers)]
        )
        self.classifier = nn.Sequential(
            RMSNorm(hidden_dim),
            nn.Linear(hidden_dim, 512),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 512),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(dropout)

    def _project_multi_channel(self, x, channels):
        return torch.cat([layer(x).unsqueeze(1) for layer in channels], dim=1)

    def forward(self, q_emb, d_emb):
        bsz = q_emb.size(0)
        q_tokens = self._project_multi_channel(q_emb, self.q_channels)
        d_tokens = self._project_multi_channel(d_emb, self.d_channels)
        cls = self.cls_token.expand(bsz, -1, -1)
        sep = self.sep_token.expand(bsz, -1, -1)
        x = torch.cat([cls, q_tokens, sep, d_tokens], dim=1)

        for layer in self.decoder:
            x = layer(x)

        pooled = x[:, 0]

        score = self.classifier(pooled).squeeze(-1)
        return score
