import torch
import torch.nn as nn
import torch.nn.functional as F

class FPNEncoder(nn.Module):
    def __init__(self, feat_chs, norm_type='BN'):
        super(FPNEncoder, self).__init__()
        self.conv00 = Conv2d(3, feat_chs[0], 7, 1, padding=3, norm_type=norm_type)
        self.conv01 = Conv2d(feat_chs[0], feat_chs[0], 5, 1, padding=2, norm_type=norm_type)

        self.downsample1 = Conv2d(feat_chs[0], feat_chs[1], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv10 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)
        self.conv11 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)

        self.downsample2 = Conv2d(feat_chs[1], feat_chs[2], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv20 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)
        self.conv21 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)

        self.downsample3 = Conv2d(feat_chs[2], feat_chs[3], 3, stride=2, padding=1, norm_type=norm_type)
        self.conv30 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)
        self.conv31 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)

    def forward(self, x):
        conv00 = self.conv00(x)
        conv01 = self.conv01(conv00)
        down_conv0 = self.downsample1(conv01)
        conv10 = self.conv10(down_conv0)
        conv11 = self.conv11(conv10)
        down_conv1 = self.downsample2(conv11)
        conv20 = self.conv20(down_conv1)
        conv21 = self.conv21(conv20)
        down_conv2 = self.downsample3(conv21)
        conv30 = self.conv30(down_conv2)
        conv31 = self.conv31(conv30)

        return [conv01, conv11, conv21, conv31]
    
class FPNEncoder5(nn.Module):
    def __init__(self, feat_chs, norm_type='BN'):
        super(FPNEncoder5, self).__init__()
        self.conv00 = Conv2d(3, feat_chs[0], 7, 1, padding=3, norm_type=norm_type)
        self.conv01 = Conv2d(feat_chs[0], feat_chs[0], 5, 1, padding=2, norm_type=norm_type)

        self.downsample1 = Conv2d(feat_chs[0], feat_chs[1], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv10 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)
        self.conv11 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)

        self.downsample2 = Conv2d(feat_chs[1], feat_chs[2], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv20 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)
        self.conv21 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)

        self.downsample3 = Conv2d(feat_chs[2], feat_chs[3], 3, stride=2, padding=1, norm_type=norm_type)
        self.conv30 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)
        self.conv31 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)


        self.downsample4 = Conv2d(feat_chs[3], feat_chs[4], 3, stride=2, padding=1, norm_type=norm_type)
        self.conv40 = Conv2d(feat_chs[4], feat_chs[4], 3, 1, padding=1, norm_type=norm_type)
        self.conv41 = Conv2d(feat_chs[4], feat_chs[4], 3, 1, padding=1, norm_type=norm_type)
    def forward(self, x):
        conv00 = self.conv00(x)
        conv01 = self.conv01(conv00)
        down_conv0 = self.downsample1(conv01)
        conv10 = self.conv10(down_conv0)
        conv11 = self.conv11(conv10)
        down_conv1 = self.downsample2(conv11)
        conv20 = self.conv20(down_conv1)
        conv21 = self.conv21(conv20)
        down_conv2 = self.downsample3(conv21)
        conv30 = self.conv30(down_conv2)
        conv31 = self.conv31(conv30)
        down_conv3 = self.downsample4(conv31)
        conv40 = self.conv40(down_conv3)
        conv41 = self.conv41(conv40)
        return [conv01, conv11, conv21, conv31, conv41]

