import torch
import torch.nn as nn
import torch.nn.functional as F
# from layers import GraphAttentionLayer
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes, hid=8, in_head=8, out_head=4, dropout=0.6, coverage=0.8):
        super(GAT, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.hid = hid
        self.in_head = in_head
        self.out_head = out_head
        self.dropout = dropout
        self.coverage = coverage
        
        self.conv1 = GATConv(self.num_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head, self.num_classes, concat=False,
                             heads=self.out_head, dropout=self.dropout)
        self.sel = nn.Sequential(nn.Linear(num_classes, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, 1), nn.Sigmoid())

    def forward(self, x, edge_index):

        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)

        pred = F.log_softmax(x, dim=1)
        sel = self.sel(x)
        aux = F.log_softmax(x, dim=1)
        pred_sel = torch.cat((pred, sel), dim=1)
        return [pred_sel, aux]