"""
@Description :   碎片拼接网络的碎片组合部分
@Author      :   tqychy 
@Time        :   2025/01/14 11:14:01
"""
import sys

sys.path.append("./")

import math

import torch
import torch.nn as nn

from nets.feature_fuse import AttentionBlock, AttentionFuse


class CrossAttnDecoder(nn.Module):
    """
    Matching 中的交叉注意力解码器
    """
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args

        dim = self.cfg.NET.FEATURE_EXTRACT_DIM
        num_blocks = self.cfg.NET.DECODER.BLOCKS
        num_heads = self.cfg.NET.DECODER.NUM_HEADS

        self.blocks = nn.ModuleList()
        for _ in range(num_blocks):
            # 每个块包含两个交叉注意力：更新A和更新B
            self.blocks.append(AttentionBlock(dim, num_heads))
            self.blocks.append(AttentionBlock(dim, num_heads))

    @staticmethod
    def get_similarity_matrix(feature1, feature2, pad_mask):
        temperature = math.sqrt(feature1.shape[-1])
        similarity_matrix = torch.bmm(
            feature1, feature2.permute(0, 2, 1)) / temperature
        # give a very small value to the padded part for softmax operation
        similarity_matrix[pad_mask] -= 1e9
        s_i = torch.softmax(similarity_matrix, dim=1)  # row softmax
        s_j = torch.softmax(similarity_matrix, dim=-1)  # column softmax
        similarity_matrix = torch.multiply(s_i, s_j)

        return similarity_matrix

    def forward(self, A, B, pad_mask):
        for i in range(0, len(self.blocks), 2):
            # 更新A：使用B作为key和value
            A = self.blocks[i](A, B, B)
            # 更新B：使用A作为key和value
            B = self.blocks[i+1](B, A, A)

        # 计算相似度矩阵
        sim_matrix = self.get_similarity_matrix(A, B, pad_mask)
        return sim_matrix


class MatchingNet(nn.Module):
    def __init__(self, feature_extract: nn.Module, fuse: nn.Module, decoder: nn.Module, classify=None):
        super().__init__()
        self.feature_extract = feature_extract
        self.fuse = fuse
        self.decoder = decoder
        self.classify = classify
    
    def matching_forward(self, inputs1: dict, inputs2: dict, pad_mask: torch.Tensor):
        f1_c, f1_t = self.feature_extract(**inputs1)
        f2_c, f2_t = self.feature_extract(**inputs2)

        f1 = self.fuse(f1_c, f1_t)
        f2 = self.fuse(f2_c, f2_t)
        
        output = self.decoder(f1, f2, pad_mask)
        return output
    
    def classify_forward(self, preds, length):
        scores = self.classify(preds, length)
        return scores

    def forward(self, *args):
        if len(args) == 3:
            return self.matching_forward(*args)
        elif len(args) == 1:
            return self.classify_forward(args[0], None)
        else:
            return self.classify_forward(*args)


if __name__ == "__main__":
    import argparse

    from config.default import cfg
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        default="./config/matching_pairingnet/1000.yaml"
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_path)

    inputs1 = torch.rand((10, 256, 64))
    inputs2 = torch.rand((10, 256, 64))
    model = AttentionFuse()
    outputs = model(inputs1, inputs2)
    print(outputs.shape)
