"""
@Description :   特征融合网络
@Author      :   tqychy 
@Time        :   2025/01/27 08:38:44
"""
import torch
import torch.nn as nn


# ------------ CrossAttention ------------ #
class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=1):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        # 线性投影层
        self.Wq = nn.Linear(dim, dim)
        self.Wk = nn.Linear(dim, dim)
        self.Wv = nn.Linear(dim, dim)

        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

        # 归一化层
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        residual = query

        # 计算Q, K, V
        q = self.Wq(query)  # (n, dim)
        k = self.Wk(key)    # (m, dim)
        v = self.Wv(value)  # (m, dim)

        # 缩放点积注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / \
            (self.head_dim ** 0.5)  # (n, m)
        attn_weights = nn.functional.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)  # (n, dim)

        # 残差连接和归一化
        output = self.norm1(residual + output)

        # 前馈网络
        ffn_output = self.ffn(output)
        output = self.norm2(output + ffn_output)

        return output
    
class AttentionFuse(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args

        dim = self.cfg.NET.FEATURE_EXTRACT_DIM
        num_heads = self.cfg.NET.VIT.NUM_HEADS

        self.att1 = AttentionBlock(dim, num_heads)
        self.att2 = AttentionBlock(dim, num_heads)

    def forward(self, f1, f2):
        f1_1 = self.att1(f1, f2, f2)
        f2_1 = self.att2(f2, f1, f1)
        return (f1_1 + f2_1) / 2

# ------------ SelfGate ------------ #
class SelfGateV1(nn.Module):
    """GRU update-gate-like fusion module"""

    def __init__(self, dim=64):
        super().__init__()
        # self.fc = nn.Linear(128, 1)
        self.fc = nn.Linear(2 * dim, dim)
        self.fc1 = nn.Linear(dim, dim)
        self.activate = nn.ELU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, t):
        """
        :param q: [batch_size, n, dim]
        :param c: [batch_size, n, dim]
        :param t: [batch_size, n, dim]
        :return: mixed_feature [batch_size, n, dim]
        """
        bs, n, dim = c.size()

        w = self.fc(torch.cat((c, t), dim=-1))
        w = self.activate(w)
        w = self.sigmoid(self.fc1(w))

        c = c * w.view(bs, n, -1)
        t = t * (1 - w).view(bs, n, -1)
        mixed_feature = torch.add(c, t)

        return mixed_feature


class SelfGateV2(nn.Module):
    """GRU update-gate-like fusion module"""

    def __init__(self):
        super(SelfGateV2, self).__init__()
        self.fc = nn.Linear(128, 64)
        self.fc1 = nn.Linear(64, 64)
        self.activate = nn.ELU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, t):
        """
        :param q: [batch_size, n, dim]
        :param c: [batch_size, n, dim]
        :param t: [batch_size, n, dim]
        :return: mixed_feature [batch_size, n, dim]
        """
        bs, n, dim = c.size()

        w = self.fc(torch.cat((c, t), dim=-1))
        w = self.activate(w)
        w = self.sigmoid(self.fc1(w))

        c = c * w.view(bs, n, -1)
        t = t * (1 - w).view(bs, n, -1)
        # mixed_feature = torch.add(c, t)
        mixed_feature = torch.cat((c, t), dim=-1)

        return mixed_feature