from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

# from .shared import conv_block, up_conv
from .KAUNet.KAUNet_v import attention_conv_block_v,gumbel_softmax
from .KAUNet.lib.functional import subtraction2,dotproduction2, aggregation
# from modules.modulated_deform_conv import ModulatedDeformConv,ModulatedDeformConvPack 
from utils_torch import set_lr_mult

class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch,ks = 3):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True),
            nn.Conv2d(in_ch, out_ch, kernel_size=ks,
                      stride=1, padding=ks//2, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(conv_block, self).__init__()

        padding = kernel_size // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x) 
        return x


class UNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super(UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

        self.Up5 = up_conv(filters[4], filters[3],ks=self.ks)
        self.Up_conv5 = conv_block(filters[4], filters[3], kernel_size=self.ks)

        self.Up4 = up_conv(filters[3], filters[2],ks=self.ks)
        self.Up_conv4 = conv_block(filters[3], filters[2], kernel_size=self.ks)

        self.Up3 = up_conv(filters[2], filters[1],ks=self.ks)
        self.Up_conv3 = conv_block(filters[2], filters[1], kernel_size=self.ks)

        self.Up2 = up_conv(filters[1], filters[0],ks=self.ks)
        self.Up_conv2 = conv_block(filters[1], filters[0], kernel_size=self.ks)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        #d1 = self.active(out)
        return out


class UNet_ks5(UNet):
    def __init__(self, in_ch=3, out_ch=1):
        kernel_size=5
        super(UNet_ks5, self).__init__(in_ch, out_ch,kernel_size)


class UNet_ks1(UNet):
    def __init__(self, in_ch=3, out_ch=1):
        kernel_size=1
        super(UNet_ks1, self).__init__(in_ch, out_ch,kernel_size)



class PyramidPool(nn.Module):
    def __init__(self, in_channels, out_channels, pool_size):
        super(PyramidPool, self).__init__()
        self.features = nn.Sequential(
            nn.AdaptiveAvgPool2d(pool_size),
            nn.Conv2d(in_channels, out_channels, 1, bias=True),
            nn.BatchNorm2d(out_channels, momentum=0.95),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        size = x.shape
        output = F.upsample_bilinear(self.features(x), size[2:])
        return output




class guide_conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(guide_conv_block, self).__init__()

        padding = kernel_size // 2
        mid_ch = out_ch
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, mid_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            )

        self.conv_global = nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_ch,mid_ch,kernel_size=7,stride=1,padding=int((7 - 1)/2) * 4,dilation=4),
            # nn.Conv2d(in_ch,out_ch,kernel_size=1),
            # nn.BatchNorm2d(out_ch),
            # nn.ReLU(),
        )

        
        self.BNRe = nn.Sequential(nn.BatchNorm2d(mid_ch*2),
                                    nn.PReLU(mid_ch*2),)
        self.reduce = nn.Sequential(nn.Conv2d(mid_ch*2,out_ch,kernel_size=1))

        self.refine_loc_glo = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                            nn.Conv2d(out_ch,out_ch//8,kernel_size=1),
                                            nn.ReLU(),
                                            nn.Conv2d(out_ch//8,out_ch,kernel_size=1),
                                            nn.Sigmoid())


    def forward(self, x):
        loc = self.conv(x) #+ self.conv_global(x)
        glo = self.conv_global(x)
        x = torch.cat([loc,glo],dim = 1)
        x = self.BNRe(x)
        x = self.reduce(x)
        x = x*self.refine_loc_glo(x)
        return x


class adaptive_conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3,layers = 2):
        super(adaptive_conv_block, self).__init__()

        padding = kernel_size // 2



        self.layers = layers
        self.conv_gate = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_ch, out_ch//4, kernel_size=7,stride=1, padding=3, bias=True),
            nn.BatchNorm2d(out_ch//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(out_ch//4, self.layers, kernel_size=3,stride=1, padding=1, bias=True),
            nn.Softmax(dim = 1),
            )
            
        
        self.Convs = nn.ModuleList()
        for i in range(self.layers):
            if i == 0:_in_ch = in_ch
            else:_in_ch = out_ch
            self.Convs.append(
                            nn.Sequential(
                            nn.Conv2d(_in_ch, out_ch, kernel_size=kernel_size,
                                    stride=1, padding=padding, bias=True),
                            nn.BatchNorm2d(out_ch),
                            nn.ReLU(inplace=True),
                            ))



    def forward(self, x):
        B,_,H,W = x.shape
        gate = self.conv_gate(x)
        out = None
        for i in range(self.layers):
            x = self.Convs[i](x)
            if i == 0:
                out = x*gate[:,i:i+1]
            else:
                out = out + x*gate[:,i:i+1]
      
        return x

