import torch
import torch.nn as nn
import torch.nn.functional as F


class SingleOpAttnBlock(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_prob):
        """
            The implementation of Operation Message Attention Block
        :param input_dim: the dimension of input feature vectors
        :param output_dim: the dimension of output feature vectors
        :param dropout_prob: the parameter p for nn.Dropout()

        """
        super(SingleOpAttnBlock, self).__init__()
        self.out_features = output_dim
        self.alpha = 0.2

        self.W = nn.Parameter(torch.empty(size=(input_dim, output_dim)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2 * output_dim, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leaky_relu = nn.LeakyReLU(self.alpha)

        self.dropout = nn.Dropout(p=dropout_prob)

class MultiHeadOpAttnBlock(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_prob, num_heads, activation, concat=True):
        """
            The implementation of Operation Message Attention Block with multi-head attention
        :param input_dim: the dimension of input feature vectors
        :param output_dim: the dimension of each head's output
        :param dropout_prob: the parameter p for nn.Dropout()
        :param num_heads: the number of attention heads
        :param activation: the activation function used before output
        :param concat: the aggregation operator, true/false means concat/averaging
        """
        super(MultiHeadOpAttnBlock, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.num_heads = num_heads
        self.concat = concat
        self.activation = activation
        self.attentions = [
            SingleOpAttnBlock(input_dim, output_dim, dropout_prob) for
            _ in range(num_heads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

class SingleMchAttnBlock(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, output_dim, dropout_prob):
        """
            The implementation of Machine Message Attention Block
        :param node_input_dim: the dimension of input node feature vectors
        :param edge_input_dim: the dimension of input edge feature vectors
        :param output_dim: the dimension of output feature vectors
        :param dropout_prob: the parameter p for nn.Dropout()
        """
        super(SingleMchAttnBlock, self).__init__()
        self.out_features = output_dim
        self.alpha = 0.2
        self.W = nn.Parameter(torch.empty(size=(node_input_dim, output_dim)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        self.W_edge = nn.Parameter(torch.empty(size=(edge_input_dim, output_dim)))
        nn.init.xavier_uniform_(self.W_edge.data, gain=1.414)

        self.a = nn.Parameter(torch.empty(size=(3 * output_dim, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leaky_relu = nn.LeakyReLU(self.alpha)

        self.dropout = nn.Dropout(p=dropout_prob)

    def get_attention_coef(self, Wh, W_edge):
        """
            compute attention coefficients using node and edge features
        :param Wh: transformed node features
        :param W_edge: transformed edge features
        :return:
        """

        Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])  # [sz_b, M, 1]
        Wh2 = torch.matmul(Wh, self.a[self.out_features:2 * self.out_features, :])  # [sz_b, M, 1]
        edge_feas = torch.matmul(W_edge, self.a[2 * self.out_features:, :])  # [sz_b, M, M, 1]

        # broadcast add
        e = Wh1 + Wh2.transpose(-1, -2) + edge_feas.squeeze(-1)

        return self.leaky_relu(e)


class MultiHeadMchAttnBlock(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, output_dim, dropout_prob, num_heads, activation, concat=True):
        """
            The implementation of Machine Message Attention Block with multi-head attention
        :param node_input_dim: the dimension of input node feature vectors
        :param edge_input_dim: the dimension of input edge feature vectors
        :param output_dim: the dimension of each head's output
        :param dropout_prob: the parameter p for nn.Dropout()
        :param num_heads: the number of attention heads
        :param activation: the activation function used before output
        :param concat: the aggregation operator, true/false means concat/averaging
        """
        super(MultiHeadMchAttnBlock, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.concat = concat
        self.activation = activation
        self.num_heads = num_heads

        self.attentions = [SingleMchAttnBlock
                           (node_input_dim, edge_input_dim, output_dim, dropout_prob) for _ in range(num_heads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
