import torch
from torch import nn
from torchvision.ops import roi_align


class Conv3_3(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=False):
        super(Conv3_3, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)

    def forward(self, x):
        x = self.conv(x)
        return x


class Full_Connection(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Full_Connection, self).__init__()
        self.fc = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        x = self.fc(x)
        return x


class Fusion(nn.Module):
    def __init__(self, inchannels, cls):
        super(Fusion, self).__init__()
        self.outputsize = (7, 7)
        self.conv_R2 = Conv3_3(inchannels, cls, 3, 1, 1, False)
        self.fc_R2 = Full_Connection(inchannels * 7 * 7, inchannels * 7 * 7)

        self.conv_R3 = Conv3_3(inchannels, cls, 3, 1, 1, False)
        self.fc_R3 = Full_Connection(inchannels * 7 * 7, inchannels * 7 * 7)

        self.conv_R4 = Conv3_3(inchannels, cls, 3, 1, 1, False)
        self.fc_R4 = Full_Connection(inchannels * 7 * 7, inchannels * 7 * 7)

        self.conv_R5 = Conv3_3(inchannels, cls, 3, 1, 1, False)
        self.fc_R5 = Full_Connection(inchannels * 7 * 7, inchannels * 7 * 7)

        self.fc = Full_Connection(inchannels * 7 * 7, 3)
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, r2, r3, r4, r5, boxes):
        boxes = [boxes[0]]
        # 处理r2   2
        r2 = roi_align(r2, boxes, self.outputsize, 1.0 / 2)
        r2_cls = self.conv_R2(r2)  # 分类
        r2_det = self.fc_R2(r2.view(len(boxes[0]), -1))  # 检测

        # 处理r3   4
        r3 = roi_align(r3, boxes, self.outputsize, 1.0 / 4)
        r3_cls = self.conv_R3(r3)  # 分类
        r3_det = self.fc_R3(r3.view(len(boxes[0]), -1))  # 检测

        # 处理r4   8
        r4 = roi_align(r4, boxes, self.outputsize, 1.0 / 8)
        r4_cls = self.conv_R4(r4)  # 分类
        r4_det = self.fc_R4(r4.view(len(boxes[0]), -1))  # 检测

        # 处理r5   16
        r5 = roi_align(r5, boxes, self.outputsize, 1.0 / 16)
        r5_cls = self.conv_R5(r5)  # 分类
        r5_det = self.fc_R5(r5.view(len(boxes[0]), -1))  # 检测

        r_cls = r2_cls + r3_cls + r4_cls + r5_cls
        r_det = r2_det + r3_det + r4_det + r5_det

        r_cls = self.gap(r_cls)

        return self.fc(r_det), r_cls.view(len(boxes[0]), -1)

# if __name__ == '__main__':
#     r2=torch.randn(1,128,112,112)
#     r3=torch.randn(1,128,56,56)
#     r4=torch.randn(1,128,28,28)
#     r5=torch.randn(1,128,14,14)
#     boxes=torch.randn(1,7,4)
#     fusion=Fusion(128,3)
#     res1,res2=fusion(r2,r3,r4,r5,boxes)
#     print('aaa')
