import torch
from torch import nn
from .shared import *
import torch.nn.functional as F


class Condidence_gate_varmean(nn.Module):
    def __init__(self,in_channel,out_channel=1,kernel_size=2,stride = 2):
        super().__init__()
        assert kernel_size == stride
        self.ks = kernel_size
        self.conv = nn.Sequential(nn.Conv2d(in_channel,in_channel//2,kernel_size=2,stride = 2),
                                    nn.BatchNorm2d(in_channel//2),
                                    nn.ReLU())
        self.conv_var = nn.Sequential(nn.Conv2d(in_channel*2,in_channel-in_channel//2,kernel_size=1),
                                    nn.BatchNorm2d(in_channel-in_channel//2),
                                    nn.ReLU())

        self.conv_gate = nn.Sequential(nn.Conv2d(in_channel,in_channel//4,kernel_size=1),
                                        nn.BatchNorm2d(in_channel//4),
                                        nn.ReLU(),
                                        nn.Conv2d(in_channel//4,out_channel,kernel_size=1),
                                        nn.Sigmoid())

    def forward(self,x):
        B,C,H,W = x.shape
        v = x.reshape(B,C,H//self.ks,self.ks,W//self.ks,self.ks)
        v = v.permute(0,1,3,5,2,4).reshape(B,C,self.ks**2,H//self.ks,W//self.ks)
        v= torch.cat([torch.var(v,dim = 2),torch.mean(v,dim =2)],dim =1)
        v = self.conv_var(v)
        
        x = self.conv(x)

        return self.conv_gate(torch.cat([x,v],dim =1))



class Condidence_gate_var(nn.Module):
    def __init__(self,in_channel,out_channel=1,kernel_size=2,stride = 2):
        super().__init__()
        assert kernel_size == stride
        self.ks = kernel_size
        self.conv = nn.Sequential(nn.Conv2d(in_channel,in_channel//2,kernel_size=2,stride = 2),
                                    nn.BatchNorm2d(in_channel//2),
                                    nn.ReLU())
        self.conv_var = nn.Sequential(nn.Conv2d(in_channel,in_channel-in_channel//2,kernel_size=1),
                                    nn.BatchNorm2d(in_channel-in_channel//2),
                                    nn.ReLU())

        self.conv_gate = nn.Sequential(nn.Conv2d(in_channel,in_channel//4,kernel_size=1),
                                        nn.BatchNorm2d(in_channel//4),
                                        nn.ReLU(),
                                        nn.Conv2d(in_channel//4,out_channel,kernel_size=1),
                                        nn.Sigmoid())

    def forward(self,x):
        B,C,H,W = x.shape
        v = x.reshape(B,C,H//self.ks,self.ks,W//self.ks,self.ks)
        v = v.permute(0,1,3,5,2,4).reshape(B,C,self.ks**2,H//self.ks,W//self.ks)
        v= torch.var(v,dim = 2)
        v = self.conv_var(v)
        
        x = self.conv(x)

        return self.conv_gate(torch.cat([x,v],dim =1))
        

class UNet_gate(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3,CG = Condidence_gate_var):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)
        

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)
        
        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)
        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)
        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)
        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)
        
        self.Conv6_gate = CG(filters[3]) #nn.Sequential(nn.Conv2d(filters[3],1,kernel_size=2,stride = 2),nn.Sigmoid())
        self.Conv5_gate = CG(filters[2]) #nn.Sequential(nn.Conv2d(filters[2],1,kernel_size=2,stride = 2),nn.Sigmoid())
        self.Conv4_gate = CG(filters[1]) #nn.Sequential(nn.Conv2d(filters[1],1,kernel_size=2,stride = 2),nn.Sigmoid())
        self.Conv3_gate = CG(filters[0]) #nn.Sequential(nn.Conv2d(filters[0],1,kernel_size=2,stride = 2),nn.Sigmoid())
        
        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3], kernel_size=decoder_ks)

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2], kernel_size=decoder_ks)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1], kernel_size=decoder_ks)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[0]+filters[0], filters[0], kernel_size=decoder_ks)


       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        g3 = self.Conv3_gate(e1)
        
        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        g4 = self.Conv4_gate(e2)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        g5 = self.Conv5_gate(e3)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        c6 = self.Conv6_cla(e5)
        g6 = self.Conv6_gate(e4)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        c5 = self.Conv5_cla(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        c4 = self.Conv4_cla(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        c3 = self.Conv3_cla(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        c2 = self.Conv2_cla(d2)

#         out = self.Conv(d2)

        #d1 = self.active(out)
        return (c6,c5,c4,c3,c2),(g6,g5,g4,g3)



class UNet_gate_varmean(UNet_gate):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3,CG = Condidence_gate_varmean):
        super(UNet_gate_varmean,self).__init__(in_ch, out_ch, kernel_size,decoder_ks,CG)
