from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

# from .shared import conv_block, up_conv
from .KAUNet.KAUNet_v import attention_conv_block_v,gumbel_softmax
from .KAUNet.lib.functional import subtraction2,dotproduction2, aggregation
# from modules.modulated_deform_conv import ModulatedDeformConv,ModulatedDeformConvPack 
from utils_torch import set_lr_mult

class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    def __init__(self, in_ch, out_ch,ks = 3):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True),
            nn.Conv2d(in_ch, out_ch, kernel_size=ks,
                      stride=1, padding=ks//2, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(conv_block, self).__init__()

        padding = kernel_size // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x) 
        return x


class UNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

        self.Up5 = up_conv(filters[4], filters[3],ks=self.ks)
        self.Up_conv5 = conv_block(filters[4], filters[3], kernel_size=self.ks)

        self.Up4 = up_conv(filters[3], filters[2],ks=self.ks)
        self.Up_conv4 = conv_block(filters[3], filters[2], kernel_size=self.ks)

        self.Up3 = up_conv(filters[2], filters[1],ks=self.ks)
        self.Up_conv3 = conv_block(filters[2], filters[1], kernel_size=self.ks)

        self.Up2 = up_conv(filters[1], filters[0],ks=self.ks)
        self.Up_conv2 = conv_block(filters[1], filters[0], kernel_size=self.ks)

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        #d1 = self.active(out)
        return out

class UNet_ori(UNet):
    def __init__(self, in_ch, out_ch):
        super().__init__(in_ch, out_ch)



# class UNet_ks5(UNet):
#     def __init__(self, in_ch=3, out_ch=1):
#         kernel_size=5
#         super(UNet_ks5, self).__init__(in_ch, out_ch,kernel_size)


# class UNet_ks1(UNet):
#     def __init__(self, in_ch=3, out_ch=1):
#         kernel_size=1
#         super(UNet_ks1, self).__init__(in_ch, out_ch,kernel_size)



# class PyramidPool(nn.Module):
#     def __init__(self, in_channels, out_channels, pool_size):
#         super(PyramidPool, self).__init__()
#         self.features = nn.Sequential(
#             nn.AdaptiveAvgPool2d(pool_size),
#             nn.Conv2d(in_channels, out_channels, 1, bias=True),
#             nn.BatchNorm2d(out_channels, momentum=0.95),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         size = x.shape
#         output = F.upsample_bilinear(self.features(x), size[2:])
#         return output




# class guide_conv_block(nn.Module):
#     """
#     Convolution Block 
#     """

#     def __init__(self, in_ch, out_ch, kernel_size=3):
#         super(guide_conv_block, self).__init__()

#         padding = kernel_size // 2
#         mid_ch = out_ch
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_ch, mid_ch, kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             nn.BatchNorm2d(mid_ch),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(mid_ch, mid_ch, kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             )

#         self.conv_global = nn.Sequential(
#             # nn.AdaptiveAvgPool2d(1),
#             nn.Conv2d(in_ch,mid_ch,kernel_size=7,stride=1,padding=int((7 - 1)/2) * 4,dilation=4),
#             # nn.Conv2d(in_ch,out_ch,kernel_size=1),
#             # nn.BatchNorm2d(out_ch),
#             # nn.ReLU(),
#         )

        
#         self.BNRe = nn.Sequential(nn.BatchNorm2d(mid_ch*2),
#                                     nn.PReLU(mid_ch*2),)
#         self.reduce = nn.Sequential(nn.Conv2d(mid_ch*2,out_ch,kernel_size=1))

#         self.refine_loc_glo = nn.Sequential(nn.AdaptiveAvgPool2d(1),
#                                             nn.Conv2d(out_ch,out_ch//8,kernel_size=1),
#                                             nn.ReLU(),
#                                             nn.Conv2d(out_ch//8,out_ch,kernel_size=1),
#                                             nn.Sigmoid())


#     def forward(self, x):
#         loc = self.conv(x) #+ self.conv_global(x)
#         glo = self.conv_global(x)
#         x = torch.cat([loc,glo],dim = 1)
#         x = self.BNRe(x)
#         x = self.reduce(x)
#         x = x*self.refine_loc_glo(x)
#         return x


# class adaptive_conv_block(nn.Module):
#     """
#     Convolution Block 
#     """

#     def __init__(self, in_ch, out_ch, kernel_size=3,layers = 2):
#         super(adaptive_conv_block, self).__init__()

#         padding = kernel_size // 2



#         self.layers = layers
#         self.conv_gate = nn.Sequential(
#             nn.MaxPool2d(kernel_size = 2, stride=2),
#             nn.Conv2d(in_ch, out_ch//4, kernel_size=7,stride=1, padding=3, bias=True),
#             nn.BatchNorm2d(out_ch//4),
#             nn.ReLU(inplace=True),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(out_ch//4, self.layers, kernel_size=3,stride=1, padding=1, bias=True),
#             nn.Softmax(dim = 1),
#             )
            
        
#         self.Convs = nn.ModuleList()
#         for i in range(self.layers):
#             if i == 0:_in_ch = in_ch
#             else:_in_ch = out_ch
#             self.Convs.append(
#                             nn.Sequential(
#                             nn.Conv2d(_in_ch, out_ch, kernel_size=kernel_size,
#                                     stride=1, padding=padding, bias=True),
#                             nn.BatchNorm2d(out_ch),
#                             nn.ReLU(inplace=True),
#                             ))



#     def forward(self, x):
#         B,_,H,W = x.shape
#         gate = self.conv_gate(x)
#         out = None
#         for i in range(self.layers):
#             x = self.Convs[i](x)
#             if i == 0:
#                 out = x*gate[:,i:i+1]
#             else:
#                 out = out + x*gate[:,i:i+1]
      
#         return x

# import math 
# from torch.nn.modules.batchnorm import _BatchNorm

# class EMAU(nn.Module):
#     '''The Expectation-Maximization Attention Unit (EMAU).
#     Arguments:
#         c (int): The input and output channel number.
#         k (int): The number of the bases.
#         stage_num (int): The iteration number for EM.
#     '''
#     def __init__(self, c, k, stage_num=3):
#         super(EMAU, self).__init__()
#         self.stage_num = stage_num

#         mu = torch.Tensor(1, c, k)
#         mu.normal_(0, math.sqrt(2. / k))    # Init with Kaiming Norm.
#         mu = self._l2norm(mu, dim=1)
#         self.register_buffer('mu', mu)

#         self.conv1 = nn.Conv2d(c, c, 1)
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(c, c, 1, bias=False),
#             nn.BatchNorm2d(c))        
        
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#                 m.weight.data.normal_(0, math.sqrt(2. / n))
#             elif isinstance(m, _BatchNorm):
#                 m.weight.data.fill_(1)
#                 if m.bias is not None:
#                     m.bias.data.zero_()
 

#     def forward(self, x):
#         '''
#         x:B*C*H*W
#         '''
#         idn = x
#         # The first 1x1 conv
#         x = self.conv1(x)

#         # The EM Attention
#         b, c, h, w = x.size()
#         x = x.view(b, c, h*w)               # b * c * n
#         mu = self.mu.repeat(b, 1, 1)        # b * c * k
#         with torch.no_grad():
#             for i in range(self.stage_num):
#                 x_t = x.permute(0, 2, 1)    # b * n * c
#                 z = torch.bmm(x_t, mu)      # b * n * k
#                 z = F.softmax(z, dim=2)     # b * n * k
#                 z_ = z / (1e-6 + z.sum(dim=1, keepdim=True))
#                 mu = torch.bmm(x, z_)       # b * c * k
#                 mu = self._l2norm(mu, dim=1)
#         #re writen by Lin      
#         # !!! The moving averaging operation is writtern in train.py, which is significant.
#         if self.training:
#             self.mu = 0.9*self.mu + mu.mean(dim=0, keepdim=True)*(1-0.9)
            

#         z_t = z.permute(0, 2, 1)            # b * k * n
#         x = mu.matmul(z_t)                  # b * c * n
#         x = x.view(b, c, h, w)              # b * c * h * w
#         x = F.relu(x, inplace=True)

#         # The second 1x1 conv
#         x = self.conv2(x)
#         x = x + idn
#         x = F.relu(x, inplace=True)

#         return x

#     def _l2norm(self, inp, dim):
#         '''Normlize the inp tensor with l2-norm.
#         Returns a tensor where each sub-tensor of input along the given dim is 
#         normalized such that the 2-norm of the sub-tensor is equal to 1.
#         Arguments:
#             inp (tensor): The input tensor.
#             dim (int): The dimension to slice over to get the ssub-tensors.
#         Returns:
#             (tensor) The normalized tensor.
#         '''
#         return inp / (1e-6 + inp.norm(dim=dim, keepdim=True))


# class UNet_EMAU(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """
#     def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
#         super(UNet_EMAU, self).__init__()

#         n1 = 64
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
#         self.ks = kernel_size
#         self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
#         self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
#         self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
#         self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
#         self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

#         self.Up5 = up_conv(filters[4], filters[3])
#         self.Up_conv5 = conv_block(filters[4], filters[3], kernel_size=self.ks)
#         # self.Up5_guide = nn.Sequential(guide_conv_block(filters[3],128),nn.Upsample(scale_factor=8,mode='bilinear'))

#         self.Up4 = up_conv(filters[3], filters[2])
#         self.Up_conv4 = conv_block(filters[3], filters[2], kernel_size=self.ks)
#         # self.Up4_guide = nn.Sequential(guide_conv_block(filters[2],128),nn.Upsample(scale_factor=4,mode='bilinear'))

#         self.Up3 = up_conv(filters[2], filters[1])
#         self.Up_conv3 = conv_block(filters[2], filters[1], kernel_size=self.ks)
#         # self.Up3_guide = nn.Sequential(guide_conv_block(filters[1],128),nn.Upsample(scale_factor=2,mode='bilinear'))

#         self.Up2 = up_conv(filters[1], filters[0])
#         self.Up_conv2 = conv_block(filters[1], filters[0], kernel_size=self.ks)
#         # self.Up2_guide = guide_conv_block(filters[0],128)


#         # self.final = nn.Sequential(nn.Conv2d(128*4,128,kernel_size=3,padding=1),
#         #                             nn.BatchNorm2d(128),
#         #                             nn.ReLU(),
#         #                             nn.Conv2d(128,out_ch,kernel_size=1))

#         self.em = EMAU(c=filters[0],k=32)

#         self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

#        # self.active = torch.nn.Sigmoid()

#     def forward(self, x):

#         e1 = self.Conv1(x)

#         e2 = self.Maxpool1(e1)
#         e2 = self.Conv2(e2)

#         e3 = self.Maxpool2(e2)
#         e3 = self.Conv3(e3)

#         e4 = self.Maxpool3(e3)
#         e4 = self.Conv4(e4)

#         e5 = self.Maxpool4(e4)
#         e5 = self.Conv5(e5)

#         d5 = self.Up5(e5)
#         d5 = torch.cat((e4, d5), dim=1)

#         d5 = self.Up_conv5(d5)
#         # _d5 = self.Up5_guide(d5)

#         d4 = self.Up4(d5)
#         d4 = torch.cat((e3, d4), dim=1)
#         d4 = self.Up_conv4(d4)
#         # _d4 = self.Up4_guide(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((e2, d3), dim=1)
#         d3 = self.Up_conv3(d3)
#         # _d3 = self.Up3_guide(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((e1, d2), dim=1)
#         d2 = self.Up_conv2(d2)
#         # _d2 = self.Up2_guide(d2)
#         out = self.em(d2)
#         out = self.Conv(out)
#         # out = self.Conv(d2)
#         # out = torch.cat([_d5,_d4,_d3,_d2],dim = 1)
#         #d1 = self.active(out)
#         return out









# class Adapt_UNet(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """
#     def __init__(self, in_ch=3, out_ch=1,adapt_layers = 2):
#         super(Adapt_UNet, self).__init__()

#         n1 = 64
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
#         self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = adaptive_conv_block(in_ch, filters[0],layers=adapt_layers)
#         self.Conv2 = adaptive_conv_block(filters[0], filters[1],layers=adapt_layers)
#         self.Conv3 = adaptive_conv_block(filters[1], filters[2],layers=adapt_layers)
#         self.Conv4 = adaptive_conv_block(filters[2], filters[3],layers=adapt_layers)
#         self.Conv5 = adaptive_conv_block(filters[3], filters[4],layers=adapt_layers)

#         self.Up5 = up_conv(filters[4], filters[3])
#         self.Up_conv5 = adaptive_conv_block(filters[4], filters[3],layers=adapt_layers)

#         self.Up4 = up_conv(filters[3], filters[2])
#         self.Up_conv4 = adaptive_conv_block(filters[3], filters[2],layers=adapt_layers)

#         self.Up3 = up_conv(filters[2], filters[1])
#         self.Up_conv3 = adaptive_conv_block(filters[2], filters[1],layers=adapt_layers)

#         self.Up2 = up_conv(filters[1], filters[0])
#         self.Up_conv2 = adaptive_conv_block(filters[1], filters[0],layers=adapt_layers)

#         self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)


#         self._Up2 = nn.Upsample(scale_factor=2)
#         self._Up4 = nn.Upsample(scale_factor=4)
#         self._Up8 = nn.Upsample(scale_factor=8)
#         self._Up16 = nn.Upsample(scale_factor=16)
#        # self.active = torch.nn.Sigmoid()

#     def forward(self, x):

#         e1 = self.Conv1(x)

#         e2 = self.Maxpool1(e1)
#         e2 = self.Conv2(e2)

#         e3 = self.Maxpool2(e2)
#         e3 = self.Conv3(e3)

#         e4 = self.Maxpool3(e3)
#         e4 = self.Conv4(e4)

#         e5 = self.Maxpool4(e4)
#         e5 = self.Conv5(e5)

#         d5 = self.Up5(e5)
#         d5 = torch.cat((e4, d5), dim=1)

#         d5 = self.Up_conv5(d5)

#         d4 = self.Up4(d5)
#         d4 = torch.cat((e3, d4), dim=1)
#         d4 = self.Up_conv4(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((e2, d3), dim=1)
#         d3 = self.Up_conv3(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((e1, d2), dim=1)
#         d2 = self.Up_conv2(d2)

#         out = self.Conv(d2)
#         # self.es = [e1,e2,e3,e4]

#         # d1 = self.active(out)

#         return out

 



# class Adapt_UNet_merge(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """
#     def __init__(self, in_ch=3, out_ch=1):
#         super(Adapt_UNet_merge, self).__init__()

#         n1 = 64
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
#         self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = adaptive_conv_block(in_ch, filters[0])
#         self.Conv2 = adaptive_conv_block(filters[0], filters[1])
#         self.Conv3 = adaptive_conv_block(filters[1], filters[2])
#         self.Conv4 = adaptive_conv_block(filters[2], filters[3])
#         self.Conv5 = adaptive_conv_block(filters[3], filters[4])

#         self.Up5 = up_conv(filters[4], filters[3])
#         self.Up_conv5 = adaptive_conv_block(filters[4], filters[3])

#         self.Up4 = up_conv(filters[3], filters[2])
#         self.Up_conv4 = adaptive_conv_block(filters[3], filters[2])

#         self.Up3 = up_conv(filters[2], filters[1])
#         self.Up_conv3 = adaptive_conv_block(filters[2], filters[1])

#         self.Up2 = up_conv(filters[1], filters[0])
#         self.Up_conv2 = adaptive_conv_block(filters[1], filters[0])

#         # self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
#         self.Conv = nn.Sequential(
#                         nn.Conv2d(sum(filters[:4]),filters[0], kernel_size=1, stride=1, padding=0),
#                         nn.BatchNorm2d(filters[0]),
#                         nn.ReLU(),
#                         nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0))


#         self._Up2 = nn.Upsample(scale_factor=2)
#         self._Up4 = nn.Upsample(scale_factor=4)
#         self._Up8 = nn.Upsample(scale_factor=8)
#         self._Up16 = nn.Upsample(scale_factor=16)
#        # self.active = torch.nn.Sigmoid()

#     def forward(self, x):

#         e1 = self.Conv1(x)

#         e2 = self.Maxpool1(e1)
#         e2 = self.Conv2(e2)

#         e3 = self.Maxpool2(e2)
#         e3 = self.Conv3(e3)

#         e4 = self.Maxpool3(e3)
#         e4 = self.Conv4(e4)

#         e5 = self.Maxpool4(e4)
#         e5 = self.Conv5(e5)

#         d5 = self.Up5(e5)
#         d5 = torch.cat((e4, d5), dim=1)

#         d5 = self.Up_conv5(d5)

#         d4 = self.Up4(d5)
#         d4 = torch.cat((e3, d4), dim=1)
#         d4 = self.Up_conv4(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((e2, d3), dim=1)
#         d3 = self.Up_conv3(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((e1, d2), dim=1)
#         d2 = self.Up_conv2(d2)

#         # out = self.Conv(d2)
#         ds = torch.cat([d2,self._Up2(d3),self._Up4(d4),self._Up8(d5)],dim = 1)
#         out = self.Conv(ds)
#         # d1 = self.active(out)

#         return out

 


# class adaptive_conv_block_weighMul(nn.Module):
#     """
#     Convolution Block 
#     """

#     def __init__(self, in_ch, out_ch, kernel_size=3,layers = 2):
#         super(adaptive_conv_block_weighMul, self).__init__()

#         padding = kernel_size // 2

#         self.conv_gate = nn.Sequential(
#             nn.MaxPool2d(kernel_size = 2, stride=2),
#             nn.Conv2d(in_ch, out_ch//6, kernel_size=7,stride=1, padding=3, bias=True),
#             nn.BatchNorm2d(out_ch//6),
#             nn.ReLU(inplace=True),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(out_ch//6, 2, kernel_size=3,stride=1, padding=1, bias=True),
#             nn.Softmax(dim = 1),
#             )
#         self.layers = layers
#         outs = [out_ch]*self.layers

#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_ch, outs[0], kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             nn.BatchNorm2d(outs[0]),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(outs[0], outs[0], kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             nn.BatchNorm2d(outs[0]),
#             nn.ReLU(inplace=True),
#             )

#         # self.conv2 = nn.Sequential(
#         #     nn.Conv2d(in_ch, outs[1], kernel_size=(1,kernel_size+4),
#         #               stride=1, padding=(0,kernel_size//2+2), bias=True),
#         #     nn.BatchNorm2d(outs[1]),
#         #     nn.ReLU(inplace=True),
#         #     )

#         # self.conv3 = nn.Sequential(
#         #     nn.Conv2d(in_ch, outs[2], kernel_size=(kernel_size+4,1),
#         #               stride=1, padding=(kernel_size//2+2,0), bias=True),
#         #     nn.BatchNorm2d(outs[2]),
#         #     nn.ReLU(inplace=True),
#         #     )

#         self.conv2 = nn.Sequential(
#             nn.Conv2d(in_ch, outs[1], kernel_size=1,
#                       stride=1, padding=0, bias=True),
#             nn.BatchNorm2d(outs[1]),
#             nn.ReLU(inplace=True),
#             )


#     def forward(self, x):
#         B,_,H,W = x.shape
#         gate = self.conv_gate(x)
#         x1 = self.conv1(x) 
#         x2 = self.conv2(x) 
#         x =  x1*gate[:,0:1]+x2*gate[:,1:2] #,x3*gate[:,2:3],x4*gate[:,3:4]],dim = 1)

#         return x


# class Adapt_UNet_weightMul(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """
#     def __init__(self, in_ch=3, out_ch=1):
#         super(Adapt_UNet_weightMul, self).__init__()

#         n1 = 60
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
#         self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = adaptive_conv_block_weighMul(in_ch, filters[0])
#         self.Conv2 = adaptive_conv_block_weighMul(filters[0], filters[1])
#         self.Conv3 = adaptive_conv_block_weighMul(filters[1], filters[2])
#         self.Conv4 = adaptive_conv_block_weighMul(filters[2], filters[3])
#         self.Conv5 = adaptive_conv_block_weighMul(filters[3], filters[4])

#         self.Up5 = up_conv(filters[4], filters[3])
#         self.Up_conv5 = adaptive_conv_block_weighMul(filters[4], filters[3])

#         self.Up4 = up_conv(filters[3], filters[2])
#         self.Up_conv4 = adaptive_conv_block_weighMul(filters[3], filters[2])

#         self.Up3 = up_conv(filters[2], filters[1])
#         self.Up_conv3 = adaptive_conv_block_weighMul(filters[2], filters[1])

# class UNet(nn.Module):

#     def __init__(self, in_ch=3, num_classes=1):
#         super().__init__()

#         n1 = 64
#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

#         self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = adaptive_conv_block(in_ch, filters[0])
#         # self.Conv1 = attention_conv_block_v(
#             # in_ch, filters[0], 3, att_hidden=32, att_mh=1, att_sm=5, att_two_w=True, visualization=False)
#         self.Conv2 = adaptive_conv_block(filters[0], filters[1])
#         self.Conv3 = adaptive_conv_block(filters[1], filters[2])
#         self.Conv4 = adaptive_conv_block(filters[2], filters[3])
#         self.Conv5 = adaptive_conv_block(filters[3], filters[4])

#         self.final = nn.Sequential(nn.Conv2d(sum(filters), 256,kernel_size=1, stride=1, padding=0),
#                                     nn.BatchNorm2d(256, momentum=0.95),
#                                     nn.ReLU(inplace=True),
#                                     nn.Conv2d(256, out_ch,kernel_size=1, stride=1, padding=0),
#                               )

#         self.Conv_gate = nn.Sequential(nn.Conv2d(sum(filters), 128,kernel_size=3, stride=1, padding=1),
#                                     nn.BatchNorm2d(128, momentum=0.95),
#                                     nn.ReLU(inplace=True),
#                                     nn.Conv2d(128, 128,kernel_size=3, stride=1, padding=1),
#                                     nn.BatchNorm2d(128, momentum=0.95),
#                                     nn.ReLU(inplace=True),
#                                     nn.Conv2d(128, 5,kernel_size=1, stride=1, padding=0),
#                                     nn.Softmax(dim=1)
#                               )

        
#         self._Up2 = nn.Upsample(scale_factor=2)
#         self._Up4 = nn.Upsample(scale_factor=4)
#         self._Up8 = nn.Upsample(scale_factor=8)
#         self._Up16 = nn.Upsample(scale_factor=16)

#         self.es = None

#         self.H,self.W = None,None

#     def forward(self, x):
#         _,_,self.H,self.W = x.shape
        
#         out = []

#         e1 = self.Conv1(x)

#         e2 = self.Maxpool(e1)
#         e2 = self.Conv2(e2)

#         e3 = self.Maxpool(e2)
#         e3 = self.Conv3(e3)

#         e4 = self.Maxpool(e3)
#         e4 = self.Conv4(e4)

#         e5 = self.Maxpool(e4)
#         e5 = self.Conv5(e5)
        

#         gate = torch.cat([e1,self._Up2(e2),self._Up4(e3),self._Up8(e4),self._Up16(e5)],dim = 1)
#         gate = self.Conv_gate(gate)

#         out = torch.cat([e1*gate[:,0:1],self._Up2(e2)*gate[:,1:2],self._Up4(e3)*gate[:,2:3],self._Up8(e4)*gate[:,3:4],self._Up16(e5)*gate[:,4:5]],dim = 1)
#         out = self.final(out)



#         # self.es = [e1,e2,e3,e4]

#         # d1 = self.active(out)

#         return out






# class kernel_select(nn.Module):
#     def __init__(self,in_dim,guide_dim,out):
#         super(kernel_select,self).__init__()
        
#         self.kernel_sizes =  [[1,1],[3,3],[5,5]]
#         self.kernel_gen = nn.Sequential(nn.Conv2d(guide_dim,guide_dim,kernel_size=7,padding=7//2,bias=True),
#                                         nn.BatchNorm2d(guide_dim),
#                                         nn.ReLU(),
#                                         nn.Conv2d(guide_dim,len(self.kernel_sizes),kernel_size=1,padding=0,bias=True))

#         self.kernels = nn.ModuleList()
#         for k1,k2 in self.kernel_sizes: 
#             self.kernels.append(
#                 nn.Sequential(nn.Conv2d(in_dim,out//2,kernel_size=k1,padding=k1//2,bias=True),
#                                 nn.BatchNorm2d(out//2),
#                                 nn.ReLU(),
#                                 nn.Conv2d(out//2,out,kernel_size=k2,padding=k2//2,bias=True),
#                                 nn.BatchNorm2d(out),
#                                 nn.ReLU(),)
#                                 )

        
#     def forward(self,x_input,x_guide):
#         '''
#         x_input:B*in_dim*H*W
#         '''
#         B,in_dim,H,W = x_input.shape

#         kernel_weight = self.kernel_gen(x_guide) #B*-1*H*W

#         out = []
#         for i in range(len(self.kernel_sizes)):
#             out.append(kernel_weight[:,i:i+1]*self.kernels[i](x_input))
        
#         out = torch.sum(torch.stack(out,dim = 1),dim=1)
        
#         return out





# # class UNet_parallel(nn.Module):
# #     """
# #     UNet - Basic Implementation
# #     Paper : https://arxiv.org/abs/1505.04597
# #     """

# #     def __init__(self, in_ch=3, out_ch=1):
# #         super(UNet_parallel, self).__init__()

# #         n1 = 64
# #         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

# #         self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

# #         self.K_Conv1 = conv_block(in_ch, filters[0],kernel_size=5)
# #         self.K_Conv2 = conv_block(filters[0], filters[1],kernel_size=5)
# #         self.K_Conv3 = conv_block(filters[1], filters[2],kernel_size=5)
# #         self.K_Conv4 = conv_block(filters[2], filters[3],kernel_size=5)
# #         self.K_Conv5 = conv_block(filters[3], filters[4],kernel_size=5)


# #         self.Conv1 = kernel_select(in_ch, filters[0],out=filters[0])
# #         self.Conv2 = kernel_select(filters[0], filters[1],out=filters[1])
# #         self.Conv3 = kernel_select(filters[1], filters[2],out=filters[2])
# #         self.Conv4 = kernel_select(filters[2], filters[3],out=filters[3])
# #         self.Conv5 = kernel_select(filters[3], filters[4],out=filters[4])



# #         # self.Conv = nn.Conv2d(sum(filters), out_ch,
# #                             #   kernel_size=1, stride=1, padding=0)
# #         self.final = nn.Sequential(nn.Conv2d(sum(filters), 256,kernel_size=1, stride=1, padding=0),
# #                                     nn.BatchNorm2d(256, momentum=0.95),
# #                                     nn.ReLU(inplace=True),
# #                                     nn.Conv2d(256, out_ch,kernel_size=1, stride=1, padding=0),
# #                               )

# #         self._Up2 = nn.Upsample(scale_factor=2)
# #         self._Up4 = nn.Upsample(scale_factor=4)
# #         self._Up8 = nn.Upsample(scale_factor=8)
# #         self._Up16 = nn.Upsample(scale_factor=16)

# #         # self.active = torch.nn.Softmax(dim=1)

# #     def forward(self, x):
# #         g1 = self.K_Conv1(x)

# #         g2 = self.Maxpool(g1)
# #         g2 = self.K_Conv2(g2)

# #         g3 = self.Maxpool(g2)
# #         g3 = self.K_Conv3(g3)
 
# #         g4 = self.Maxpool(g3)
# #         g4 = self.K_Conv4(g4)

# #         g5 = self.Maxpool(g4)
# #         g5 = self.K_Conv5(g5)


# #         e1 = self.Conv1(x_input=x,x_guide=g1)

# #         e2 = self.Maxpool(e1)
# #         e2 = self.Conv2(x_input=e2,x_guide=g2)

# #         e3 = self.Maxpool(e2)
# #         e3 = self.Conv3(x_input=e3,x_guide=g3)

# #         e4 = self.Maxpool(e3)
# #         e4 = self.Conv4(x_input=e4,x_guide=g4)

# #         e5 = self.Maxpool(e4)
# #         e5 = self.Conv5(x_input=e5,x_guide=g5)

# #         # d5 = self.Up5(e5)
# #         # d5 = self.Up_conv5(x_input=e4,x_guide=d5)

# #         # d4 = self.Up4(d5)
# #         # d4 = self.Up_conv4(x_input=e3,x_guide=d4)

# #         # d3 = self.Up3(d4)
# #         # d3 = self.Up_conv3(x_input=e2,x_guide=d3)

# #         # d2 = self.Up2(d3)
# #         # d2 = self.Up_conv2(x_input=e1,x_guide=d2)

# #         ds = torch.cat([self._Up16(e5),self._Up8(e4),self._Up4(e3),self._Up2(e2),e1],dim = 1)
# #         out = self.final(ds)
# #         # out = self.Conv(ds)

# #         # d1 = self.active(out)

# #         return out



# class conv(nn.Module):
#     """
#     Up Convolution Block
#     """

#     def __init__(self, in_ch, out_ch):
#         super(conv, self).__init__()
#         self.up = nn.Sequential(
#             nn.Conv2d(in_ch, out_ch, kernel_size=3,
#                       stride=1, padding=1, bias=True),
#             nn.BatchNorm2d(out_ch),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         x = self.up(x)
#         return x


# class UNet_WO_Pool(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """

#     def __init__(self, in_ch=3, out_ch=1):
#         super(UNet_WO_Pool, self).__init__()

#         n1 = 64
#         filters = [n1, n1, n1, n1, n1]

#         # self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = conv_block(in_ch, filters[0])
#         self.Conv2 = conv_block(filters[0], filters[1])
#         self.Conv3 = conv_block(filters[1], filters[2])
#         self.Conv4 = conv_block(filters[2], filters[3])
#         self.Conv5 = conv_block(filters[3], filters[4])

#         self.Up5 = conv(filters[4], filters[3])
#         self.Up_conv5 = conv_block(filters[4]*2, filters[3])

#         self.Up4 = conv(filters[3], filters[2])
#         self.Up_conv4 = conv_block(filters[3]*2, filters[2])

#         self.Up3 = conv(filters[2], filters[1])
#         self.Up_conv3 = conv_block(filters[2]*2, filters[1])

#         self.Up2 = conv(filters[1], filters[0])
#         self.Up_conv2 = conv_block(filters[1]*2, filters[0])

#         self.Conv = nn.Conv2d(filters[0], out_ch,
#                               kernel_size=1, stride=1, padding=0)

#         # self.active = torch.nn.Softmax(dim=1)

#     def forward(self, x):
#         e1 = self.Conv1(x)

#         # e2 = self.Maxpool1(e1)
#         e2 = self.Conv2(e1)

#         # e3 = self.Maxpool2(e2)
#         e3 = self.Conv3(e2)

#         # e4 = self.Maxpool3(e3)
#         e4 = self.Conv4(e3)

#         # e5 = self.Maxpool4(e4)
#         e5 = self.Conv5(e4)

#         d5 = self.Up5(e5)
#         d5 = torch.cat((e4, d5), dim=1)

#         d5 = self.Up_conv5(d5)

#         d4 = self.Up4(d5)
#         d4 = torch.cat((e3, d4), dim=1)
#         d4 = self.Up_conv4(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((e2, d3), dim=1)
#         d3 = self.Up_conv3(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((e1, d2), dim=1)
#         d2 = self.Up_conv2(d2)

#         out = self.Conv(d2)

#         # d1 = self.active(out)

#         return out




# class TranUNet(nn.Module):
#     """
#     UNet - Basic Implementation
#     Paper : https://arxiv.org/abs/1505.04597
#     """
#     def __init__(self, in_ch=3, out_ch=1,n1=64):
#         super(TranUNet, self).__init__()

#         filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
#         self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

#         self.Conv1 = conv_block(in_ch, filters[0])
#         self.Conv2 = conv_block(filters[0], filters[1])
#         self.Conv3 = conv_block(filters[1], filters[2])
#         self.Conv4 = conv_block(filters[2], filters[3])
#         self.Conv5 = conv_block(filters[3], filters[4])
        
#         self.layers = 2
#         self.pe = nn.Parameter(torch.randn(1,1,filters[4]))
#         self.transformers = nn.ModuleList()
#         for i in range(self.layers):
#             self.transformers.append(nn.TransformerEncoderLayer(d_model=filters[4], nhead=2))
        
#         self.Up5 = up_conv(filters[4], filters[3])
#         self.Up_conv5 = conv_block(filters[4], filters[3])

#         self.Up4 = up_conv(filters[3], filters[2])
#         self.Up_conv4 = conv_block(filters[3], filters[2])

#         self.Up3 = up_conv(filters[2], filters[1])
#         self.Up_conv3 = conv_block(filters[2], filters[1])

#         self.Up2 = up_conv(filters[1], filters[0])
#         self.Up_conv2 = conv_block(filters[1], filters[0])

#         self.Conv = nn.Conv2d(filters[0], num_classes,
#                               kernel_size=1, stride=1, padding=0)

#        # self.active = torch.nn.Sigmoid()

#     def forward(self, x):

#         e1 = self.Conv1(x)

#         e2 = self.Maxpool1(e1)
#         e2 = self.Conv2(e2)

#         e3 = self.Maxpool2(e2)
#         e3 = self.Conv3(e3)

#         e4 = self.Maxpool3(e3)
#         e4 = self.Conv4(e4)

#         e5 = self.Maxpool4(e4)
#         e5 = self.Conv5(e5) #B*C*H*W
        
#         B,C,h,w = e5.shape
#         e5 = e5.permute(2,3,0,1)
#         e5 = e5.reshape(h*w,B,C)
#         e5 = e5 + self.pe
#         for i in range(self.layers):
#             e5 = self.transformers[i](e5)
#         e5 = e5.reshape(h,w,B,C)
#         e5 = e5.permute(2,3,0,1)
        
        
#         d5 = self.Up5(e5)
#         d5 = torch.cat((e4, d5), dim=1)

#         d5 = self.Up_conv5(d5)

#         d4 = self.Up4(d5)
#         d4 = torch.cat((e3, d4), dim=1)
#         d4 = self.Up_conv4(d4)

#         d3 = self.Up3(d4)
#         d3 = torch.cat((e2, d3), dim=1)
#         d3 = self.Up_conv3(d3)

#         d2 = self.Up2(d3)
#         d2 = torch.cat((e1, d2), dim=1)
#         d2 = self.Up_conv2(d2)

#         out = self.Conv(d2)

#         #d1 = self.active(out)
#         return out
