from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from ..shared import up_conv 
from .lib.functional import subtraction2,dotproduction2, aggregation



def sample_gumbel(shape,device, eps=1e-20):
    U = torch.Tensor(shape).uniform_(0,1).to(device) #.cuda()
    return -(torch.log(-torch.log(U + eps) + eps))

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size(),device = logits.device)
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature=1, hard=False):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y = (y_hard - y).detach() + y
    return y

    

class attention_conv_v(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, att_hidden, att_mh, att_sm, att_two_w, visualization):

        super(attention_conv_v, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1) // 2

        self.att_hidden = att_hidden
        self.att_mh = att_mh
        self.att_out = att_hidden * self.att_mh
        self.att_sm = att_sm
        self.att_two_w = att_two_w
        self.visualization = visualization

        self.x_mlp = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,padding=self.padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        
        self.att_wq = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2, stride=2),
            nn.Conv2d(in_channels, self.att_out//4, kernel_size=7,stride=1, padding=3, bias=True),
            nn.BatchNorm2d(self.att_out//4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(self.att_out//4, self.att_out, kernel_size=3,stride=1, padding=1, bias=True),
            
        )

        self.att_wk = nn.Sequential(
            nn.Conv2d(in_channels, self.att_out, kernel_size=1),
            nn.BatchNorm2d(self.att_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.att_out, self.att_out, kernel_size=1)  # attention Linear
            )

        # self.att_f = nn.Sequential(nn.Linear(self.att_hidden*2,64),
        #                             nn.Sigmoid(),
        #                             nn.Linear(64,1),
        #                             )

        
        # self.conv_att = nn.Sequential(
        #     nn.Conv2d(in_channels, self.att_out, kernel_size=self.kernel_size+4,padding=self.padding+2),
        #     nn.BatchNorm2d(self.att_out),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(self.att_out, self.att_mh*self.kernel_size*self.kernel_size, kernel_size=1),  # attention Linear
        #     )

        self.att= None

    def get_att_hidden(self, x, w):
        B, Cout, H, W = x.shape
        h = w(x).view(B, self.att_hidden, self.att_mh, H, W)  # B, att_hidden, att_mh, H, W
        h = h  #(torch.sqrt(torch.sum(h.pow(2), dim=1, keepdim=True)) + 1e-6)  # B, att_hidden, att_mh, H, W
        h = h.view(B, -1, H, W)

        return h

    def forward(self, x):
        inpux_x = x
        x = self.x_mlp(x)  # B, Cout, H, W

        B, _, H, W = x.shape
        fea = self.get_att_hidden(inpux_x, self.att_wq)
            
        att = dotproduction2(self.get_att_hidden(inpux_x, self.att_wq), 
                                self.get_att_hidden(inpux_x, self.att_wk),
                                self.kernel_size).view(B, self.att_hidden, self.att_mh, -1, H*W)  # B, att_hidden, att_mh, K^2, HW
            # att2 = subtraction2(self.get_att_hidden(x, self.att_wq), 
            #                      self.get_att_hidden(x, self.att_wk),
            #                      self.kernel_size,padding=1).view(B, self.att_hidden, self.att_mh, -1, H*W) 
        # else:
        #     fea = self.get_att_hidden(inpux_x, self.att_wq) # B, att_hidden, att_mh, H,W
        #     att = dotproduction2(fea, fea, self.kernel_size).view(B, self.att_hidden, self.att_mh, -1, H*W)  # B, att_hidden, att_mh, K^2, HW
            # att2 = subtraction2(fea, fea, self.kernel_size,padding=1).view(B, self.att_hidden, self.att_mh, -1, H*W)  # B, att_hidden, att_mh, K^2, HW

        # att = self.conv_att(inpux_x)
        # att = att.reshape(B,self.att_mh,self.kernel_size**2,H*W)
        # att = torch.nn.functional.softmax(att * self.att_sm, dim=2)
        # fea = fea.reshape(B,self.att_hidden,self.att_mh,1,H*W).repeat(1,1,1,self.kernel_size**2,1)
        # att = torch.cat([fea,att],dim = 1)
        # att = self.att_f(att.permute(0,2,3,4,1)).squeeze(dim = -1)
        # self.att = att.detach()

        # att = torch.nn.functional.softmax(att * self.att_sm, dim=2)  # B, att_mh, K^2, HW
        att = torch.sum(att, dim=1)/ (self.att_hidden **0.5)  # B, att_hidden, att_mh, K^2, HW ----->>  B, att_mh, K^2, HW
        att = F.sigmoid(att)
        # if self.att is not None:
        #     att = self.att/(torch.sum(self.att,dim = 2,keepdim=True)+1e-8)
        # else:
        #     att = torch.nn.functional.softmax(att * self.att_sm, dim=2)  # B, att_mh, K^2, HW
        # if self.training:
        #     att = att.permute(0,1,3,2).reshape(B*self.att_mh*H*W,self.kernel_size**2).contiguous()
        #     att = F.softmax(att,dim = -1)
        #     att = torch.log(att)

        #     r = torch.zeros_like(att)
        #     for _ in range(2):
        #         r = r+gumbel_softmax(att)
        #     r = r/3
        #     att = r.reshape(B,self.att_mh,H*W,self.kernel_size**2).permute(0,1,3,2)
        # else:
        #     att = torch.nn.functional.softmax(att * self.att_sm, dim=2)  # B, att_mh, K^2, HW
        # att = torch.nn.functional.tanh(att)
        # att = att/torch.sum(att,dim = 2)
        x = aggregation(x, att, kernel_size=self.kernel_size, padding=self.padding)
        
        if self.visualization:
            return x, x.detach(), att
        else:
            return x


class attention_conv_block_v(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, att_hidden, att_mh, att_sm, att_two_w, visualization=False):
        super(attention_conv_block_v, self).__init__()
        self.kernel_size = kernel_size
        self.visualization = visualization

        self.att_conv_1 = attention_conv_v(in_ch, out_ch, kernel_size, att_hidden, att_mh, att_sm, att_two_w, visualization)
        self.att_conv_2 = attention_conv_v(out_ch, out_ch, kernel_size, att_hidden, att_mh, att_sm, att_two_w, visualization)


    def forward(self, x):
        if self.visualization:
            B, Cin, H, W = x.shape

            x, x_before, att = self.att_conv_1(x)

            return self.att_conv_2(x)[0], x[:, 0:4, :, :].detach(), x_before[:, 0:4, :, :], att[:, 0:4, :, :].view(B, -1, self.kernel_size ** 2, H, W)
        else:
            return self.att_conv_2(self.att_conv_1(x))


class KAUNet_v(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch, out_ch, att_mh, att_sm, att_ks, att_two_w):
        super(KAUNet_v, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        att_hidden = 32

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = attention_conv_block_v(
            in_ch, filters[0], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w, visualization=True)
        self.Conv2 = attention_conv_block_v(
            filters[0], filters[1], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)
        self.Conv3 = attention_conv_block_v(
            filters[1], filters[2], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)
        self.Conv4 = attention_conv_block_v(
            filters[2], filters[3], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)
        self.Conv5 = attention_conv_block_v(
            filters[3], filters[4], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)

        # self.Conv1 = conv_block(in_ch, filters[0])
        # self.Conv2 = conv_block(filters[0], filters[1])
        # self.Conv3 = conv_block(filters[1], filters[2])
        # self.Conv4 = conv_block(filters[2], filters[3])
        # self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        # self.Up_conv5 = conv_block(filters[4], filters[3])
        self.Up_conv5 = attention_conv_block_v(
            filters[4], filters[3], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)

        self.Up4 = up_conv(filters[3], filters[2])
        # self.Up_conv4 = conv_block(filters[3], filters[2])
        self.Up_conv4 = attention_conv_block_v(
            filters[3], filters[2], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)

        self.Up3 = up_conv(filters[2], filters[1])
        # self.Up_conv3 = conv_block(filters[2], filters[1])
        self.Up_conv3 = attention_conv_block_v(
            filters[2], filters[1], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)

        self.Up2 = up_conv(filters[1], filters[0])
        # self.Up_conv2 = conv_block(filters[1], filters[0])
        self.Up_conv2 = attention_conv_block_v(
            filters[1], filters[0], att_ks, att_hidden=att_hidden, att_mh=att_mh, att_sm=att_sm, att_two_w=att_two_w)

        self.Conv = nn.Conv2d(filters[0], out_ch,
                              kernel_size=1, stride=1, padding=0)

        # self.Up_v = nn.Upsample(scale_factor=2, mode='nearest')
        self.Up_v = nn.Upsample(scale_factor=1, mode='nearest')

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1, x_after, x_before, att = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        # e2, x_after, x_before, att = self.Conv2(e2)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        y = self.Up5(e5)
        y = torch.cat((e4, y), dim=1)
        y = self.Up_conv5(y)

        y = self.Up4(y)
        y = torch.cat((e3, y), dim=1)
        y = self.Up_conv4(y)

        y = self.Up3(y)
        y = torch.cat((e2, y), dim=1)
        y = self.Up_conv3(y)

        y = self.Up2(y)
        y = torch.cat((e1, y), dim=1)
        y = self.Up_conv2(y)

        # y, x_after, x_before, att = self.Up_conv2(y)

        y = self.Conv(y)

        # d1 = self.active(out)

        B, C, K2, H, W = att.shape

        return y, self.Up_v(x_after), self.Up_v(x_before), self.Up_v(att.view(B, -1, H, W)).view(B, C, K2, H, W)  # x_after/before: B*4*240*240, att:B*1*9*240*240
