import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from .newConv import newGATConv

class GAT(nn.Module):
    def __init__(self, num_features, hidden_features, heads, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc1 = GATConv(num_features, hidden_features, heads)
        self.gc2 = GATConv(hidden_features * heads, num_classes, heads=1, concat=False)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.gc1(x, edge_index, edge_weight))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, edge_index, edge_weight)
        return x

class GATL(nn.Module):
    def __init__(self, num_features, hidden_features, heads, num_classes):
        super().__init__()
        self.gc1 = GATConv(num_features, hidden_features, heads)
        self.ln1 = nn.Linear(num_features, hidden_features * heads, bias=False)
        self.gc2 = GATConv(hidden_features * heads, hidden_features, heads)
        self.ln2 = nn.Linear(hidden_features * heads, hidden_features * heads, bias=False)
        self.gc3 = GATConv(hidden_features * heads, num_classes, heads=int(heads * 1.5), concat=False)
        self.ln3 = nn.Linear(hidden_features * heads, num_classes, bias=False)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.elu(self.gc1(x, edge_index, edge_weight) + self.ln1(x))
        x = F.elu(self.gc2(x, edge_index, edge_weight) + self.ln2(x))
        x = self.gc3(x, edge_index, edge_weight) + self.ln3(x)
        return x


class newGAT(nn.Module):
    def __init__(self, num_features, hidden_features, heads, num_classes, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.gc1 = newGATConv(num_features, hidden_features, heads)
        self.gc2 = newGATConv(hidden_features * heads, num_classes, heads=1, concat=False)

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.gc1(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight)
        return x

class newGATL(nn.Module):
    def __init__(self, num_features, hidden_features, heads, num_classes):
        super().__init__()
        self.gc1 = newGATConv(num_features, hidden_features, heads)
        self.ln1 = nn.Linear(num_features, hidden_features * heads, bias=False)
        self.gc2 = newGATConv(hidden_features * heads, hidden_features, heads)
        self.ln2 = nn.Linear(hidden_features * heads, hidden_features * heads, bias=False)
        self.gc3 = newGATConv(hidden_features * heads, num_classes, heads=int(heads * 1.5), concat=False)
        self.ln3 = nn.Linear(hidden_features * heads, num_classes, bias=False)

    def forward(self, x, edge_index, inf_edge_index=None, edge_weight=None, inf_edge_weight=None):
        x = F.elu(self.gc1(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln1(x))
        x = F.elu(self.gc2(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln2(x))
        x = self.gc3(x, inf_edge_index, edge_index, inf_edge_weight, edge_weight) + self.ln3(x)
        return x