from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class Head(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1, n_filters=32):
        super().__init__()

        n1 = n_filters
        filters = [n1 * i for i in [1, 2, 4, 8, 16]]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], 2 * filters[0])

        self.edge1 = conv_block(filters[0], filters[0], kernel_size=7)
        self.edge2 = nn.Conv2d(filters[0], 1, kernel_size=7, padding=3)

        self.area1 = conv_block(filters[0], filters[0], kernel_size=7)
        self.area2 = nn.Conv2d(filters[0], 1, kernel_size=7, padding=3)

        self.filter0 = filters[0]

        self.output1 = conv_block(filters[1], filters[1], kernel_size=7)
        self.output2 = nn.Conv2d(filters[1], 1, kernel_size=7, padding=3)

        self.img_q = conv_block(in_ch, filters[0], kernel_size=7)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        edge1 = self.edge1(d2[:, :self.filter0, :, :])
        edge2 = self.edge2(edge1)

        area1 = self.area1(d2[:, self.filter0:, :, :])
        area2 = self.area2(area1)

        img_q = self.img_q(x)
        att_e = (img_q * edge1).sum(dim=1, keepdim=True) / math.sqrt(self.filter0)
        att_a = (img_q * area1).sum(dim=1, keepdim=True) / math.sqrt(self.filter0)
        att = torch.cat((att_e, att_a), dim=1)
        att = F.softmax(att, dim=1)

        # out = d2.clone()
        feat = torch.cat((edge1, area1), dim=1)
        feat[:, :self.filter0, :, :] *= att[:, 0:1, :, :]
        feat[:, self.filter0:, :, :] *= att[:, 1:2, :, :]

        out = self.output2(self.output1(feat))

        # d1 = self.active(out)

        return feat, edge2, area2, out


class UNet_multihead_edge_att(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 32
        self.filters = n_filters
        self.head1 = Head(in_ch=4, out_ch=1, n_filters=n_filters)
        self.head2 = Head(in_ch=4, out_ch=1, n_filters=n_filters)
        self.head3 = Head(in_ch=4, out_ch=1, n_filters=n_filters)
        self.head4 = Head(in_ch=4, out_ch=1, n_filters=n_filters)

        self.img_q = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=n_filters, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(n_filters * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=n_filters * 2, out_channels=n_filters * 2, kernel_size=5, stride=1, padding=2)
        )

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=n_filters*4*2, out_channels=n_filters, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=n_filters, out_channels=n_filters, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=n_filters, out_channels=5, kernel_size=5, stride=1, padding=2)
        )

    def forward(self, x):
        feat1, edge1, area1, out1 = self.head1(x)
        feat2, edge2, area2, out2 = self.head2(x)
        feat3, edge3, area3, out3 = self.head3(x)
        feat4, edge4, area4, out4 = self.head4(x)

        out = torch.cat((-1 * torch.ones(out1.shape).float().to(out1.device), out1, out2, out3, out4), dim=1)

        edges = [edge1, edge2, edge3, edge4]
        areas = [area1, area2, area3, area4]
        outs = [out1, out2, out3, out4]

        # feat_d = self.filters * 2

        # img_q = self.img_q(x)
        # att_1 = (img_q * feat1).sum(dim=1, keepdim=True) / math.sqrt(feat_d)
        # att_2 = (img_q * feat2).sum(dim=1, keepdim=True) / math.sqrt(feat_d)
        # att_3 = (img_q * feat3).sum(dim=1, keepdim=True) / math.sqrt(feat_d)
        # att_4 = (img_q * feat4).sum(dim=1, keepdim=True) / math.sqrt(feat_d)
        # att = torch.cat((att_1, att_2, att_3, att_4), dim=1)
        # att = F.softmax(att, dim=1)

        # out = torch.cat((feat1 * att[:, 0:1, :, :],
        #                  feat2 * att[:, 1:2, :, :],
        #                  feat3 * att[:, 2:3, :, :],
        #                  feat4 * att[:, 3:4, :, :]), dim=1)

        # out = torch.cat((feat1, feat2, feat3, feat4), dim=1)

        # out = self.conv(out)

        return out, outs, edges, areas
