from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import up_conv
class Conv2d_Linear(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,padding=0,bias = None,stride=1, dilation=1, groups=1):
        '''
        sign(np.array): kernel_size*kernel_size
        '''
        super(Conv2d_Linear,self).__init__()
       
        self.weight = nn.Parameter(torch.randn(out_channels,in_channels,kernel_size,kernel_size))

 

        
        self.sign = nn.Parameter(torch.zeros(out_channels,1,kernel_size,kernel_size).float(),requires_grad = False)
        delta = out_channels//8
        deltas = [0,delta,delta*2,delta*3,delta*4,out_channels]
        self.deltas = deltas
        idx = torch.arange(kernel_size)
        idx_ = torch.arange(kernel_size-1,-1,-1)
        self.sign.data[deltas[0]:deltas[1],:,(kernel_size-1)//2] = 1.
        self.sign.data[deltas[1]:deltas[2],:,:,(kernel_size-1)//2] = 1.
        self.sign.data[deltas[2]:deltas[3],:,idx,idx] = 1.
        self.sign.data[deltas[3]:deltas[4],:,idx,idx_] = 1.
        self.sign.data[deltas[4]:deltas[5]] = 1.
       
           
        self.bias = None
        self.stride=stride
        self.padding=padding
        self.dilation=dilation
        self.groups=groups

        self.Conv = nn.Sequential(nn.Conv2d(in_channels, 5,kernel_size=3, stride=1, padding=1),nn.Sigmoid())
       
    def forward(self,x): #
        x_input = x
        x = torch.nn.functional.conv2d(x,self.weight*self.sign,self.bias,self.stride,self.padding,self.dilation,self.groups)
        x_channel_weight= self.Conv(x_input)
        x[:,self.deltas[0]:self.deltas[1]] *= x_channel_weight[:,0:1]
        x[:,self.deltas[1]:self.deltas[2]] *= x_channel_weight[:,1:2]
        x[:,self.deltas[2]:self.deltas[3]] *= x_channel_weight[:,2:3]
        x[:,self.deltas[3]:self.deltas[4]] *= x_channel_weight[:,3:4]
        x[:,self.deltas[4]:self.deltas[5]] *= x_channel_weight[:,4:5]

        return x

class conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3):
        super(conv_block, self).__init__()

        padding = kernel_size // 2

        self.conv = nn.Sequential(
            Conv2d_Linear(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            Conv2d_Linear(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet_Linear(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_Linear, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out