import math 
from torch.nn.modules.batchnorm import _BatchNorm

class EMAU(nn.Module):
    '''The Expectation-Maximization Attention Unit (EMAU).
    Arguments:
        c (int): The input and output channel number.
        k (int): The number of the bases.
        stage_num (int): The iteration number for EM.
    '''
    def __init__(self, c, k, stage_num=3):
        super(EMAU, self).__init__()
        self.stage_num = stage_num

        mu = torch.Tensor(1, c, k)
        mu.normal_(0, math.sqrt(2. / k))    # Init with Kaiming Norm.
        mu = self._l2norm(mu, dim=1)
        self.register_buffer('mu', mu)

        self.conv1 = nn.Conv2d(c, c, 1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            nn.BatchNorm2d(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
 

    def forward(self, x):
        '''
        x:B*C*H*W
        '''
        idn = x
        # The first 1x1 conv
        x = self.conv1(x)

        # The EM Attention
        b, c, h, w = x.size()
        x = x.view(b, c, h*w)               # b * c * n
        mu = self.mu.repeat(b, 1, 1)        # b * c * k
        with torch.no_grad():
            for i in range(self.stage_num):
                x_t = x.permute(0, 2, 1)    # b * n * c
                z = torch.bmm(x_t, mu)      # b * n * k
                z = F.softmax(z, dim=2)     # b * n * k
                z_ = z / (1e-6 + z.sum(dim=1, keepdim=True))
                mu = torch.bmm(x, z_)       # b * c * k
                mu = self._l2norm(mu, dim=1)
        #re writen by Lin      
        # !!! The moving averaging operation is writtern in train.py, which is significant.
        if self.training:
            self.mu = 0.9*self.mu + mu.mean(dim=0, keepdim=True)*(1-0.9)
            

        z_t = z.permute(0, 2, 1)            # b * k * n
        x = mu.matmul(z_t)                  # b * c * n
        x = x.view(b, c, h, w)              # b * c * h * w
        x = F.relu(x, inplace=True)

        # The second 1x1 conv
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x, inplace=True)

        return x

    def _l2norm(self, inp, dim):
        '''Normlize the inp tensor with l2-norm.
        Returns a tensor where each sub-tensor of input along the given dim is 
        normalized such that the 2-norm of the sub-tensor is equal to 1.
        Arguments:
            inp (tensor): The input tensor.
            dim (int): The dimension to slice over to get the ssub-tensors.
        Returns:
            (tensor) The normalized tensor.
        '''
        return inp / (1e-6 + inp.norm(dim=dim, keepdim=True))


class UNet_EMAU(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super(UNet_EMAU, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3], kernel_size=self.ks)
        # self.Up5_guide = nn.Sequential(guide_conv_block(filters[3],128),nn.Upsample(scale_factor=8,mode='bilinear'))

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2], kernel_size=self.ks)
        # self.Up4_guide = nn.Sequential(guide_conv_block(filters[2],128),nn.Upsample(scale_factor=4,mode='bilinear'))

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1], kernel_size=self.ks)
        # self.Up3_guide = nn.Sequential(guide_conv_block(filters[1],128),nn.Upsample(scale_factor=2,mode='bilinear'))

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0], kernel_size=self.ks)
        # self.Up2_guide = guide_conv_block(filters[0],128)


        # self.final = nn.Sequential(nn.Conv2d(128*4,128,kernel_size=3,padding=1),
        #                             nn.BatchNorm2d(128),
        #                             nn.ReLU(),
        #                             nn.Conv2d(128,out_ch,kernel_size=1))

        self.em = EMAU(c=filters[0],k=32)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)
        # _d5 = self.Up5_guide(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        # _d4 = self.Up4_guide(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        # _d3 = self.Up3_guide(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        # _d2 = self.Up2_guide(d2)
        out = self.em(d2)
        out = self.Conv(out)
        # out = self.Conv(d2)
        # out = torch.cat([_d5,_d4,_d3,_d2],dim = 1)
        #d1 = self.active(out)
        return out









