from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .models.SAN.san import SAM
from .models.shared import conv_block, up_conv

class full_attention_conv(nn.Module):
    def __init__(self, in_channels, out_channels, att_hidden=32, att_dim=32, kernel_size=3):
        assert (kernel_size == 3)

        super(full_attention_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.att_dim = att_dim
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1)
        self.att_key = nn.Sequential(
            nn.Conv2d(out_channels, att_hidden, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(att_hidden, att_dim, kernel_size=1)
        )
        self.att_query = nn.Sequential(
            nn.Conv2d(out_channels, att_hidden, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(att_hidden, att_dim, kernel_size=1)  # (B, att_dim, H, W)
        )

    def forward(self, x):
        x = self.conv(x)
        B, Cout, H, W = x.shape
        att_dim = self.att_dim

        att_q = self.att_query(x).view(
            B, att_dim, H * W, 1, 1).permute(0, 2, 1, 3, 4)  # B, H * W, att_dim, 1, 1
        att_k = self.att_key(x)  # B,  att_dim, H, W

        att = torch.zeros(B, H*W, H, W).to(x.device)
        for b in range(B):
            att[b, :, :, :] = nn.functional.conv2d(
                att_k[b:b+1, :, :, :], att_q[b, :, :, :, :])

        att = att.view(B, H, W, 1, H*W)
        att = torch.nn.functional.softmax(att, dim=-1)
        att = att * x.view(B, 1, 1, Cout, H*W)
        att = torch.sum(att, dim=-1).permute(0, 3, 1,
                                             2)  # output (B, Cout, H, W)

        return x + att


class full_attention_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, att_hidden=32, att_dim=16, kernel_size=3):
        super(full_attention_conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,
                      stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            full_attention_conv(out_channels, out_channels,
                                att_hidden, att_dim, kernel_size),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class bi_attention_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(bi_attention_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1)
        self.mlp = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
        )
        self.query = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
        )

    def forward(self, x):
        y_conv = self.conv(x)
        y_mlp = self.mlp(x)
        q = self.query(x)
        att_conv = torch.sum(y_conv * q, dim=1, keepdim=True) / \
            math.sqrt(self.out_channels)
        att_mlp = torch.sum(y_mlp * q, dim=1, keepdim=True) / \
            math.sqrt(self.out_channels)
        att = torch.cat((att_conv, att_mlp), dim=1)
        att = nn.functional.softmax(att, dim=1)  # B*2*H*W

        return y_conv * att[:, 0:1, :, :] + y_mlp * att[:, 1:2, :, :]


class bi_attention_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(bi_attention_conv_block, self).__init__()
        self.conv = nn.Sequential(
            bi_attention_conv(in_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            bi_attention_conv(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)


class Bi_Attention_U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(Bi_Attention_U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # self.Conv1 = bi_attention_conv_block(
        #     in_ch, filters[0])
        self.Conv2 = bi_attention_conv_block(
            filters[0], filters[1])
        self.Conv3 = bi_attention_conv_block(
            filters[1], filters[2])
        self.Conv4 = bi_attention_conv_block(
            filters[2], filters[3])
        self.Conv5 = bi_attention_conv_block(
            filters[3], filters[4])

        self.Conv1 = conv_block(in_ch, filters[0])
        # self.Conv2 = conv_block(filters[0], filters[1])
        # self.Conv3 = conv_block(filters[1], filters[2])
        # self.Conv4 = conv_block(filters[2], filters[3])
        # self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])
        # self.Up_conv5 = bi_attention_conv_block(
        #     filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])
        # self.Up_conv4 = bi_attention_conv_block(
        #     filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])
        # self.Up_conv3 = bi_attention_conv_block(
        #     filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])
        # self.Up_conv2 = bi_attention_conv_block(
        #     filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        y = self.Up5(e5)
        y = torch.cat((e4, y), dim=1)
        y = self.Up_conv5(y)

        y = self.Up4(y)
        y = torch.cat((e3, y), dim=1)
        y = self.Up_conv4(y)

        y = self.Up3(y)
        y = torch.cat((e2, y), dim=1)
        y = self.Up_conv3(y)

        y = self.Up2(y)
        y = torch.cat((e1, y), dim=1)
        y = self.Up_conv2(y)

        y = self.Conv(y)

        # d1 = self.active(out)

        return y


class tri_attention_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(tri_attention_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv5 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=2, stride=1)
        self.conv3 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1)
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, stride=1)
        self.query = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                      kernel_size=1, padding=0, stride=1),
        )

    def forward(self, x):
        y_conv5 = self.conv5(x)
        y_conv3 = self.conv3(x)
        y_conv1 = self.conv1(x)
        q = self.query(x)
        att_conv5 = torch.sum(y_conv5 * q, dim=1, keepdim=True) / \
            math.sqrt(self.out_channels)
        att_conv3 = torch.sum(y_conv3 * q, dim=1, keepdim=True) / \
            math.sqrt(self.out_channels)
        att_conv1 = torch.sum(y_conv1 * q, dim=1, keepdim=True) / \
            math.sqrt(self.out_channels)

        att = torch.cat((att_conv5, att_conv3, att_conv1), dim=1)
        att = nn.functional.softmax(att, dim=1)  # B*3*H*W

        return y_conv5 * att[:, 0:1, :, :] + y_conv3 * att[:, 1:2, :, :] + y_conv1 * att[:, 2:3, :, :]


