import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import APPNP, GCNConv, SAGEConv, GATv2Conv

class APPNP_Model(torch.nn.Module):
    def __init__(self, input_dim, out_dim, filter_num, K=10, alpha = 0.1, dropout = 0.5):
        super(APPNP_Model, self).__init__()
        self.dropout = dropout
        self.line1 = nn.Linear(input_dim, filter_num)
        self.line2 = nn.Linear(filter_num, out_dim)
        self.conv = APPNP(K=K, alpha=alpha)

    def reset_parameters(self):
        self.line1.reset_parameters()
        self.line2.reset_parameters()
        return

    def forward(self, x, edge_index):
        x = F.relu(self.line1(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.line2(x))
        x = self.conv(x, edge_index)
        return F.log_softmax(x, dim=-1)

class GCN_Model(torch.nn.Module):
    def __init__(self, input_dim, out_dim, filter_num, dropout = 0.5):

        super(GCN_Model, self).__init__()
        self.dropout = dropout
        self.line1 = GCNConv(input_dim, filter_num)
        self.line2 = GCNConv(filter_num, out_dim)

    def reset_parameters(self):
        self.line1.reset_parameters()
        self.line2.reset_parameters()
        return

    def forward(self, x, edge_index):
        x = F.relu(self.line1(x, edge_index))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.line2(x, edge_index)
        return F.log_softmax(x, dim=-1)

class SAGE_Model(torch.nn.Module):
    def __init__(self, input_dim, out_dim, filter_num, dropout = 0.5):

        super(SAGE_Model, self).__init__()
        self.dropout = dropout
        self.line1 = SAGEConv(input_dim, filter_num)
        self.line2 = SAGEConv(filter_num, out_dim)

    def reset_parameters(self):
        self.line1.reset_parameters()
        self.line2.reset_parameters()
        return

    def forward(self, x, edge_index):
        x = F.relu(self.line1(x, edge_index))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.line2(x, edge_index)
        return F.log_softmax(x, dim=-1)

class GAT_Model(torch.nn.Module):
    def __init__(self, input_dim, out_dim, filter_num=8, att_drop=0.6, dropout=0.6, heads=8):

        super(GAT_Model, self).__init__()
        self.dropout = dropout
        self.line1 = GATv2Conv(input_dim, filter_num, heads, dropout=att_drop)
        self.line2 = GATv2Conv(filter_num*heads, out_dim, 1, dropout=att_drop, concat=False)

    def reset_parameters(self):
        self.line1.reset_parameters()
        self.line2.reset_parameters()
        return

    def forward(self, x, edge_index):
        x = F.relu(self.line1(x, edge_index))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.line2(x, edge_index)
        return F.log_softmax(x, dim=-1)