class Adapt_UNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1,adapt_layers = 2):
        super(Adapt_UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = adaptive_conv_block(in_ch, filters[0],layers=adapt_layers)
        self.Conv2 = adaptive_conv_block(filters[0], filters[1],layers=adapt_layers)
        self.Conv3 = adaptive_conv_block(filters[1], filters[2],layers=adapt_layers)
        self.Conv4 = adaptive_conv_block(filters[2], filters[3],layers=adapt_layers)
        self.Conv5 = adaptive_conv_block(filters[3], filters[4],layers=adapt_layers)

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = adaptive_conv_block(filters[4], filters[3],layers=adapt_layers)

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = adaptive_conv_block(filters[3], filters[2],layers=adapt_layers)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = adaptive_conv_block(filters[2], filters[1],layers=adapt_layers)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = adaptive_conv_block(filters[1], filters[0],layers=adapt_layers)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)


        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)
        # self.es = [e1,e2,e3,e4]

        # d1 = self.active(out)

        return out

 



class Adapt_UNet_merge(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1):
        super(Adapt_UNet_merge, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = adaptive_conv_block(in_ch, filters[0])
        self.Conv2 = adaptive_conv_block(filters[0], filters[1])
        self.Conv3 = adaptive_conv_block(filters[1], filters[2])
        self.Conv4 = adaptive_conv_block(filters[2], filters[3])
        self.Conv5 = adaptive_conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = adaptive_conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = adaptive_conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = adaptive_conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = adaptive_conv_block(filters[1], filters[0])

        # self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
        self.Conv = nn.Sequential(
                        nn.Conv2d(sum(filters[:4]),filters[0], kernel_size=1, stride=1, padding=0),
                        nn.BatchNorm2d(filters[0]),
                        nn.ReLU(),
                        nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0))


        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        # out = self.Conv(d2)
        ds = torch.cat([d2,self._Up2(d3),self._Up4(d4),self._Up8(d5)],dim = 1)
        out = self.Conv(ds)
        # d1 = self.active(out)

        return out

 


class adaptive_conv_block_weighMul(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3,layers = 2):
        super(adaptive_conv_block_weighMul, self).__init__()

        padding = kernel_size // 2

        self.conv_gate = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_ch, out_ch//6, kernel_size=7,stride=1, padding=3, bias=True),
            nn.BatchNorm2d(out_ch//6),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(out_ch//6, 2, kernel_size=3,stride=1, padding=1, bias=True),
            nn.Softmax(dim = 1),
            )
        self.layers = layers
        outs = [out_ch]*self.layers

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, outs[0], kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(outs[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(outs[0], outs[0], kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(outs[0]),
            nn.ReLU(inplace=True),
            )

        # self.conv2 = nn.Sequential(
        #     nn.Conv2d(in_ch, outs[1], kernel_size=(1,kernel_size+4),
        #               stride=1, padding=(0,kernel_size//2+2), bias=True),
        #     nn.BatchNorm2d(outs[1]),
        #     nn.ReLU(inplace=True),
        #     )

        # self.conv3 = nn.Sequential(
        #     nn.Conv2d(in_ch, outs[2], kernel_size=(kernel_size+4,1),
        #               stride=1, padding=(kernel_size//2+2,0), bias=True),
        #     nn.BatchNorm2d(outs[2]),
        #     nn.ReLU(inplace=True),
        #     )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_ch, outs[1], kernel_size=1,
                      stride=1, padding=0, bias=True),
            nn.BatchNorm2d(outs[1]),
            nn.ReLU(inplace=True),
            )


    def forward(self, x):
        B,_,H,W = x.shape
        gate = self.conv_gate(x)
        x1 = self.conv1(x) 
        x2 = self.conv2(x) 
        x =  x1*gate[:,0:1]+x2*gate[:,1:2] #,x3*gate[:,2:3],x4*gate[:,3:4]],dim = 1)

        return x


