import torch
import torch.nn as nn
import torch.nn.functional as F
# from models.dcn import DeformableConv2d
from torch.nn import Conv2d as DeformableConv2d


class MDF(nn.Module):
    def __init__(self, in_channels):
        super(MDF, self).__init__()

        channel_1, channel_2, channel_3 = 32, 32, 32
        kernel_size_1, kernel_size_2, kernel_size_3 = 7, 5, 3
        # self.conv1 = DeformableConv2d(in_channels*2, channel_1, kernel_size=kernel_size_1, padding=kernel_size_1//2)
        # self.conv2 = DeformableConv2d(in_channels*2+channel_1, channel_2, kernel_size=kernel_size_2, padding=kernel_size_2//2)
        # self.conv3 = DeformableConv2d(in_channels*2+channel_1+channel_2, channel_3, kernel_size=kernel_size_3, padding=kernel_size_3//2)
        # self.aggregation = DeformableConv2d(in_channels*2+channel_1+channel_2+channel_3, in_channels, kernel_size=1, padding=0)

        self.conv1 = DeformableConv2d(in_channels*2, in_channels*2, kernel_size=kernel_size_1, padding=kernel_size_1//2)
        self.conv2 = DeformableConv2d(in_channels*2, in_channels*2, kernel_size=kernel_size_2, padding=kernel_size_2//2)
        self.conv3 = DeformableConv2d(in_channels*2, in_channels*2, kernel_size=kernel_size_3, padding=kernel_size_3//2)
        self.aggregation = DeformableConv2d(in_channels*2, in_channels, kernel_size=1, padding=0)

        self.bn1 = nn.BatchNorm2d(in_channels*2)
        self.bn2 = nn.BatchNorm2d(in_channels*2)
        self.bn3 = nn.BatchNorm2d(in_channels*2)
        self.bn4 = nn.BatchNorm2d(in_channels)

    def forward(self, rgb_input, depth_input):
        input = torch.cat([rgb_input, depth_input], dim=1)
        # x1 = F.relu(self.bn1(self.conv1(input)))
        # x2 = F.relu(self.bn2(self.conv2(torch.cat([input, x1], dim=1))))
        # x3 = F.relu(self.bn3(self.conv3(torch.cat([input, x1, x2], dim=1))))
        # xg = F.relu(self.bn4(self.aggregation(torch.cat([input, x1, x2, x3], dim=1))))
        # output = rgb_input + xg

        x1 = F.relu(self.bn1(self.conv1(input)))
        x2 = F.relu(self.bn2(self.conv2(input + x1)))
        x3 = F.relu(self.bn3(self.conv3(input + x1 + x2)))
        xg = F.relu(self.bn4(self.aggregation(input + x1 + x2 + x3)))
        output = rgb_input + xg

        return output


if __name__ == '__main__':

    rgb_input = torch.randn(1, 32, 64, 64)  # 输入的RGB图像
    depth_input = torch.randn(1, 32, 64, 64)  # 输入的深度图

    # 创建网络实例
    in_channels = 32  # 输入的通道数（RGB图像）
    multi_scale_net = MDF(in_channels)
    output = multi_scale_net(rgb_input, depth_input)
