import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, ResGatedGraphConv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
import torch.nn.functional as F


class GatedGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers):
        super(GatedGCN, self).__init__()
        self.nlayers = nlayers
        self.emb = nn.Linear(nfeat,nhid)
        self.conv = ResGatedGraphConv(nhid,nhid)
        self.out = nn.Linear(nhid,nclass)

    def forward(self, data):
        input = data.x.float()
        edge_index = data.edge_index
        x = self.emb(input)

        for i in range(self.nlayers):
            x = torch.tanh(self.conv(x, edge_index))

        y = self.out(x)
        y = global_max_pool(y, data.batch)

        return y.squeeze(-1)

class GraphSAGE(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers):
        super(GraphSAGE, self).__init__()
        self.nlayers = nlayers
        self.emb = nn.Linear(nfeat,nhid)
        self.conv = SAGEConv(nhid,nhid)
        self.out = nn.Linear(nhid,nclass)

    def forward(self, data):
        input = data.x.float()
        edge_index = data.edge_index
        x = self.emb(input)

        for i in range(self.nlayers):
            x = torch.tanh(self.conv(x, edge_index))

        y = self.out(x)
        y = global_max_pool(y, data.batch)

        return y.squeeze(-1)

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers):
        super(GCN, self).__init__()
        self.nlayers = nlayers
        self.emb = nn.Linear(nfeat,nhid)
        self.conv = GCNConv(nhid,nhid)
        self.out = nn.Linear(nhid,nclass)

    def forward(self, data):
        input = data.x.float()
        edge_index = data.edge_index
        x = self.emb(input)

        for i in range(self.nlayers):
            x = torch.tanh(self.conv(x, edge_index))

        y = self.out(x)
        y = global_max_pool(y, data.batch)

        return y.squeeze(-1)

class GIN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers):
        super(GIN, self).__init__()
        self.nlayers = nlayers
        self.emb = nn.Linear(nfeat,nhid)
        self.conv = GINConv(nn.Linear(nhid,nhid))
        self.out = nn.Linear(nhid,nclass)

    def forward(self, data):
        input = data.x.float()
        edge_index = data.edge_index
        x = self.emb(input)

        for i in range(self.nlayers):
            x = torch.tanh(self.conv(x, edge_index))

        y = self.out(x)
        y = global_max_pool(y, data.batch)

        return y.squeeze(-1)
