import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.variable import Variable
from parser_1 import _parser
from models.layers import GraphAttentionNetwork

args = _parser()

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionNetwork(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
                           range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionNetwork(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, data):
        x = data[0]
        adj = data[1]
        mask = data[2]
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        x = torch.max(x, dim=1)[0].squeeze()  # max pooling over nodes (usually performs better than average)
        x = Variable(x, requires_grad=True)
        return x
