"""
@Description :   碎片全局特征匹配网络
@Author      :   tqychy 
@Time        :   2025/02/14 09:57:50
"""
import torch.nn as nn


class PairingNet(nn.Module):
    def __init__(self, feature_extract: nn.Module, fuse: nn.Module):
        super().__init__()
        self.feature_extract = feature_extract
        self.fuse = fuse
    
    def forward(self, inputs: dict):
        f_c, f_t = self.feature_extract(**inputs)

        output = self.fuse(f_c, f_t)
        output_flattened = output.view(output.shape[0], -1)

        return output_flattened

