

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F

from fairseq import utils


class MultiheadAttentionLayers(nn.Module):


    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
        super().__init__()
        self.embed_dim = int(embed_dim/2)
        self.last_linears = nn.Linear(embed_dim,self.embed_dim)
        self.extand_linears = nn.Linear(embed_dim,self.embed_dim)
        self.proj_linear = nn.Linear(self.embed_dim,embed_dim)
        embed_dim  =  self.embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.in_proj_weight_lasts = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        self.in_proj_weight_extands = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        if bias:
            self.in_proj_bias_lasts = Parameter(torch.Tensor(3 * embed_dim))
            self.in_proj_bias_extands = Parameter(torch.Tensor(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias_lasts', None)
            self.register_parameter('in_proj_bias_extands', None)
        self.out_proj_lee = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj_ell = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj_eee = nn.Linear(embed_dim, embed_dim, bias=bias)

        #self.out_proj = nn.Linear(embed_dim, embed_dim*2, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self.reset_parameters()

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight_extands)
        nn.init.xavier_uniform_(self.in_proj_weight_lasts)
        nn.init.xavier_uniform_(self.out_proj_lee.weight)
        nn.init.xavier_uniform_(self.out_proj_eee.weight)
        nn.init.xavier_uniform_(self.out_proj_ell.weight)
        #nn.init.xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias_lasts is not None:
            nn.init.constant_(self.in_proj_bias_lasts, 0.)
            nn.init.constant_(self.in_proj_bias_extands, 0.)
            nn.init.constant_(self.out_proj_ell.bias, 0.)
            nn.init.constant_(self.out_proj_lee.bias, 0.)
            nn.init.constant_(self.out_proj_eee.bias, 0.)
            #nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def forward(self, last, extands,  key_padding_mask=None,
                need_weights=True, static_kv=False, attn_mask=None):
        
        last = self.last_linears(last)
        extands = self.extand_linears(extands)
        qkv_same = True
        #kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = last.size()
        assert embed_dim == self.embed_dim
        assert list(last.size()) == [tgt_len, bsz, embed_dim]
        assert last.size() == extands.size()

        # self-attention
        qe, ke, ve = self.in_proj_qkv_extands(extands)

        ql, kl, vl = self.in_proj_qkv_lasts(last)

        qe *= self.scaling
        ql *= self.scaling



        qe = qe.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if ke is not None:
            ke = ke.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if ve is not None:
            ve = ve.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        ql = ql.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if kl is not None:
            kl = kl.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if vl is not None:
            vl = vl.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)


        src_len = kl.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len


        attn_weights_le = torch.bmm(ql, ke.transpose(1, 2))
        attn_weights_el = torch.bmm(qe, kl.transpose(1, 2))
        attn_weights_ee = torch.bmm(qe, ke.transpose(1, 2))
        assert list(attn_weights_le.size()) == [bsz * self.num_heads, tgt_len, src_len]
        assert list(attn_weights_el.size()) == [bsz * self.num_heads, tgt_len, src_len]
        assert list(attn_weights_ee.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights_el.size(0), 1, 1)
            attn_weights_el += attn_mask
            attn_weights_ee += attn_mask
            attn_weights_le += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights_le = attn_weights_le.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights_ee = attn_weights_ee.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights_el = attn_weights_el.view(bsz, self.num_heads, tgt_len, src_len)
            if self.onnx_trace:
                attn_weights_ee = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights_ee.float()
                ).type_as(attn_weights_ee)
                attn_weights_el = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights_el.float()
                ).type_as(attn_weights_el)
                attn_weights_le = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights_le.float()
                ).type_as(attn_weights_le)
            else:
                attn_weights_ee = attn_weights_ee.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights_ee)
                attn_weights_le = attn_weights_le.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights_le)
                attn_weights_el = attn_weights_el.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights_el)  # FP16 support: cast to float and back
            attn_weights_el = attn_weights_el.view(bsz * self.num_heads, tgt_len, src_len)
            attn_weights_ee = attn_weights_ee.view(bsz * self.num_heads, tgt_len, src_len)
            attn_weights_le = attn_weights_le.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights_le = utils.softmax(
            attn_weights_le, dim=-1, onnx_trace=self.onnx_trace,
        ).type_as(attn_weights_le)
        attn_weights_le = F.dropout(attn_weights_le, p=self.dropout, training=self.training)

        attn_lee = torch.bmm(attn_weights_le, ve)

        attn_weights_el = utils.softmax(
            attn_weights_el, dim=-1, onnx_trace=self.onnx_trace,
        ).type_as(attn_weights_el)
        attn_weights_el = F.dropout(attn_weights_el, p=self.dropout, training=self.training)

        attn_ell = torch.bmm(attn_weights_el, vl)

        attn_weights_ee = utils.softmax(
            attn_weights_ee, dim=-1, onnx_trace=self.onnx_trace,
        ).type_as(attn_weights_ee)
        attn_weights_ee = F.dropout(attn_weights_ee, p=self.dropout, training=self.training)

        attn_eee = torch.bmm(attn_weights_ee, ve)
        assert list(attn_lee.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        assert list(attn_ell.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        assert list(attn_eee.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]

        if (self.onnx_trace and attn_ell.size(1) == 1):
            
            attn_ell = attn_ell.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn_ell = attn_ell.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

        if (self.onnx_trace and attn_eee.size(1) == 1):
            
            attn_eee = attn_eee.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn_eee = attn_eee.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

        if (self.onnx_trace and attn_lee.size(1) == 1):
            
            attn_lee = attn_lee.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn_lee = attn_lee.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)

        attn_lee = self.out_proj_lee(attn_lee)
        attn_ell = self.out_proj_ell(attn_ell)
        attn_eee = self.out_proj_eee(attn_eee)


        if need_weights:
            # average attention weights over heads
            attn_weights_el = attn_weights_el.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights_el = attn_weights_el.sum(dim=1) / self.num_heads
            attn_weights_ee = attn_weights_ee.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights_ee = attn_weights_ee.sum(dim=1) / self.num_heads
            attn_weights_le = attn_weights_le.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights_le = attn_weights_le.sum(dim=1) / self.num_heads
            attn_weights = attn_weights_el + attn_weights_le + attn_weights_ee
        else:
            attn_weights = None
        
        
        return self.proj_linear(attn_lee+attn_ell+attn_eee), attn_weights

    def in_proj_qkv_lasts(self, query):
        return self._in_proj_last(query).chunk(3, dim=-1)

    def in_proj_qkv_extands(self, query):
        return self._in_proj_extand(query).chunk(3, dim=-1)



    def _in_proj_last(self, input, start=0, end=None):
        weight = self.in_proj_weight_lasts
        bias = self.in_proj_bias_lasts
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def _in_proj_extand(self, input, start=0, end=None):
        weight = self.in_proj_weight_extands
        bias = self.in_proj_bias_extands
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def reorder_incremental_state(self, incremental_state, new_order):
        """Reorder buffered internal state (for incremental generation)."""
        input_buffer = self._get_input_buffer(incremental_state)
        if input_buffer is not None:
            for k in input_buffer.keys():
                input_buffer[k] = input_buffer[k].index_select(0, new_order)
            self._set_input_buffer(incremental_state, input_buffer)

    def _get_input_buffer(self, incremental_state):
        return utils.get_incremental_state(
            self,
            incremental_state,
            'attn_state',
        ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        utils.set_incremental_state(
            self,
            incremental_state,
            'attn_state',
            buffer,
        )
