from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from ..shared import conv_block, up_conv
from .lib.functional import dotproduction2, aggregation


class attention_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, attention_hidden, attention_share):
        assert (kernel_size == 3)

        super(attention_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.attention_hidden = attention_hidden
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_self = nn.Sequential(
            nn.Conv2d(in_channels, attention_hidden, kernel_size=1),
            nn.BatchNorm2d(attention_hidden),
            nn.ReLU(inplace=True),
            nn.Conv2d(attention_hidden, attention_hidden, kernel_size=1)  # attention Linear
        )

        self.att_neighbors = nn.Sequential(
            nn.Conv2d(in_channels, attention_hidden, kernel_size=1),
            nn.BatchNorm2d(attention_hidden),
            nn.ReLU(inplace=True),
            nn.Conv2d(attention_hidden, attention_hidden, kernel_size=1)  # attention Linear
        )

        self.att_mlp = nn.Sequential(nn.BatchNorm2d(attention_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(attention_hidden, attention_hidden, kernel_size=1, bias=False),
                                     nn.BatchNorm2d(attention_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(attention_hidden, out_channels // attention_share, kernel_size=1)
                                     )

        self.cnn = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        att = self.att_mlp(dotproduction2(self.att_self(x), self.att_neighbors(x), self.kernel_size)) / math.sqrt(self.attention_hidden)  # B, Cout//att_share, K^2, H*W
        att = torch.nn.functional.softmax(att, dim=-2)
        x = self.cnn(x)

        return aggregation(x, att, kernel_size=self.kernel_size, padding=self.padding)


class attention_conv_block(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, att_hidden, att_share):
        super(attention_conv_block, self).__init__()

        self.conv = nn.Sequential(
            attention_conv(in_ch, out_ch, kernel_size, att_hidden, att_share),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            attention_conv(out_ch, out_ch, kernel_size, att_hidden, att_share),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)


class KAUNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch, out_ch, att_share):
        super(KAUNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        att_hidden = filters[0:4]

        if att_share == 0:
            att_share = filters[1:5]
        else:
            att_share = [i * att_share for i in [1, 1, 1, 1]]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # self.Conv1 = attention_conv_block(
        #     in_ch, filters[0], att_hidden)
        self.Conv2 = attention_conv_block(
            filters[0], filters[1], 3, att_hidden=att_hidden[0], att_share=att_share[0])
        self.Conv3 = attention_conv_block(
            filters[1], filters[2], 3, att_hidden=att_hidden[1], att_share=att_share[1])
        self.Conv4 = attention_conv_block(
            filters[2], filters[3], 3, att_hidden=att_hidden[2], att_share=att_share[2])
        self.Conv5 = attention_conv_block(
            filters[3], filters[4], 3, att_hidden=att_hidden[3], att_share=att_share[3])

        self.Conv1 = conv_block(in_ch, filters[0])
        # self.Conv2 = conv_block(filters[0], filters[1])
        # self.Conv3 = conv_block(filters[1], filters[2])
        # self.Conv4 = conv_block(filters[2], filters[3])
        # self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        # self.Up_conv2 = attention_conv_block(
        #     filters[1], filters[0], att_hidden)

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        y = self.Up5(e5)
        y = torch.cat((e4, y), dim=1)
        y = self.Up_conv5(y)

        y = self.Up4(y)
        y = torch.cat((e3, y), dim=1)
        y = self.Up_conv4(y)

        y = self.Up3(y)
        y = torch.cat((e2, y), dim=1)
        y = self.Up_conv3(y)

        y = self.Up2(y)
        y = torch.cat((e1, y), dim=1)
        y = self.Up_conv2(y)

        y = self.Conv(y)

        # d1 = self.active(out)

        return y
