import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from Aug import Aug
from Net import Net
from Fusion import Fusion


class ARFM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.backbone = Net()
        self.backbone.load_state_dict(torch.load('params/best4.pth',map_location='cpu'))
        self.backbone.features = self.backbone.features
        self.layer1 = self.backbone.features[:4]  # b,64,224,224
        self.layer2 = self.backbone.features[4:9]  # b,128,112,112
        self.layer3 = self.backbone.features[9:16]  # b,256,56,56
        self.layer4 = self.backbone.features[16:23]  # b,512,28,28
        self.layer5 = self.backbone.features[23:28]  # b,512,14,14
        self.aug3_2 = Aug(256, 128, channels)  # f_inchannels,r_inchannels, out_channels
        self.aug4_3 = Aug(512, channels, channels)
        self.aug5_4 = Aug(512, channels, channels)
        self.fusion = Fusion(channels, 3)
        #整体分类   冻结参数不更新
        self.backbone.cls_conv = self.backbone.cls_conv.requires_grad_(False)
        self.backbone.cls_avgpool = self.backbone.cls_avgpool.requires_grad_(False)
        self.backbone.cls_flatten = self.backbone.cls_flatten
        self.backbone.cls_classifier = self.backbone.cls_classifier

    def forward(self, x, boxes):
        f1 = self.layer1(x)  # 1,64,224,224
        f2 = self.layer2(f1)  # 1,128,112,112
        f3 = self.layer3(f2)  # 1,256,56,56
        f4 = self.layer4(f3)  # 1,512,28,28
        f5 = self.layer5(f4)  # 1,512,14,14

        r2 = f2  # 1,128,112,112
        r3 = self.aug3_2(f3, r2)  # 1,128,56,56
        r4 = self.aug4_3(f4, r3)  # 1,128,28,28
        r5 = self.aug5_4(f5, r4)  # 1,128,14,14

        out_det, out_cls = self.fusion(r2, r3, r4, r5, boxes)  # fusion(r2,r3,r4,r5,boxes)---->batch,3

        det_scores = F.softmax(out_det, dim=0)
        cls_scores = F.softmax(out_cls, dim=1)

        #整体分类
        whole=self.backbone.cls_conv(f5)
        whole=self.backbone.cls_avgpool(whole)
        whole=self.backbone.cls_flatten(whole)
        whole=self.backbone.cls_classifier(whole)
        return det_scores, cls_scores,whole

    @staticmethod
    def calculate_loss(combined_scores, target):
        image_level_scores = torch.sum(combined_scores, dim=0)  # 将图片的得分值按照列相加，得到和标签维度大小一致的图像级得分
        image_level_scores = torch.clamp(image_level_scores, min=0.0,
                                         max=1.0)  # [3,]  将输入input张量每个元素的值压缩到区间 [min,max]，并返回结果到一个新张量
        loss = F.binary_cross_entropy(image_level_scores, target, reduction="sum")
        return loss


if __name__ == '__main__':
    loss_=nn.BCELoss()
    img = torch.randn(1, 3, 224, 224)
    boxes = torch.randn(1, 7, 4)
    net = ARFM(128)
    det, cls,whole = net(img, boxes)
    scores = det * cls
    label = torch.tensor([[1, 0, 0]], dtype=torch.float32)
    loss = ARFM.calculate_loss(scores,label[0])+loss_(whole,label)
    print(loss)