from romatch.models.mvsformer.RFB_block import RFB_modified
class FPNEncoder_RFB(nn.Module):
    def __init__(self, feat_chs, norm_type='BN'):
        super(FPNEncoder_RFB, self).__init__()
        self.conv00 = Conv2d(3, feat_chs[0], 7, 1, padding=3, norm_type=norm_type)
        self.conv01 = Conv2d(feat_chs[0], feat_chs[0], 5, 1, padding=2, norm_type=norm_type)
        self.rfb01 = RFB_modified(feat_chs[0], feat_chs[0])
        
        self.downsample1 = Conv2d(feat_chs[0], feat_chs[1], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv10 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)
        self.conv11 = Conv2d(feat_chs[1], feat_chs[1], 3, 1, padding=1, norm_type=norm_type)
        self.rfb11 = RFB_modified(feat_chs[1], feat_chs[1])

        self.downsample2 = Conv2d(feat_chs[1], feat_chs[2], 5, stride=2, padding=2, norm_type=norm_type)
        self.conv20 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)
        self.conv21 = Conv2d(feat_chs[2], feat_chs[2], 3, 1, padding=1, norm_type=norm_type)
        self.rfb21 = RFB_modified(feat_chs[2], feat_chs[2])
        
        self.downsample3 = Conv2d(feat_chs[2], feat_chs[3], 3, stride=2, padding=1, norm_type=norm_type)
        self.conv30 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)
        self.conv31 = Conv2d(feat_chs[3], feat_chs[3], 3, 1, padding=1, norm_type=norm_type)
        self.rfb31 = RFB_modified(feat_chs[3], feat_chs[3])
        
    def forward(self, x):
        conv00 = self.conv00(x)
        conv01 = self.conv01(conv00)
        conv01 = self.rfb01(conv01)
        
        down_conv0 = self.downsample1(conv01)
        conv10 = self.conv10(down_conv0)
        conv11 = self.conv11(conv10)
        conv11 = self.rfb11(conv11)
        
        down_conv1 = self.downsample2(conv11)
        conv20 = self.conv20(down_conv1)
        conv21 = self.conv21(conv20)
        conv21 = self.rfb21(conv21)
        
        down_conv2 = self.downsample3(conv21)
        conv30 = self.conv30(down_conv2)
        conv31 = self.conv31(conv30)
        conv31 = self.rfb31(conv31)

        return [conv01, conv11, conv21, conv31]
    
class FPNDecoder(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[3], kernel_size=1), nn.BatchNorm2d(feat_chs[3]), Swish())

        self.inner1 = nn.Conv2d(feat_chs[2], final_ch, 1)
        self.out1 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[2]), Swish())

        self.inner2 = nn.Conv2d(feat_chs[1], final_ch, 1)
        self.out2 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[1], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[1]), Swish())

        self.inner3 = nn.Conv2d(feat_chs[0], final_ch, 1)
        self.out3 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[0], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[0]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv21)
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv11)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv01)
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]
    
class FPNDecoder_my(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder_my, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[3], kernel_size=1), nn.BatchNorm2d(feat_chs[3]), Swish())

        self.inner1 = nn.Conv2d(feat_chs[2], final_ch, 1)
        self.inner1_vit = nn.Conv2d(final_ch, final_ch, 1)
        self.out1 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[2]), Swish())

        self.inner2 = nn.Conv2d(feat_chs[1], final_ch, 1)
        self.inner2_vit = nn.Conv2d(final_ch, final_ch, 1)
        self.out2 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[1], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[1]), Swish())

        self.inner3 = nn.Conv2d(feat_chs[0], final_ch, 1)
        self.inner3_vit = nn.Conv2d(final_ch, final_ch, 1)
        self.out3 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[0], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[0]), Swish())

    def forward(self, conv01, conv11, conv21, conv31, vit_feature):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False) +\
                     self.inner1(conv21) +\
                     F.interpolate(self.inner1_vit(vit_feature), size=conv21.shape[2:], mode="bilinear", align_corners=False)
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False) +\
                    self.inner2(conv11) +\
                     F.interpolate(self.inner2_vit(vit_feature), size=conv11.shape[2:], mode="bilinear", align_corners=False)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False) +\
                     self.inner3(conv01) +\
                     F.interpolate(self.inner3_vit(vit_feature), size=conv01.shape[2:], mode="bilinear", align_corners=False)
                        
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]
    
class FPNDecoder_concat(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder_concat, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-2]), Swish())
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-4]), Swish())
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = conv21 + self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1))
        out1 = self.out1(intra_feat)


        intra_feat = conv11 + self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1))
        out2 = self.out2(intra_feat)

        intra_feat = conv01 + self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]