class Adapt_UNet_weightMul(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1):
        super(Adapt_UNet_weightMul, self).__init__()

        n1 = 60
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = adaptive_conv_block_weighMul(in_ch, filters[0])
        self.Conv2 = adaptive_conv_block_weighMul(filters[0], filters[1])
        self.Conv3 = adaptive_conv_block_weighMul(filters[1], filters[2])
        self.Conv4 = adaptive_conv_block_weighMul(filters[2], filters[3])
        self.Conv5 = adaptive_conv_block_weighMul(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = adaptive_conv_block_weighMul(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = adaptive_conv_block_weighMul(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = adaptive_conv_block_weighMul(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = adaptive_conv_block_weighMul(filters[1], filters[0])

        # self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
        self.Conv = nn.Sequential(
                        nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0))


        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)
        # d1 = self.active(out)

        return out

 


class UNet_refine(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_refine, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.final = nn.Sequential(nn.Conv2d(sum(filters), 256,kernel_size=1, stride=1, padding=0),
                                    nn.BatchNorm2d(256, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(256, out_ch,kernel_size=1, stride=1, padding=0),
                              )

        self.Conv_gate = nn.Sequential(nn.Conv2d(in_ch, 128,kernel_size=7, stride=1, padding=3),
                                    nn.BatchNorm2d(128, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(128, 128,kernel_size=3, stride=1, padding=1),
                                    nn.BatchNorm2d(128, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(128, 5,kernel_size=3, stride=1, padding=1),
                                    nn.Softmax(dim=1)
                              )

        self.recursive = 1
        self.ms = nn.ModuleList()
        for i in range(self.recursive):
            if i ==0:
                self.ms.append(nn.ModuleList([conv_block(in_ch, filters[0]),
                                            conv_block(filters[0], filters[1]),
                                            conv_block(filters[1], filters[2]),
                                            conv_block(filters[2], filters[3]),
                                            conv_block(filters[3], filters[4])]))
            else:
                self.ms.append(nn.ModuleList([conv_block(filters[0], filters[0]),
                                            conv_block(filters[0]+filters[1], filters[1]),
                                            conv_block(filters[1]+filters[2], filters[2]),
                                            conv_block(filters[2]+filters[3], filters[3]),
                                            conv_block(filters[3]+filters[4], filters[4])]))
        # self.Up5 = up_conv(filters[4], filters[3])
        # self.Up_conv5 = conv_block(filters[4], filters[3])

        # self.Up4 = up_conv(filters[3], filters[2])
        # self.Up_conv4 = conv_block(filters[3], filters[2])

        # self.Up3 = up_conv(filters[2], filters[1])
        # self.Up_conv3 = conv_block(filters[2], filters[1])

        # self.Up2 = up_conv(filters[1], filters[0])
        # self.Up_conv2 = conv_block(filters[1], filters[0])

        # self.Conv = nn.Conv2d(filters[0], out_ch,
        #                       kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)
        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)

        self.es = None

        self.H,self.W = None,None

    def forward(self, x):
        _,_,self.H,self.W = x.shape
        
        out = []
        for i in range(self.recursive):
            if i == 0:
                e1 = self.ms[i][0](x)

                e2 = self.Maxpool(e1)
                e2 = self.ms[i][1](e2)

                e3 = self.Maxpool(e2)
                e3 = self.ms[i][2](e3)

                e4 = self.Maxpool(e3)
                e4 = self.ms[i][3](e4)

                e5 = self.Maxpool(e4)
                e5 = self.ms[i][4](e5)
            else:
                e1 = self.ms[i][0](e1)

                e2 = self.ms[i][1](torch.cat([e2,self.Maxpool(e1)],dim = 1))

                e3 = self.ms[i][2](torch.cat([e3,self.Maxpool(e2)],dim = 1))

                e4 = self.ms[i][3](torch.cat([e4,self.Maxpool(e3)],dim = 1))

                e5 = self.ms[i][4](torch.cat([e5,self.Maxpool(e4)],dim = 1))

        gate = self.Conv_gate(x)

        out = torch.cat([e1*gate[:,0:1],self._Up2(e2)*gate[:,1:2],self._Up4(e3)*gate[:,2:3],self._Up8(e4)*gate[:,3:4],self._Up16(e5)*gate[:,4:5]],dim = 1)
        out = self.final(out)

        # d5 = self.Up5(e5)
        # d5 = torch.cat((e4, d5), dim=1)

        # d5 = self.Up_conv5(d5)

        # d4 = self.Up4(d5)
        # d4 = torch.cat((e3, d4), dim=1)
        # d4 = self.Up_conv4(d4)

        # d3 = self.Up3(d4)
        # d3 = torch.cat((e2, d3), dim=1)
        # d3 = self.Up_conv3(d3)

        # d2 = self.Up2(d3)
        # d2 = torch.cat((e1, d2), dim=1)
        # d2 = self.Up_conv2(d2)

        # out = self.Conv(d2)
        self.es = [e1,e2,e3,e4]

        # d1 = self.active(out)

        return out

    def get_es(self,):
        self.es = [self.es[0],self._Up2(self.es[1]),self._Up4(self.es[2]),self._Up8(self.es[3])]
        
        return self.es

    def get_att(self,):
        att = self.Conv1.att_conv_1.att[:,:1]
        att = att.reshape(att.shape[0],1,-1,self.H,self.W)
        return att



class UNet_grid(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_grid, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = adaptive_conv_block(in_ch, filters[0])
        # self.Conv1 = attention_conv_block_v(
            # in_ch, filters[0], 3, att_hidden=32, att_mh=1, att_sm=5, att_two_w=True, visualization=False)
        self.Conv2 = adaptive_conv_block(filters[0], filters[1])
        self.Conv3 = adaptive_conv_block(filters[1], filters[2])
        self.Conv4 = adaptive_conv_block(filters[2], filters[3])
        self.Conv5 = adaptive_conv_block(filters[3], filters[4])

        self.final = nn.Sequential(nn.Conv2d(sum(filters), 256,kernel_size=1, stride=1, padding=0),
                                    nn.BatchNorm2d(256, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(256, out_ch,kernel_size=1, stride=1, padding=0),
                              )

        self.Conv_gate = nn.Sequential(nn.Conv2d(sum(filters), 128,kernel_size=3, stride=1, padding=1),
                                    nn.BatchNorm2d(128, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(128, 128,kernel_size=3, stride=1, padding=1),
                                    nn.BatchNorm2d(128, momentum=0.95),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(128, 5,kernel_size=1, stride=1, padding=0),
                                    nn.Softmax(dim=1)
                              )

        
        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)

        self.es = None

        self.H,self.W = None,None

    def forward(self, x):
        _,_,self.H,self.W = x.shape
        
        out = []

        e1 = self.Conv1(x)

        e2 = self.Maxpool(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool(e4)
        e5 = self.Conv5(e5)
        

        gate = torch.cat([e1,self._Up2(e2),self._Up4(e3),self._Up8(e4),self._Up16(e5)],dim = 1)
        gate = self.Conv_gate(gate)

        out = torch.cat([e1*gate[:,0:1],self._Up2(e2)*gate[:,1:2],self._Up4(e3)*gate[:,2:3],self._Up8(e4)*gate[:,3:4],self._Up16(e5)*gate[:,4:5]],dim = 1)
        out = self.final(out)



        # self.es = [e1,e2,e3,e4]

        # d1 = self.active(out)

        return out






class kernel_select(nn.Module):
    def __init__(self,in_dim,guide_dim,out):
        super(kernel_select,self).__init__()
        
        self.kernel_sizes =  [[1,1],[3,3],[5,5]]
        self.kernel_gen = nn.Sequential(nn.Conv2d(guide_dim,guide_dim,kernel_size=7,padding=7//2,bias=True),
                                        nn.BatchNorm2d(guide_dim),
                                        nn.ReLU(),
                                        nn.Conv2d(guide_dim,len(self.kernel_sizes),kernel_size=1,padding=0,bias=True))

        self.kernels = nn.ModuleList()
        for k1,k2 in self.kernel_sizes: 
            self.kernels.append(
                nn.Sequential(nn.Conv2d(in_dim,out//2,kernel_size=k1,padding=k1//2,bias=True),
                                nn.BatchNorm2d(out//2),
                                nn.ReLU(),
                                nn.Conv2d(out//2,out,kernel_size=k2,padding=k2//2,bias=True),
                                nn.BatchNorm2d(out),
                                nn.ReLU(),)
                                )

        
    def forward(self,x_input,x_guide):
        '''
        x_input:B*in_dim*H*W
        '''
        B,in_dim,H,W = x_input.shape

        kernel_weight = self.kernel_gen(x_guide) #B*-1*H*W

        out = []
        for i in range(len(self.kernel_sizes)):
            out.append(kernel_weight[:,i:i+1]*self.kernels[i](x_input))
        
        out = torch.sum(torch.stack(out,dim = 1),dim=1)
        
        return out





# class UNet_parallel(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """

#     def __init__(self, in_ch=3, out_ch=1):
#         super(UNet_parallel, self).__init__()

#         n1 = 64
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

#         self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.K_Conv1 = conv_block(in_ch, filters[0],kernel_size=5)
#         self.K_Conv2 = conv_block(filters[0], filters[1],kernel_size=5)
#         self.K_Conv3 = conv_block(filters[1], filters[2],kernel_size=5)
#         self.K_Conv4 = conv_block(filters[2], filters[3],kernel_size=5)
#         self.K_Conv5 = conv_block(filters[3], filters[4],kernel_size=5)


#         self.Conv1 = kernel_select(in_ch, filters[0],out=filters[0])
#         self.Conv2 = kernel_select(filters[0], filters[1],out=filters[1])
#         self.Conv3 = kernel_select(filters[1], filters[2],out=filters[2])
#         self.Conv4 = kernel_select(filters[2], filters[3],out=filters[3])
#         self.Conv5 = kernel_select(filters[3], filters[4],out=filters[4])



#         # self.Conv = nn.Conv2d(sum(filters), out_ch,
#                             #   kernel_size=1, stride=1, padding=0)
#         self.final = nn.Sequential(nn.Conv2d(sum(filters), 256,kernel_size=1, stride=1, padding=0),
#                                     nn.BatchNorm2d(256, momentum=0.95),
#                                     nn.ReLU(inplace=True),
#                                     nn.Conv2d(256, out_ch,kernel_size=1, stride=1, padding=0),
#                               )

#         self._Up2 = nn.Upsample(scale_factor=2)
#         self._Up4 = nn.Upsample(scale_factor=4)
#         self._Up8 = nn.Upsample(scale_factor=8)
#         self._Up16 = nn.Upsample(scale_factor=16)

#         # self.active = torch.nn.Softmax(dim=1)

#     def forward(self, x):
#         g1 = self.K_Conv1(x)

#         g2 = self.Maxpool(g1)
#         g2 = self.K_Conv2(g2)

#         g3 = self.Maxpool(g2)
#         g3 = self.K_Conv3(g3)
 
#         g4 = self.Maxpool(g3)
#         g4 = self.K_Conv4(g4)

#         g5 = self.Maxpool(g4)
#         g5 = self.K_Conv5(g5)


#         e1 = self.Conv1(x_input=x,x_guide=g1)

#         e2 = self.Maxpool(e1)
#         e2 = self.Conv2(x_input=e2,x_guide=g2)

#         e3 = self.Maxpool(e2)
#         e3 = self.Conv3(x_input=e3,x_guide=g3)

#         e4 = self.Maxpool(e3)
#         e4 = self.Conv4(x_input=e4,x_guide=g4)

#         e5 = self.Maxpool(e4)
#         e5 = self.Conv5(x_input=e5,x_guide=g5)

#         # d5 = self.Up5(e5)
#         # d5 = self.Up_conv5(x_input=e4,x_guide=d5)

#         # d4 = self.Up4(d5)
#         # d4 = self.Up_conv4(x_input=e3,x_guide=d4)

#         # d3 = self.Up3(d4)
#         # d3 = self.Up_conv3(x_input=e2,x_guide=d3)

#         # d2 = self.Up2(d3)
#         # d2 = self.Up_conv2(x_input=e1,x_guide=d2)

#         ds = torch.cat([self._Up16(e5),self._Up8(e4),self._Up4(e3),self._Up2(e2),e1],dim = 1)
#         out = self.final(ds)
#         # out = self.Conv(ds)

#         # d1 = self.active(out)

#         return out



class conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch):
        super(conv, self).__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3,
                      stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class UNet_WO_Pool(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_WO_Pool, self).__init__()

        n1 = 64
        filters = [n1, n1, n1, n1, n1]

        # self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4]*2, filters[3])

        self.Up4 = conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3]*2, filters[2])

        self.Up3 = conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2]*2, filters[1])

        self.Up2 = conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1]*2, filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        # e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e1)

        # e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e2)

        # e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e3)

        # e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e4)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out




class TranUNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1,n1=64):
        super(TranUNet, self).__init__()

        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])
        
        self.layers = 2
        self.pe = nn.Parameter(torch.randn(1,1,filters[4]))
        self.transformers = nn.ModuleList()
        for i in range(self.layers):
            self.transformers.append(nn.TransformerEncoderLayer(d_model=filters[4], nhead=2))
        
        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5) #B*C*H*W
        
        B,C,h,w = e5.shape
        e5 = e5.permute(2,3,0,1)
        e5 = e5.reshape(h*w,B,C)
        e5 = e5 + self.pe
        for i in range(self.layers):
            e5 = self.transformers[i](e5)
        e5 = e5.reshape(h,w,B,C)
        e5 = e5.permute(2,3,0,1)
        
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        #d1 = self.active(out)
        return out
