

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F

from fairseq import utils


class MultiheadAdd(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """

    def __init__(self, embed_dim, dropout=0., bias=True, ):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropouts = dropout
        #self.dropout = nn.Dropout(dropout)
        self.head_dim = self.embed_dim
        self.scaling = self.embed_dim ** -0.5
        self.num_heads = 1

        self.querylinear = nn.Linear(embed_dim, embed_dim)
        self.keylinear = nn.Linear(embed_dim, embed_dim)
        self.valuelinear = nn.Linear(embed_dim,embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)


        self.onnx_trace = False
        self.reset_parameters()



    def reset_parameters(self):
        nn.init.xavier_uniform_(self.out_proj.weight)
        #if self. is not None:
        nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self, el,le,ee,ll,  key_padding_mask=None):

        el = el.unsqueeze(-1)
        le = le.unsqueeze(-1)
        ee = ee.unsqueeze(-1)
        ll = ll.unsqueeze(-1)
        value = torch.cat([el,le,ee,ll], -1)
        value = value.transpose(-1, -2)

        tgt_len, bsz, valuesize, embed_dim = value.size()

        q = self.querylinear(value)
        k = self.keylinear(value)
        v = self.valuelinear(value)
        #v = value
        q *= self.scaling



        q = q.contiguous().view(tgt_len, bsz * self.num_heads,valuesize, self.head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, bsz * self.num_heads,valuesize, self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, bsz * self.num_heads,valuesize, self.head_dim).transpose(0, 1)


        src_len = k.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        attn_weights = torch.matmul(q, k.transpose(-1, -2))
        assert list(attn_weights.size()) == [bsz * self.num_heads, src_len, valuesize, valuesize]


        attn_weights = utils.softmax(
            attn_weights, dim=-1,
        ).type_as(attn_weights)
        attn_weights = F.dropout(attn_weights, p=self.dropouts, training=self.training)

        attn = torch.matmul(attn_weights, v)
        attn = torch.sum(attn , -2)
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn/4.)

        return attn
