import torch
from torch import nn
import torch.nn.functional as F

class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch,ks = 3):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True),
            nn.Conv2d(in_ch, out_ch, kernel_size=ks,
                      stride=1, padding=ks//2, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(conv_block, self).__init__()

        padding = kernel_size // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x) 
        return x



class FGlo(nn.Module):
    """
    the FGlo class is employed to refine the joint feature of both local feature and surrounding context.
    """
    def __init__(self, channel, reduction=16):
        super(FGlo, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(channel // reduction, channel),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y



    
class UNet_boost_feacat(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super(UNet_boost_feacat, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)
        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)
        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)
        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)

        
        self.Up5 = up_conv(filters[4], filters[3],ks=1)
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3], kernel_size=1)

        self.Up4 = up_conv(filters[3], filters[2],ks=1)
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2], kernel_size=1)

        self.Up3 = up_conv(filters[2], filters[1],ks=1)
        self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1], kernel_size=1)

        self.Up2 = up_conv(filters[1], filters[0],ks=1)
        self.Up_conv2 = conv_block(filters[0]+filters[0], filters[0], kernel_size=1)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        
        d5 = self.Up5(e5)
        rate = F.softmax(self.Uppool(c6),dim = 1).max(dim=1,keepdim=True)[0]
        d5 = torch.cat((e4*(1-rate), d5*rate), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        d4 = self.Up4(d5)
        rate = F.softmax(self.Uppool(c5),dim = 1).max(dim=1,keepdim=True)[0]
        d4 = torch.cat((e3*(1-rate), d4*rate), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        d3 = self.Up3(d4)
        rate = F.softmax(self.Uppool(c4),dim = 1).max(dim=1,keepdim=True)[0]
        d3 = torch.cat((e2*(1-rate), rate*d3), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        d2 = self.Up2(d3)
        rate = F.softmax(self.Uppool(c3),dim = 1).max(dim=1,keepdim=True)[0]
        d2 = torch.cat((e1*(1-rate), rate*d2), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return c6,c5,c4,c3,c2

    
    

class UNet_boost(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 1):
        super(UNet_boost, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)
        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)
        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)
        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)

        
        # self.Up5 = up_conv(filters[4], filters[3],ks=3)
        self.Up_conv5 = conv_block(filters[3]+out_ch, filters[3], kernel_size=decoder_ks)

        # self.Up4 = up_conv(filters[3], filters[2],ks=3)
        self.Up_conv4 = conv_block(filters[2]+out_ch, filters[2], kernel_size=decoder_ks)

        # self.Up3 = up_conv(filters[2], filters[1],ks=3)
        self.Up_conv3 = conv_block(filters[1]+out_ch, filters[1], kernel_size=decoder_ks)

        # self.Up2 = up_conv(filters[1], filters[0],ks=3)
        self.Up_conv2 = conv_block(filters[0]+out_ch, filters[0], kernel_size=decoder_ks)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        
        # d5 = self.Up5(e5)
        d5 = torch.cat((e4, self.Uppool(c6)), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        # d4 = self.Up4(d5)
        d4 = torch.cat((e3, self.Uppool(c5)), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        # d3 = self.Up3(d4)
        d3 = torch.cat((e2, self.Uppool(c4)), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        # d2 = self.Up2(d3)
        d2 = torch.cat((e1, self.Uppool(c3)), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return c6,c5,c4,c3,c2

    

class UNet_boost_3decoderks(UNet_boost):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3):
        super(UNet_boost_3decoderks, self).__init__(in_ch,out_ch,kernel_size,decoder_ks)

class UNet_boost_Scale(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super().__init__()
        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])
        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])
        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])
        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])
        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])
        self.Pred5 = nn.Sequential(
            conv_block(filters[4], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        self.Pred4 = nn.Sequential(
            conv_block(filters[3], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        self.Pred3 = nn.Sequential(
            conv_block(filters[2], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        self.Pred2 = nn.Sequential(
            conv_block(filters[1], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        self.Pred1 = nn.Sequential(
            conv_block(filters[0], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
            )
        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1 = self.Conv1(x)
        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        sig = 10
        pred5 = self.Pred5(e5)
        c = F.interpolate(pred5.softmax(dim=1), scale_factor=2, mode='bilinear', align_corners=True)
        c = c.topk(k=2, dim=1)[0]  # certainty
        c = (c[:, 0, :, :] - c[:, 1, :, :]).unsqueeze(dim=1)
        c = ((c - 0.5) * sig).sigmoid()
        # c = 1 - Categorical(probs=pred5.softmax(dim=1).permute(0,2,3,1)[:, :, :, 1:5]).entropy().unsqueeze(dim=1) / 1.387  # B, 1, H, W
        d5 = self.Up5(e5)
        d5 = torch.cat((e4 * (2 - c), d5 * (1 + c)), dim=1)
        d5 = self.Up_conv5(d5)
        pred4 = self.Pred4(d5)
        c = F.interpolate(pred4.softmax(dim=1), scale_factor=2, mode='bilinear', align_corners=True)
        c = c.topk(k=2, dim=1)[0]  # ce
        c = (c[:, 0, :, :] - c[:, 1, :, :]).unsqueeze(dim=1)
        c = ((c - 0.5) * sig).sigmoid()
        # c = 1 - Categorical(probs=pred5.softmax(dim=1).permute(0,2,3,1)[:, :, :, 1:5]).entropy().unsqueeze(dim=1) / 1.387  # B, 1, H, W
        d4 = self.Up4(d5)
        d4 = torch.cat((e3 * (2 - c), d4 * (1 + c)), dim=1)
        d4 = self.Up_conv4(d4)
        pred3 = self.Pred3(d4)
        c = F.interpolate(pred3.softmax(dim=1), scale_factor=2, mode='bilinear', align_corners=True)
        c = c.topk(k=2, dim=1)[0]  # ce
        c = (c[:, 0, :, :] - c[:, 1, :, :]).unsqueeze(dim=1)
        c = ((c - 0.5) * sig).sigmoid()
        # c = 1 - Categorical(probs=pred5.softmax(dim=1).permute(0,2,3,1)[:, :, :, 1:5]).entropy().unsqueeze(dim=1) / 1.387  # B, 1, H, W
        d3 = self.Up3(d4)
        d3 = torch.cat((e2 * (2 - c), d3 * (1 + c)), dim=1)
        d3 = self.Up_conv3(d3)
        pred2 = self.Pred2(d3)
        c = F.interpolate(pred2.softmax(dim=1), scale_factor=2, mode='bilinear', align_corners=True)
        c = c.topk(k=2, dim=1)[0]  # ce
        c = (c[:, 0, :, :] - c[:, 1, :, :]).unsqueeze(dim=1)
        c = ((c - 0.5) * sig).sigmoid()
        # c = 1 - Categorical(probs=pred5.softmax(dim=1).permute(0,2,3,1)[:, :, :, 1:5]).entropy().unsqueeze(dim=1) / 1.387  # B, 1, H, W)
        d2 = self.Up2(d3)
        d2 = torch.cat((e1 * (2 - c), d2 * (1 + c)), dim=1)
        d2 = self.Up_conv2(d2)
        pred1 = self.Pred1(d2)
        # d1 = self.active(out)
        return pred5,pred4,pred3,pred2,pred1
    


class UNet_boost_big(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)
        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)
        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)
        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)

        
        self.Up5 = up_conv(filters[4], filters[3],ks=3)
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3], kernel_size=decoder_ks)

        self.Up4 = up_conv(filters[3], filters[2],ks=3)
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2], kernel_size=decoder_ks)

        self.Up3 = up_conv(filters[2], filters[1],ks=3)
        self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1], kernel_size=decoder_ks)

        self.Up2 = up_conv(filters[1], filters[0],ks=3)
        self.Up_conv2 = conv_block(filters[0]+filters[0], filters[0], kernel_size=decoder_ks)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return c6,c5,c4,c3,c2



