###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2018
###########################################################################
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .syncbn import SyncBatchNorm

__all__ = ['ACFModule', 'MixtureOfSoftMaxACF']

class ACFModule(nn.Module):
    """ Multi-Head Attention module """
    def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm, 
                 kq_transform='conv', value_transform='conv',
                 pooling=True, concat=False, dropout=0.1):
        super(ACFModule, self).__init__()

        self.n_head = n_head
        self.n_mix = n_mix
        self.d_k = d_k
        self.d_v = d_v
        self.pooling = pooling
        self.concat = concat

        if self.pooling:
            self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False)

        if kq_transform == 'conv':
            self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1)
            nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        elif kq_transform == 'ffn':
            self.conv_qs = nn.Sequential(
                nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False),
                norm_layer(n_head*d_k),
                nn.ReLU(True),
                nn.Conv2d(n_head*d_k, n_head*d_k, 1),
            )
            nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
        elif kq_transform == 'dffn':
            self.conv_qs = nn.Sequential(
                nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False),
                norm_layer(n_head*d_k),
                nn.ReLU(True),
                nn.Conv2d(n_head*d_k, n_head*d_k, 1),
            )
            nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
        else:
            raise NotImplemented
        #self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1)
        self.conv_ks = self.conv_qs
        if value_transform == 'conv':
            self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1)
        else:
            raise NotImplemented

        #nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k)

        self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False)
        self.norm_layer = norm_layer(d_model)

    def forward(self, x):
        residual = x

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        b_, c_, h_, w_ = x.size()

        if self.pooling:
            qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
            kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
            vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
        else:
            kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
            qt = kt
            vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)

        output, attn = self.attention(qt, kt, vt)

        output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_)

        output = self.conv(output)
        if self.concat:
            output = torch.cat((self.norm_layer(output), residual), 1)
        else:
            output = self.norm_layer(output) + residual
        return output

    def demo(self, x):
        residual = x

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        b_, c_, h_, w_ = x.size()

        if self.pooling:
            qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
            kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
            vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
        else:
            kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
            qt = kt
            vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)

        _, attn = self.attention(qt, kt, vt)
        attn.view(b_, n_head, h_*w_, -1)
        return attn

    def extra_repr(self):
        return 'n_head={}, n_mix={}, d_k={}, pooling={}' \
            .format(self.n_head, self.n_mix, self.d_k, self.pooling)


class MixtureOfSoftMaxACF(nn.Module):
    """"Mixture of SoftMax"""
    def __init__(self, n_mix, d_k, attn_dropout=0.1):
        super(MixtureOfSoftMaxACF, self).__init__()
        self.temperature = np.power(d_k, 0.5)
        self.n_mix = n_mix
        self.att_drop = attn_dropout
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax1 = nn.Softmax(dim=1)
        self.softmax2 = nn.Softmax(dim=2)
        self.d_k = d_k
        if n_mix > 1:
            self.weight = nn.Parameter(torch.Tensor(n_mix, d_k))
            std = np.power(n_mix, -0.5)
            self.weight.data.uniform_(-std, std)

    def forward(self, qt, kt, vt):
        B, d_k, N = qt.size()
        m = self.n_mix
        assert d_k == self.d_k
        d = d_k // m
        if m > 1:
            # \bar{v} \in R^{B, d_k, 1}
            bar_qt = torch.mean(qt, 2, True)
            # pi \in R^{B, m, 1}
            pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1)
        # reshape for n_mix
        q = qt.view(B*m, d, N).transpose(1, 2)
        N2 = kt.size(2)
        kt = kt.view(B*m, d, N2)
        v = vt.transpose(1, 2)
        # {Bm, N, N}
        attn = torch.bmm(q, kt)
        attn = attn / self.temperature
        attn = self.softmax2(attn)
        attn = self.dropout(attn)
        if m > 1:
            # attn \in R^{Bm, N, N2} => R^{B, N, N2}
            attn = (attn * pi).view(B, m, N, N2).sum(1)
        output = torch.bmm(attn, v)
        return output, attn
