from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class FuseGate(nn.Module):
    def __init__(self, in_ch, os, kernel_size=3):
        super(FuseGate, self).__init__()
        padding = kernel_size // 2

        self.x_mask_shallow = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=in_ch, kernel_size=os, stride=os, padding=0),
            nn.BatchNorm2d(in_ch),
            nn.Sigmoid()
        )

        self.x_mask_deep = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=in_ch, kernel_size=os, stride=os, padding=0),
            nn.BatchNorm2d(in_ch),
            nn.Sigmoid()
        )

        # self.x_mask = nn.Sequential(
        #     nn.Conv2d(in_channels=64, out_channels=64, kernel_size=9, stride=1, padding=4),
        #     nn.BatchNorm2d(64),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(in_channels=64, out_channels=2, kernel_size=os, stride=os, padding=0),
        #     nn.Softmax(dim=1)
        # )

        self.update_coeff = nn.Sequential(
            nn.Conv2d(in_channels=in_ch*2, out_channels=in_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(in_ch),
            nn.Sigmoid()
        )
        self.update_value = nn.Sequential(
            nn.Conv2d(in_channels=in_ch*2, out_channels=in_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(in_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(in_ch),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(in_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, shallow, deep):
        assert shallow.shape == deep.shape

        mask_shallow = self.x_mask_shallow(x)
        mask_deep = self.x_mask_deep(x)
        # mask = self.x_mask(x)

        # shallow_new = shallow * mask[:, 0:1, :, :]
        # deep_new = deep * mask[:, 1:2, :, :]

        shallow_new = shallow * mask_shallow
        deep_new = deep * mask_deep

        shallow_new += self.update_coeff(torch.cat((shallow_new, deep_new), dim=1)) * self.update_value(torch.cat((shallow_new, deep_new), dim=1))

        return self.conv(shallow_new)


class UNet_fuse_x(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(UNet_fuse_x, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Fuse5 = FuseGate(in_ch=filters[3], os=8, kernel_size=3)

        self.Up4 = up_conv(filters[3], filters[2])
        self.Fuse4 = FuseGate(in_ch=filters[2], os=4, kernel_size=5)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Fuse3 = FuseGate(in_ch=filters[1], os=2, kernel_size=5)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Fuse2 = FuseGate(in_ch=filters[0], os=1, kernel_size=9)

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=9, stride=1, padding=4)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = self.Fuse5(e1, e4, d5)

        d4 = self.Up4(d5)
        d4 = self.Fuse4(e1, e3, d4)

        d3 = self.Up3(d4)
        d3 = self.Fuse3(e1, e2, d3)

        d2 = self.Up2(d3)
        d2 = self.Fuse2(e1, e1, d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out
