import torch
from torch import nn
from models.deeplab.deeplab_ori.backbone.resnet import ResNet101_multiscale,ResNet50_multiscale,Bottleneck
from models.shared import conv_block,up_conv
from models.layers import ASPP,PPM
from torch.nn import BatchNorm2d
import torch.nn.functional as F
from .shared import conv_block

def make_layer(block, inplanes, planes, blocks, stride=1, dilation=1, BatchNorm=nn.BatchNorm2d):
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(inplanes, planes * block.expansion,
                        kernel_size=1, stride=stride, bias=False),
            BatchNorm(planes * block.expansion),
        )

    layers = []
    layers.append(block(inplanes, planes, stride, dilation, downsample, BatchNorm))
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(block(inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))

    return nn.Sequential(*layers)


class Decoder_Res(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256, BatchNorm = nn.BatchNorm2d):
        super().__init__()
    
        self.conv1 = nn.Conv2d(low_level_inplanes, low_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BatchNorm(low_dim)
        self.relu = nn.ReLU()
                                    
        self.last_conv = nn.Sequential(nn.Conv2d(high_level_inplanes+low_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(out_dim),
                                       nn.ReLU(),
                                       )
        
        self.last_bottle = make_layer(Bottleneck,inplanes=out_dim, planes=out_dim, blocks=2, stride=1, dilation=1, BatchNorm=nn.BatchNorm2d)
        self.last_conv2 = nn.Sequential(
                                nn.Dropout(0.5),
                                nn.Conv2d(Bottleneck.expansion*out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                BatchNorm(out_dim),
                                nn.ReLU(),
                                nn.Dropout(0.1),
                                nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv(x)
        x = self.last_bottle(mid)
        x = self.last_conv2(x)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            # elif isinstance(m, nn.BatchNorm2d):
            #     m.weight.data.fill_(1)
            #     m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()



class DeepLabv3p_boost_decoderRes(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder_Res(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=64,low_level_inplanes = filters[3])

        self.Pred3 = Decoder_Res(out_ch,high_level_inplanes = 64,low_dim=256,out_dim=32,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1


class Decoder(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False, BatchNorm = nn.BatchNorm2d):
        super(Decoder, self).__init__()
    
        self.conv1 = nn.Conv2d(low_level_inplanes, low_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BatchNorm(low_dim)
        self.relu = nn.ReLU()
                                    
        self.last_conv = nn.Sequential(nn.Conv2d(high_level_inplanes+low_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(out_dim),
                                       nn.ReLU(),
                                       )
        if rm_BN:
            self.last_conv2 = nn.Sequential(
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    nn.PReLU(),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        else:
            self.last_conv2 = nn.Sequential(
                                    nn.Dropout(0.5),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Dropout(0.1),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv(x)
        x = self.last_conv2(mid)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            # elif isinstance(m, nn.BatchNorm2d):
            #     m.weight.data.fill_(1)
            #     m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class DeepLabv3p_boostno(nn.Module):
    def __init__(self,in_ch,num_classes,_decoder=Decoder,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred3 = _decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e3,pred3 = self.Pred3(aspp_e5,e3)

        return pred3, pred3, pred3 #, pred2, pred1



class SeparableConv2D(nn.Module):
    """
    for ASPP
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True,
                 use_batchnorm=True,
                 norm_fn=nn.BatchNorm2d):
        super(SeparableConv2D, self).__init__()
        self.use_bn = use_batchnorm
        self.dilation = dilation
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,
                                   bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=bias)
        if use_batchnorm:
            self.bn = norm_fn(in_channels)
        nn.init.normal_(self.depthwise.weight, std=0.33)
        nn.init.normal_(self.pointwise.weight, std=0.06)

    def forward(self, x):
        x = self.depthwise(x)
        if self.use_bn:
            x = self.bn(x)
        x = F.relu(x, inplace=True)
        x = self.pointwise(x)
        return x

    
    
class Decoder_FarSeg(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False, BatchNorm = nn.BatchNorm2d):
        super().__init__()
    
        self.conv1x1_low_level = nn.Sequential(
            nn.Conv2d(low_level_inplanes, low_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(low_dim),
            nn.ReLU(),
        )
        
        
        self.conv1x1_high_level = nn.Sequential(
            nn.Conv2d(high_level_inplanes, low_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(low_dim),
            nn.ReLU(),
        )
        
        self._init_weight()
        
        self.last_conv1 = SeparableConv2D(low_dim * 2, out_dim, 3, 1, padding=1, dilation=1,bias=True,use_batchnorm=True,norm_fn=BatchNorm)
        self.last_conv2 = nn.Sequential(
            SeparableConv2D(out_dim, out_dim, 3, 1, padding=1, dilation=1,bias=True,use_batchnorm=True,norm_fn=BatchNorm),
            nn.Conv2d(out_dim,num_classes,1)
                                       )
        

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1x1_low_level(low_level_feat)
        x_high_level = self.conv1x1_high_level(x_high_level)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv1(x)
        x = self.last_conv2(mid)
        return mid,x
    
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            # elif isinstance(m, nn.BatchNorm2d):
            #     m.weight.data.fill_(1)
            #     m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class DeepLabv3p_boostno_FarSeg(DeepLabv3p_boostno):
    def __init__(self,in_ch,num_classes,_decoder=Decoder_FarSeg):
        super().__init__(in_ch,num_classes,_decoder)


class Decoder_att(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False, BatchNorm = nn.BatchNorm2d):
        super().__init__()
        
        self.query = nn.Sequential(nn.Conv2d(high_level_inplanes+low_level_inplanes, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    #    BatchNorm(out_dim),
                                    #    nn.ReLU(),
                                       )


        self.conv_low = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            )


        self.conv_high = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
                )

        self.conv_low_v = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            )


        self.conv_high_v = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
                )


        self.pe = nn.Parameter(torch.randn(1,3,out_dim,1,1))

        if rm_BN:
            self.last_conv2 = nn.Sequential(
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    nn.ReLU(),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        else:
            self.last_conv2 = nn.Sequential(
                                    nn.Dropout(0.5),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Dropout(0.1),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        B,_,H,W = low_level_feat.shape
        query = self.query(
            torch.cat(
                [low_level_feat,F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)],dim = 1
                )
                        )
        low = self.conv_low(low_level_feat)
        high = self.conv_high(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        
        key = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #
        key = key + self.pe

        low = self.conv_low_v(low_level_feat)
        high = self.conv_high_v(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)

        val = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #


        att = torch.einsum('bchw,bachw->bahw',query,key)
        att = F.softmax(att,dim = 1)

        mid = torch.einsum('bahw,bachw->bchw',att,val)
        # mid = att.unsqueeze(dim=2)*val
        # mid = self.last_conv(mid.reshape(B,-1,H,W))
        x = self.last_conv2(mid)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Decoder_sigmoid1(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False, BatchNorm = nn.BatchNorm2d):
        super().__init__()
        
        self.query = nn.Sequential(nn.Conv2d(high_level_inplanes+low_level_inplanes, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    #    BatchNorm(out_dim),
                                    #    nn.ReLU(),
                                       )


        self.conv_low = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),            
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )

        self.conv_high = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),            
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )
        self.conv_low_v = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )


        self.conv_high_v = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )


        self.pe = nn.Parameter(torch.randn(1,3,out_dim,1,1))
        self.last_conv = nn.Sequential(
                                    nn.Conv2d(3*out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1))
        if rm_BN:
            self.last_conv2 = nn.Sequential(
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    nn.ReLU(),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        else:
            self.last_conv2 = nn.Sequential(
                                    nn.Dropout(0.5),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Dropout(0.1),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        B,_,H,W = low_level_feat.shape
        query = self.query(
            torch.cat(
                [low_level_feat,F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)],dim = 1
                )
                        )
        low = self.conv_low(low_level_feat)
        high = self.conv_high(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        
        difference = F.sigmoid(torch.sum(low*high,dim = 1,keepdim=True))
        # key = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #
        # key = key + self.pe

        low = self.conv_low_v(low_level_feat)
        high = self.conv_high_v(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)

        # val = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #


        # att = torch.einsum('bchw,bachw->bahw',query,key)
        # att = F.softmax(att,dim = 1)
        mid = torch.cat([low,difference*low,high],dim = 1) #
        # mid = att.unsqueeze(dim=2)*val
        mid = self.last_conv(mid)
        x = self.last_conv2(mid)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class DeepLabv3p_boost_similar_sigmoid1_cat(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder_sigmoid1(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder_sigmoid1(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1




class Decoder_sigmoid2(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes,low_dim=48,out_dim=256,rm_BN=False, BatchNorm = nn.BatchNorm2d):
        super().__init__()
        
        self.query = nn.Sequential(nn.Conv2d(high_level_inplanes+low_level_inplanes, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    #    BatchNorm(out_dim),
                                    #    nn.ReLU(),
                                       )


        self.conv_low = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),            
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )
        self._conv_low = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),            
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )

        self.conv_high = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),            
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )
        self.conv_low_v = nn.Sequential(
            nn.Conv2d(low_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )


        self.conv_high_v = nn.Sequential(
            nn.Conv2d(high_level_inplanes, out_dim, kernel_size=1, stride=1, padding=0, bias=False),
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1)
            )


        self.pe = nn.Parameter(torch.randn(1,3,out_dim,1,1))
        self.last_conv = nn.Sequential(
                                    nn.Conv2d(3*out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=1, stride=1))
        if rm_BN:
            self.last_conv2 = nn.Sequential(
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    nn.ReLU(),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        else:
            self.last_conv2 = nn.Sequential(
                                    nn.Dropout(0.5),
                                    nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(out_dim),
                                    nn.ReLU(),
                                    nn.Dropout(0.1),
                                    nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        B,_,H,W = low_level_feat.shape
        query = self.query(
            torch.cat(
                [low_level_feat,F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)],dim = 1
                )
                        )
        low = self.conv_low(low_level_feat)
        high = self.conv_high(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        
        difference = F.sigmoid(torch.sum(low*high,dim = 1,keepdim=True))*self._conv_low(low_level_feat)
        # key = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #
        # key = key + self.pe

        low = self.conv_low_v(low_level_feat)
        high = self.conv_high_v(x_high_level)
        high = F.interpolate(high,size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)

        # val = torch.stack([low,low-F.avg_pool2d(low,kernel_size=3,stride = 1,padding=1),low-high],dim = 1) #


        # att = torch.einsum('bchw,bachw->bahw',query,key)
        # att = F.softmax(att,dim = 1)
        mid = torch.cat([low,difference,high],dim = 1) #
        # mid = att.unsqueeze(dim=2)*val
        mid = self.last_conv(mid)
        x = self.last_conv2(mid)
        return mid,x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class DeepLabv3p_boost_similar_sigmoid2_cat(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder_sigmoid2(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder_sigmoid2(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1



class DeepLabv3p_boost_similar(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256,BatchNorm = nn.BatchNorm2d)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1


class DeepLabv3p_boost_similar_PPM(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.ppm = PPM(filters[-1],256,BatchNorm = nn.BatchNorm2d)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        ppm_e5 = self.ppm(e5)

        mid_e5,pred5 = self.Pred5(ppm_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1


class DeepLabv3p_boost_similarA(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=48,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1


class DeepLabv3p_boost_similarB(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre = True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
    
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=768,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=384,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1

class DeepLabv3p_boost_5similar(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        pre = False
        if in_ch==3:
            pre=True
        self.resnet = ResNet50_multiscale(in_ch,output_stride=16,pretrained=pre)
        out_ch = num_classes
        BatchNorm = nn.BatchNorm2d
        self.ts = nn.Parameter(torch.zeros(5),requires_grad=False)
        self.ts[:4] = 0.5
        filters = self.resnet.out_s #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim = 128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim = 64,low_level_inplanes = filters[2])
        
        self.Pred2 = Decoder(out_ch,high_level_inplanes = 64,low_dim=128,out_dim = 32,low_level_inplanes = filters[1])
        

        self.conv_e1 =  nn.Sequential(
                    nn.Conv2d(filters[0], 32, kernel_size=5, stride=1, padding=2,bias=False),
                    BatchNorm(32),
                    nn.ReLU(),
                    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2,bias=False),
                    BatchNorm(32),
                    nn.ReLU(),
                    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2,bias=False),
                    BatchNorm(32),
                    nn.ReLU(),
                            )

        self.Pred1 = Decoder(out_ch,high_level_inplanes = 32,low_dim=32,out_dim = 32,low_level_inplanes = 32)

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        mid_e2,pred2 = self.Pred2(mid_e3,e2)

        e1 =self.conv_e1(e1)
        mid_e1,pred1 = self.Pred1(mid_e2,e1)

        return pred5, pred4, pred3 , pred2, pred1




class UNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, kernel_size=3):
        super(UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.filters = filters

        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

        
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)


        #d1 = self.active(out)
        return e1,e2,e3,e4,e5


class UNet_boost_similar(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        self.unet = UNet(in_ch)
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3),requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.unet.filters #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=128,low_level_inplanes = filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.unet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        return pred5, pred4, pred3 #, pred2, pred1


class UNet_boost_5similar(nn.Module):
    def __init__(self,in_ch,num_classes,**kwargs):
        super().__init__()
        self.unet = UNet(in_ch)
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(5),requires_grad=False)
        self.ts[:4] = 0.5
        
        BatchNorm = nn.BatchNorm2d

        filters = self.unet.filters #[3,64,256,512,2048]
        self.aspp = ASPP(filters[-1],256)
        
        self.Pred5 = Decoder(out_ch,high_level_inplanes = 256,low_dim=256,out_dim=256,low_level_inplanes = filters[4])
        
        self.Pred4 = Decoder(out_ch,high_level_inplanes = 256,low_dim=512,out_dim=128,low_level_inplanes = filters[3])

        self.Pred3 = Decoder(out_ch,high_level_inplanes = 128,low_dim=256,out_dim=64,low_level_inplanes = filters[2])

        self.Pred2 = Decoder(out_ch,high_level_inplanes = 64,low_dim=128,out_dim=32,low_level_inplanes = filters[1])

        self.conv_e1 =  nn.Sequential(
                    nn.Conv2d(filters[0], 32, kernel_size=5, stride=1, padding=2,bias=True),
                    BatchNorm(32),
                    nn.ReLU(),
                    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2,bias=False),
                    BatchNorm(32),
                    nn.ReLU(),
                    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2,bias=False),
                    BatchNorm(32),
                    nn.ReLU(),
                            )

        self.Pred1 = Decoder(out_ch,high_level_inplanes = 32,low_dim=32,out_dim=32,low_level_inplanes = 32)
        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.unet(x) #3,64,256,512,2048
        
        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5,pred5 = self.Pred5(aspp_e5,e5)

        mid_e4,pred4 = self.Pred4(mid_e5,e4)

        mid_e3,pred3 = self.Pred3(mid_e4,e3)

        mid_e2,pred2 = self.Pred2(mid_e3,e2)

        
        mid_e1,pred1 = self.Pred1(mid_e2,self.conv_e1(e1))

        return pred5, pred4, pred3 , pred2, pred1