from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv
from .KAUNet.KAUNet_v import attention_conv_block_v,gumbel_softmax
from .KAUNet.lib.functional import subtraction2,dotproduction2, aggregation
# from modules.modulated_deform_conv import ModulatedDeformConv,ModulatedDeformConvPack 
from utils_torch import set_lr_mult


class UNet_Base(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1,n1=64,block = conv_block):
        super(UNet_Base, self).__init__()

        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = block(in_ch, filters[0],kernel_size=3) #conv_block(in_ch, filters[0],kernel_size=3)
        self.Conv2 = block(filters[0], filters[1],kernel_size=3) #conv_block(filters[0], filters[1],kernel_size=3)
        self.Conv3 = block(filters[1], filters[2],kernel_size=3) #conv_block(filters[1], filters[2],kernel_size=3)
        self.Conv4 = block(filters[2], filters[3],kernel_size=3) #conv_block(filters[2], filters[3],kernel_size=3)
        self.Conv5 = block(filters[3], filters[4],kernel_size=3) #conv_block(filters[3], filters[4],kernel_size=3)

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = block(filters[4], filters[3],kernel_size=3) #conv_block(filters[4], filters[3],kernel_size=3)

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = block(filters[3], filters[2],kernel_size=3) #conv_block(filters[3], filters[2],kernel_size=3)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = block(filters[2], filters[1],kernel_size=3) #conv_block(filters[2], filters[1],kernel_size=3)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = block(filters[1], filters[0],kernel_size=3) #conv_block(filters[1], filters[0],kernel_size=3)


        self.Conv = nn.Sequential(
                                nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0))

        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        #d1 = self.active(out)
        return out



class attention_delta(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(attention_delta, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.att_back = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=3,stride=1, padding=1, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.kernel_size**2, kernel_size=3,stride=1, padding=1, bias=True),
        )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )


        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        att_back = F.softmax(att_back,dim = 2)
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = (x - x[:,:,self.kernel_size**2//2:self.kernel_size**2//2+1]) * att_back

        x = x.reshape(B,-1,H,W)
        out = self.conv_delta(x) + self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out




class UNet_delta(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = attention_delta
        super(UNet_delta,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)




class attention_pure(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(attention_pure, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.att_back = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=3,stride=1, padding=1, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.kernel_size**2, kernel_size=3,stride=1, padding=1, bias=True),
        )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )


        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        att_back = F.softmax(att_back,dim = 2)
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = x * att_back

        x = x.reshape(B,-1,H,W)
        out = self.conv_delta(x) #+ self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out



class UNet_att_pure(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = attention_pure
        super(UNet_att_pure,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)




class attention_pure_add(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(attention_pure_add, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.att_back = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=3,stride=1, padding=1, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.kernel_size**2, kernel_size=3,stride=1, padding=1, bias=True),
        )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )


        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        att_back = F.softmax(att_back,dim = 2)
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = x * att_back

        x = x.reshape(B,-1,H,W)
        out = self.conv_delta(x) + self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out



class UNet_att_pure_add(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = attention_pure_add
        super(UNet_att_pure_add,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)




class attention_tanh(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(attention_tanh, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.att_back = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=3,stride=1, padding=1, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.kernel_size**2, kernel_size=3,stride=1, padding=1, bias=True),
            nn.Tanh(),
            nn.ReLU()
        )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )

    
        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = x * att_back

        x = x.reshape(B,-1,H,W)
        out = self.conv_delta(x) + self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out



class UNet_att_tanh(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = attention_tanh
        super(UNet_att_tanh,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)




class attention_tanh_delta(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(attention_tanh_delta, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.att_back = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=3,stride=1, padding=1, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.kernel_size**2, kernel_size=3,stride=1, padding=1, bias=True),
            nn.Tanh(),
            nn.ReLU()
        )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )

    
        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = (x-x[:,:,self.kernel_size**2//2:self.kernel_size**2//2+1]) * att_back

        x = x.reshape(B,-1,H,W)
        out = self.conv_delta(x) + self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out



class UNet_att_tanh_delta(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = attention_tanh_delta
        super(UNet_att_tanh_delta,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)



# class conv_block_DCN(nn.Module):
#     """
#     Convolution Block 
#     """

#     def __init__(self, in_ch, out_ch, kernel_size=3):
#         super(conv_block_DCN, self).__init__()

#         padding = kernel_size // 2

#         self.conv = nn.Sequential(
#             ModulatedDeformConvPack(in_ch, out_ch, kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             nn.BatchNorm2d(out_ch),
#             nn.ReLU(inplace=True),
#             ModulatedDeformConvPack(out_ch, out_ch, kernel_size=kernel_size,
#                       stride=1, padding=padding, bias=True),
#             nn.BatchNorm2d(out_ch),
#             nn.ReLU(inplace=True))
#         # self.conv_res = nn.Sequential(
#         #     nn.Conv2d(in_ch, out_ch, kernel_size=1,
#         #               stride=1, padding=0, bias=True),
#         #     nn.BatchNorm2d(out_ch),
#         #     nn.ReLU(inplace=True),)

#     def forward(self, x):
#         x = self.conv(x) #+ self.conv_res(x)

#         return x

# class UNet_DCN(UNet_Base):
#     def __init__(self,in_ch=3, out_ch=1,n1=64):
#         block = conv_block_DCN
#         super(UNet_DCN,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)



class avg_delta(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3):

        super(avg_delta, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_out = out_channels

        self.x_conv = nn.Sequential(
            nn.AvgPool2d(kernel_size = 9, stride=1,padding=4),
            nn.Conv2d(in_channels, self.att_out, kernel_size=1,stride=1, padding=0, bias=True),
            nn.BatchNorm2d(self.att_out),
            nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=8,mode='nearest'),#bilinear
        )
        set_lr_mult(self.x_conv,10)

        # self.att_back = nn.Sequential(
        #     nn.AvgPool2d(kernel_size = 8, stride=8),
        #     nn.Conv2d(in_channels, self.att_out, kernel_size=1,stride=1, padding=0, bias=True),
        #     nn.BatchNorm2d(self.att_out),
        #     nn.ReLU(inplace=True),
        #     nn.Upsample(scale_factor=8,mode='nearest'),#bilinear
        # )

        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
        self.pad = torch.nn.ZeroPad2d(kernel_size//2)
        self.conv_delta = nn.Sequential(nn.Conv2d(self.kernel_size**2*in_channels,self.att_out,kernel_size=1,padding=0,bias=True),
                                        nn.BatchNorm2d(self.att_out),
                                        nn.ReLU(),
                                        )


        self.att= None

    def forward(self, x):
        B,_,H,W = x.shape
        input_x = x

        # att_back = self.att_back(input_x).reshape(B,-1,self.kernel_size**2,H,W) #B*-1*H*W
        # att_back = F.softmax(att_back,dim = 2)
        x = self.unfold(self.pad(x)).reshape(B,-1,self.kernel_size**2,H,W)
        x = x - torch.mean(x,dim=2,keepdim=True)
        x = x.reshape(B,-1,H,W)

        out = self.conv_delta(x) + self.x_conv(input_x)

        # x = self.x_mlp(x)  # B, Cout, H, W

        return out




class UNet_avg_delta_Sep(UNet_Base):
    def __init__(self,in_ch=3, out_ch=1,n1=64):
        block = avg_delta
        super(UNet_avg_delta_Sep,self).__init__(in_ch=in_ch, out_ch=out_ch,n1=64,block = block)

