import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, coverage):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout
        self.coverage = coverage

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
        self.sel = nn.Sequential(nn.Linear(nclass, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, 1), nn.Sigmoid())

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        emb = x
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))

        pred = F.log_softmax(x, dim=1)
        sel = self.sel(x)
        pred_sel = torch.cat((pred, sel), dim=1)
        aux = F.log_softmax(x, dim=1)
        output = [pred_sel, aux]
        return emb, output