class FPNDecoder5_concat(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder5_concat, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-2]), Swish())
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-4]), Swish())
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

        self.inner4 = nn.Sequential(nn.Conv2d(feat_chs[-4]+feat_chs[-5], feat_chs[-5], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-5]), Swish())
        self.out4 = nn.Sequential(nn.Conv2d(feat_chs[-5], feat_chs[-5], 1), nn.BatchNorm2d(feat_chs[-5]), Swish())
    def forward(self, conv01, conv11, conv21, conv31, conv41):
        intra_feat = conv41
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = conv31 + self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv31), dim=1))
        out1 = self.out1(intra_feat)


        intra_feat = conv21 + self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1))
        out2 = self.out2(intra_feat)

        intra_feat = conv11 + self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1))
        out3 = self.out3(intra_feat)

        intra_feat = conv01 + self.inner4(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1))
        out4 = self.out4(intra_feat)
        return [out0, out1, out2, out3, out4]
    

class FPNDecoder_concat_RFB(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder_concat_RFB, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = RFB_modified(feat_chs[-1]+feat_chs[-2], feat_chs[-2])
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = RFB_modified(feat_chs[-2]+feat_chs[-3], feat_chs[-3])        
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = RFB_modified(feat_chs[-3]+feat_chs[-4], feat_chs[-4])
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1))
        out1 = self.out1(intra_feat)


        intra_feat = self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1))
        out2 = self.out2(intra_feat)

        intra_feat = self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]
class FPNDecoder_concat_doubleconv_noadd(nn.Module): 
    def __init__(self, feat_chs):
        super(FPNDecoder_concat_doubleconv_noadd, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-2]), Swish(),
                                    nn.Conv2d(feat_chs[-2], feat_chs[-2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-2]), Swish(),)
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-3]), Swish(),
                                    nn.Conv2d(feat_chs[-3], feat_chs[-3], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-4]), Swish(),
                                    nn.Conv2d(feat_chs[-4], feat_chs[-4], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-4]), Swish())
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1))
        out1 = self.out1(intra_feat)


        intra_feat = self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1))
        out2 = self.out2(intra_feat)

        intra_feat = self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]

class FPNDecoder_concat2(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder_concat2, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], kernel_size=1), nn.BatchNorm2d(feat_chs[-2]), Swish())
        self.inner1_smooth = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-2]), Swish())
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], kernel_size=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        self.inner2_smooth = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], kernel_size=1), nn.BatchNorm2d(feat_chs[-4]), Swish())
        self.inner3_smooth = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], kernel_size=3, padding=1), nn.BatchNorm2d(feat_chs[-4]), Swish())        
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = self.inner1_smooth(conv21 + self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1)))
        out1 = self.out1(intra_feat)


        intra_feat = self.inner2_smooth(conv11 + self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1)))
        out2 = self.out2(intra_feat)

        intra_feat = self.inner3_smooth(conv01 + self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1)))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]

class FPNDecoder_concat_light(nn.Module):
    def __init__(self, feat_chs):
        super(FPNDecoder_concat_light, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], kernel_size=1), nn.BatchNorm2d(feat_chs[-2]), Swish())
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], kernel_size=1), nn.BatchNorm2d(feat_chs[-3]), Swish())
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], kernel_size=1), nn.BatchNorm2d(feat_chs[-4]), Swish())
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())

    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = conv21 + self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1))
        out1 = self.out1(intra_feat)


        intra_feat = conv11 + self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1))
        out2 = self.out2(intra_feat)

        intra_feat = conv01 + self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]
    