class UNet_boost_big_FGlo(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)
        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)
        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)
        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)

        
        self.Up5 = up_conv(filters[4], filters[3],ks=3)
        self.Up_conv5 = nn.Sequential(
            FGlo(filters[3]+filters[3]),
            conv_block(filters[3]+filters[3], filters[3], kernel_size=decoder_ks))

        self.Up4 = up_conv(filters[3], filters[2],ks=3)
        self.Up_conv4 = nn.Sequential(
            FGlo(filters[2]+filters[2]),
            conv_block(filters[2]+filters[2], filters[2], kernel_size=decoder_ks))

        self.Up3 = up_conv(filters[2], filters[1],ks=3)
        self.Up_conv3 = nn.Sequential(
            FGlo(filters[1]+filters[1]),
            conv_block(filters[1]+filters[1], filters[1], kernel_size=decoder_ks))

        self.Up2 = up_conv(filters[1], filters[0],ks=3)
        self.Up_conv2 = nn.Sequential(
            FGlo(filters[0]+filters[0]),
            conv_block(filters[0]+filters[0], filters[0], kernel_size=decoder_ks))

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return c6,c5,c4,c3,c2


    

class UNet_boost_big_cla(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Sequential(
            nn.Conv2d(filters[4],filters[0],kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(filters[0],out_ch,kernel_size=1))
        self.Conv5_cla = nn.Sequential(
            nn.Conv2d(filters[3],filters[0],kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(filters[0],out_ch,kernel_size=1))
        self.Conv4_cla = nn.Sequential(
            nn.Conv2d(filters[2],filters[0],kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(filters[0],out_ch,kernel_size=1)
            )
        self.Conv3_cla = nn.Sequential(
            nn.Conv2d(filters[1],filters[0],kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(filters[0],out_ch,kernel_size=1))
        self.Conv2_cla = nn.Sequential(
            nn.Conv2d(filters[0],filters[0],kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(filters[0],out_ch,kernel_size=1))

        
        self.Up5 = up_conv(filters[4], filters[3],ks=3)
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3], kernel_size=decoder_ks)

        self.Up4 = up_conv(filters[3], filters[2],ks=3)
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2], kernel_size=decoder_ks)

        self.Up3 = up_conv(filters[2], filters[1],ks=3)
        self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1], kernel_size=decoder_ks)

        self.Up2 = up_conv(filters[1], filters[0],ks=3)
        self.Up_conv2 = conv_block(filters[0]+filters[0], filters[0], kernel_size=decoder_ks)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return c6,c5,c4,c3,c2