import torch
import torch.nn as nn
import torch.nn.functional as F
# from layers import GATConv
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes, hid=8, in_head=8, out_head=4, dropout=0.6):
        super(GAT, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.hid = hid
        self.in_head = in_head
        self.out_head = out_head
        self.dropout = dropout
        
        self.conv1 = GATConv(self.num_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head, self.num_classes, concat=False,
                             heads=self.out_head, dropout=self.dropout)

    def forward(self, x, edge_index):

        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)