import torch
from torch import nn
from .shared import conv_block,up_conv
from .KAUNet.lib.functional import subtraction2,dotproduction2, aggregation


class up_guide(nn.Module):
    def __init__(self,in_dim,guide_dim,out,kernel_size=3):
        super(up_guide,self).__init__()
        
        self.kernel_size = kernel_size
        self.kernel_gen = nn.Sequential(nn.Conv2d(guide_dim,guide_dim,kernel_size=self.kernel_size+4,padding=self.kernel_size//2+2,bias=True),
                                        nn.BatchNorm2d(guide_dim),
                                        nn.ReLU(),
                                        nn.Conv2d(guide_dim,self.kernel_size*self.kernel_size*in_dim,kernel_size=1,padding=0,bias=True))
        
        self.mlp = nn.Sequential(nn.Conv2d(in_dim,out,kernel_size=1,padding=0,bias=True),
                                nn.BatchNorm2d(out),
                                nn.ReLU(),
                                nn.Conv2d(out,out,kernel_size=1,padding=0,bias=True),
                                nn.BatchNorm2d(out),
                                nn.ReLU(),
                                )

    def forward(self,x_input,x_guide):
        '''
        x_input:B*in_dim*H*W
        '''
        B,in_dim,H,W = x_input.shape

        kernel = self.kernel_gen(x_guide) #B*(3*3*in_dim)*H*W
        kernel = kernel.reshape(B,1,self.kernel_size**2,-1)   #B, 1, K^2, left_inHW

        x_input = x_input.reshape(B,1,in_dim*H,W)
        x_input = aggregation(x_input,kernel,kernel_size=self.kernel_size,stride=1,padding=1,pad_mode=0)
        x_input = x_input.reshape(B,in_dim,H,W)
        
        x_input = self.mlp(x_input)
        return x_input


class UNet_guide(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_guide, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = up_guide(filters[3],filters[3],out=filters[3])


        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = up_guide(filters[2],filters[2],out=filters[2])


        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = up_guide(filters[1],filters[1],out=filters[1])


        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = up_guide(filters[0],filters[0],out=filters[0])


        self.Conv = nn.Conv2d(filters[0]+filters[1]+filters[2]+filters[3], out_ch,
                              kernel_size=1, stride=1, padding=0)


        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = self.Up_conv5(x_input=e4,x_guide=d5)

        d4 = self.Up4(d5)
        d4 = self.Up_conv4(x_input=e3,x_guide=d4)

        d3 = self.Up3(d4)
        d3 = self.Up_conv3(x_input=e2,x_guide=d3)

        d2 = self.Up2(d3)
        d2 = self.Up_conv2(x_input=e1,x_guide=d2)

        ds = torch.cat([self._Up8(d5),self._Up4(d4),self._Up2(d3),d2],dim = 1)
        out = self.Conv(ds)

        # d1 = self.active(out)

        return out






class up_guide_merge(nn.Module):
    def __init__(self,in_dim,guide_dim,out,kernel_size=3):
        super(up_guide_merge,self).__init__()
        
        self.kernel_size = kernel_size
        self.kernel_gen = nn.Sequential(nn.Conv2d(guide_dim,guide_dim,kernel_size=self.kernel_size+4,padding=self.kernel_size//2+2,bias=True),
                                        nn.BatchNorm2d(guide_dim),
                                        nn.ReLU(),
                                        nn.Conv2d(guide_dim,self.kernel_size*self.kernel_size*in_dim,kernel_size=1,padding=0,bias=True))
        
        self.mlp = nn.Sequential(nn.Conv2d(in_dim+guide_dim,out,kernel_size=1,padding=0,bias=True),
                                nn.BatchNorm2d(out),
                                nn.ReLU(),
                                nn.Conv2d(out,out,kernel_size=1,padding=0,bias=True),
                                nn.BatchNorm2d(out),
                                nn.ReLU(),
                                )

    def forward(self,x_input,x_guide):
        '''
        x_input:B*in_dim*H*W
        '''
        B,in_dim,H,W = x_input.shape

        kernel = self.kernel_gen(x_guide) #B*(3*3*in_dim)*H*W
        kernel = kernel.reshape(B,1,self.kernel_size**2,-1)   #B, 1, K^2, left_inHW

        x_input = x_input.reshape(B,1,in_dim*H,W)
        x_input = aggregation(x_input,kernel,kernel_size=self.kernel_size,stride=1,padding=1,pad_mode=0)
        x_input = x_input.reshape(B,in_dim,H,W)
        
        x_input = self.mlp(torch.cat([x_input,x_guide],dim = 1))
        return x_input


class UNet_guide_merge(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_guide_merge, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = up_guide_merge(filters[3],filters[3],out=filters[3])


        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = up_guide_merge(filters[2],filters[2],out=filters[2])


        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = up_guide_merge(filters[1],filters[1],out=filters[1])


        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = up_guide_merge(filters[0],filters[0],out=filters[0])


        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)


        self._Up2 = nn.Upsample(scale_factor=2)
        self._Up4 = nn.Upsample(scale_factor=4)
        self._Up8 = nn.Upsample(scale_factor=8)
        self._Up16 = nn.Upsample(scale_factor=16)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = self.Up_conv5(x_input=e4,x_guide=d5)

        d4 = self.Up4(d5)
        d4 = self.Up_conv4(x_input=e3,x_guide=d4)

        d3 = self.Up3(d4)
        d3 = self.Up_conv3(x_input=e2,x_guide=d3)

        d2 = self.Up2(d3)
        d2 = self.Up_conv2(x_input=e1,x_guide=d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out