class tri_attention_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(tri_attention_conv_block, self).__init__()
        self.conv = nn.Sequential(
            tri_attention_conv(in_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            tri_attention_conv(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)


class Tri_Attention_U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(Tri_Attention_U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # self.Conv1 = tri_attention_conv_block(
        #     in_ch, filters[0])
        self.Conv2 = tri_attention_conv_block(
            filters[0], filters[1])
        self.Conv3 = tri_attention_conv_block(
            filters[1], filters[2])
        self.Conv4 = tri_attention_conv_block(
            filters[2], filters[3])
        self.Conv5 = tri_attention_conv_block(
            filters[3], filters[4])

        self.Conv1 = conv_block(in_ch, filters[0])
        # self.Conv2 = conv_block(filters[0], filters[1])
        # self.Conv3 = conv_block(filters[1], filters[2])
        # self.Conv4 = conv_block(filters[2], filters[3])
        # self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])
        # self.Up_conv5 = tri_attention_conv_block(
        #     filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])
        # self.Up_conv4 = tri_attention_conv_block(
        #     filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])
        # self.Up_conv3 = tri_attention_conv_block(
        #     filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])
        # self.Up_conv2 = tri_attention_conv_block(
        #     filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        y = self.Up5(e5)
        y = torch.cat((e4, y), dim=1)
        y = self.Up_conv5(y)

        y = self.Up4(y)
        y = torch.cat((e3, y), dim=1)
        y = self.Up_conv4(y)

        y = self.Up3(y)
        y = torch.cat((e2, y), dim=1)
        y = self.Up_conv3(y)

        y = self.Up2(y)
        y = torch.cat((e1, y), dim=1)
        y = self.Up_conv2(y)

        y = self.Conv(y)

        # d1 = self.active(out)

        return y




class Full_Attention_U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, out_ch=1):
        super(Full_Attention_U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        # self.Up_conv5 = conv_block(filters[4], filters[3])
        self.Up_conv5 = full_attention_conv_block(filters[4], filters[3], 32)

        self.Up4 = up_conv(filters[3], filters[2])
        # self.Up_conv4 = conv_block(filters[3], filters[2])
        self.Up_conv4 = full_attention_conv_block(filters[3], filters[2], 32)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])
        # self.Up_conv3 = full_attention_conv_block(filters[2], filters[1], 32)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out



class sam_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, sa_type, kernel_size=7, stride=1):
        super(sam_conv_block, self).__init__()
        self.sam = SAM(sa_type=sa_type, in_planes=in_channels, rel_planes=in_channels//16,
                       out_planes=in_channels//4, share_planes=8, kernel_size=kernel_size, stride=stride)
        self.bn1 = nn.BatchNorm2d(in_channels//4)
        self.conv = nn.Conv2d(in_channels//4, out_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn1(self.sam(x)))
        out = self.relu(self.bn2(self.conv(out)))
        return out


class SAM_U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch, out_ch, sa_type):
        super(SAM_U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        # self.Conv2 = conv_block(filters[0], filters[1])
        # self.Conv3 = conv_block(filters[1], filters[2])
        # self.Conv4 = conv_block(filters[2], filters[3])
        # self.Conv5 = conv_block(filters[3], filters[4])
        # self.Conv1 = sam_conv_block(in_ch, filters[0], sa_type)
        self.Conv2 = sam_conv_block(filters[0], filters[1], sa_type)
        self.Conv3 = sam_conv_block(filters[1], filters[2], sa_type)
        self.Conv4 = sam_conv_block(filters[2], filters[3], sa_type)
        self.Conv5 = sam_conv_block(filters[3], filters[4], sa_type)

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])
        # self.Up_conv5 = sam_conv_block(filters[4], filters[3], sa_type)

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])
        # self.Up_conv4 = sam_conv_block(filters[3], filters[2], sa_type)

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])
        # self.Up_conv3 = sam_conv_block(filters[2], filters[1], sa_type)

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])
        # self.Up_conv2 = sam_conv_block(filters[1], filters[0], sa_type)

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out


class Multiout_U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch, out_ch):
        super(Multiout_U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])


        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.out_16 = conv_block(filters[4], filters[0])
        self.out_8 = conv_block(filters[3], filters[0])
        self.out_4 = conv_block(filters[2], filters[0])
        self.out_2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        self.out_Up2 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.out_Up4 = nn.Upsample(scale_factor=4, mode='bilinear')
        self.out_Up8 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.out_Up16 = nn.Upsample(scale_factor=16, mode='bilinear')

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Conv2(self.Maxpool1(e1))

        e3 = self.Conv3(self.Maxpool2(e2))

        e4 = self.Conv4(self.Maxpool3(e3))

        e5 = self.Conv5(self.Maxpool4(e4))

        d5 = self.Up_conv5(torch.cat((e4, self.Up5(e5)), dim=1))

        d4 = self.Up_conv4(torch.cat((e3, self.Up4(d5)), dim=1))

        d3 = self.Up_conv3(torch.cat((e2, self.Up3(d4)), dim=1))

        d2 = self.Up_conv2(torch.cat((e1, self.Up2(d3)), dim=1))

        d2 = self.Conv(d2)  # B*5*240*240
        d3 = self.out_Up2(self.Conv(self.out_2(d3)))  # B*5*120*120 => B*5*240*240
        d4 = self.out_Up4(self.Conv(self.out_4(d4)))  # B*5*60*60 =>
        d5 = self.out_Up8(self.Conv(self.out_8(d5)))  # B*5*30*30 =>
        e5 = self.out_Up16(self.Conv(self.out_16(e5)))  # B*5*15*15 =>

        # d1 = self.active(out)

        return [d2, d3, d4, d5, e5]
