import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class GlorotLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super(GlorotLinear, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, inputs):
        return self.linear(inputs)

class MLLinear(nn.Module):
    def __init__(self, linear_size, output_size):
        super(MLLinear, self).__init__()
        self.linear = nn.ModuleList(GlorotLinear(in_s, out_s)
                                    for in_s, out_s in zip(linear_size[:-1], linear_size[1:]))

        self.output = GlorotLinear(linear_size[-1], output_size)

    def forward(self, inputs):
        linear_out = inputs
        for linear in self.linear:
            linear_out = F.relu(linear(linear_out))
        return torch.squeeze(self.output(linear_out), -1)

class MLAttention(nn.Module):
    def __init__(self, labels_num, hidden_size):
        super(MLAttention, self).__init__()
        self.attention = nn.Linear(hidden_size, labels_num, bias=False)

    def forward(self, inputs, masks):
        masks = masks == 1
        masks = torch.unsqueeze(masks, 1)  # N, 1, L
        attention = self.attention(inputs).transpose(1, 2).masked_fill(~masks, -np.inf)  # N, labels_num, L
        attention = F.softmax(attention, -1)
        return attention @ inputs, attention.detach()   # N, labels_num, hidden_size

class AttentionClassifier(nn.Module):
    def __init__(self, labels_num, hidden_size):
        super(AttentionClassifier, self).__init__()
        self.attention = MLAttention(labels_num, hidden_size)
        self.mllinear = MLLinear([hidden_size, 256], label_num)

    def forward(self, inputs, masks):
        out, attn = self.attention(inputs, masks)
        out = self.mllinear(out)
        return out