class FPNDecoder_concat_light_dw(nn.Module):
    def __init__(self, feat_chs, num_dw):
        super(FPNDecoder_concat_light_dw, self).__init__()
        final_ch = feat_chs[-1]
        self.out0 = nn.Sequential(nn.Conv2d(final_ch, feat_chs[-1], kernel_size=1), nn.BatchNorm2d(feat_chs[-1]), Swish())

        self.inner1 = nn.Sequential(*[self.create_block(feat_chs[-1]+feat_chs[-2], feat_chs[-1]+feat_chs[-2], dw=True, kernel_size=5) for _ in range(num_dw)])
        self.inner11 = nn.Conv2d(feat_chs[-1]+feat_chs[-2], feat_chs[-2], 1, 1, 0)
        self.out1 = nn.Sequential(nn.Conv2d(feat_chs[-2], feat_chs[-2], 1), nn.BatchNorm2d(feat_chs[-2]), Swish())

        self.inner2 = nn.Sequential(*[self.create_block(feat_chs[-2]+feat_chs[-3], feat_chs[-2]+feat_chs[-3], dw=True, kernel_size=5) for _ in range(num_dw)])
        self.inner22 = nn.Conv2d(feat_chs[-2]+feat_chs[-3], feat_chs[-3], 1, 1, 0)  
        self.out2 = nn.Sequential(nn.Conv2d(feat_chs[-3], feat_chs[-3], 1), nn.BatchNorm2d(feat_chs[-3]), Swish())

        self.inner3 = nn.Sequential(*[self.create_block(feat_chs[-3]+feat_chs[-4], feat_chs[-3]+feat_chs[-4], dw=True, kernel_size=5) for _ in range(num_dw)])
        self.inner33 = nn.Conv2d(feat_chs[-3]+feat_chs[-4], feat_chs[-4], 1, 1, 0) 
        self.out3 = nn.Sequential(nn.Conv2d(feat_chs[-4], feat_chs[-4], 1), nn.BatchNorm2d(feat_chs[-4]), Swish())
    def create_block(
        self,
        in_dim,
        out_dim,
        dw=False,
        kernel_size=5,
        bias = True,
        norm_type = nn.BatchNorm2d,
    ):
        num_groups = 1 if not dw else in_dim
        if dw:
            assert (
                out_dim % in_dim == 0
            ), "outdim must be divisible by indim for depthwise"
        conv1 = nn.Conv2d(
            in_dim,
            out_dim,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            groups=num_groups,
            bias=bias,
        )
        norm = norm_type(out_dim, momentum = 0.1) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
        relu = nn.ReLU(inplace=True)
        conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
        return nn.Sequential(conv1, norm, relu, conv2)
    def forward(self, conv01, conv11, conv21, conv31):
        intra_feat = conv31
        out0 = self.out0(intra_feat)

        # conv21.shape[2:]
        intra_feat = conv21 + self.inner11(self.inner1(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv21), dim=1)))
        out1 = self.out1(intra_feat)


        intra_feat = conv11 + self.inner22(self.inner2(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv11), dim=1)))
        out2 = self.out2(intra_feat)

        intra_feat = conv01 + self.inner33(self.inner3(torch.cat((F.interpolate(intra_feat.to(torch.float32), scale_factor=2, mode="bilinear", align_corners=False), conv01), dim=1)))
        out3 = self.out3(intra_feat)

        return [out0, out1, out2, out3]
def init_bn(module):
    if module.weight is not None:
        nn.init.ones_(module.weight)
    if module.bias is not None:
        nn.init.zeros_(module.bias)
    return


def init_uniform(module, init_method):
    if module.weight is not None:
        if init_method == "kaiming":
            nn.init.kaiming_uniform_(module.weight)
        elif init_method == "xavier":
            nn.init.xavier_uniform_(module.weight)
    return

class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)
    
class Conv2d(nn.Module):
    """Applies a 2D convolution (optionally with batch normalization and relu activation)
    over an input signal composed of several input planes.

    Attributes:
        conv (nn.Module): convolution module
        bn (nn.Module): batch normalization module
        relu (bool): whether to activate by relu

    Notes:
        Default momentum for batch normalization is set to be 0.01,

    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, norm_type='IN', **kwargs):
        super(Conv2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs)
        self.kernel_size = kernel_size
        self.stride = stride
        if norm_type == 'IN':
            self.bn = nn.InstanceNorm2d(out_channels, momentum=bn_momentum) if bn else None
        elif norm_type == 'BN':
            self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

    def forward(self, x):
        y = self.conv(x)
        if self.bn is not None:
            y = self.bn(y)
        if self.relu:
            y = F.leaky_relu(y, 0.1, inplace=True)
        return y

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)