import torch
from torch import nn
from torch.nn.init import xavier_uniform_ as xavier_uniform
from opt_einsum import contract

from model.activation import get_activation_fn


class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_dim = config.input_dim
        self.attention_dim = config.attention_dim

        self.w_linear = nn.Linear(self.attention_dim, self.attention_dim)
        self.b_linear = nn.Linear(self.attention_dim, 1)

        self.W = nn.Linear(self.input_dim, self.attention_dim)
        xavier_uniform(self.W.weight)

        self.attention_head = config.attention_head
        self.dropout = nn.Dropout(config.dropout_prob)

        self.pooling = config.pooling

        self.act_fn = get_activation_fn(config.activation)()
        self.u_reduce = nn.Linear(self.attention_dim,
                                  self.attention_dim // self.attention_head)

    def forward(self, h, word_mask, label_feature=None):
        m = self.get_label_queried_features(h, word_mask, label_feature)

        label_feature = self.transform_label_feats(label_feature)
        w = self.w_linear(label_feature)  # label * hidden
        b = self.b_linear(label_feature)  # label * 1
        logits = self.get_logits(m, w, b)
        return logits

    def get_logits(self, m, w=None, b=None):
        logits = contract('blh,lh->bl', m, w) + b.squeeze(-1)
        return logits

    def transform_label_feats(self, label_feat):
        label_count = label_feat.shape[0] // self.attention_head
        label_feat = label_feat.reshape(label_count, self.attention_head, -1)
        if self.pooling == 'max':
            label_feat = label_feat.max(dim=1)[0]
        elif self.pooling == 'mean':
            label_feat = label_feat.mean(dim=1)
        return label_feat

    def get_label_queried_features(self, h, word_mask, label_feat):
        z = self.act_fn(self.W(h))  # batch_size * seq_length * att_dim
        # z = self.W(h)  # batch_size * seq_length * att_dim
        batch_size, seq_length, att_dim = z.size()
        z_reshape = z.reshape(batch_size, seq_length, self.attention_head, att_dim // self.attention_head)
        # batch_size, seq_length, att_head, sub_dim
        label_count = label_feat.size(0) // self.attention_head
        u_reshape = self.u_reduce(label_feat.reshape(label_count, self.attention_head, att_dim))
        score = contract('abcd,ecd->aebc', z_reshape, u_reshape)
        if word_mask is not None:
            word_mask = word_mask.bool()
            score = score.masked_fill(
                mask=~word_mask[:, 0:score.shape[-2]].unsqueeze(1).unsqueeze(-1).expand_as(score),
                value=float('-1e6'))
        alpha = torch.softmax(score, dim=2)  # softmax on seq_length # batch_size, label_count, seq_length, att_head

        m = contract('abd,aebc->aedc', h, alpha)

        if self.pooling == 'max':
            m = m.max(dim=-1)[0]
        elif self.pooling == 'mean':
            m = m.mean(dim=-1)

        # batch_size, label_count, attention_head, hidden // attention_head
        m = self.dropout(m)
        return m
