import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.typing import Adj
from torch_geometric.nn import GCNConv, SAGEConv, GATConv


class GCN(torch.nn.Module):
    def __init__(self, nlayer, nfeat, nhid, ncls, drop_rate: float = 0.5):
        super(GCN, self).__init__()
        self.drop_rate = drop_rate
        self.edge_weight = None
        self.graph_convs = torch.nn.ModuleList()
        self.graph_convs.append(GCNConv(nfeat, nhid))
        for l in range(nlayer - 1):
            self.graph_convs.append(GCNConv(nhid, nhid))
        self.classifier = torch.nn.Linear(nhid, ncls)

    def forward(self, x: Tensor, edge_index: Adj):

        for l in range(len(self.graph_convs)):
            x = self.graph_convs[l](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
        x = self.classifier(x)
        return x, F.log_softmax(x, dim=1), F.softmax(x, dim=1)


class SAGE(torch.nn.Module):
    def __init__(self, nlayer, nfeat, nhid, ncls, drop_rate: float = 0.5):
        super(SAGE, self).__init__()
        self.drop_rate = drop_rate
        self.nlayer = nlayer
        self.graph_convs = torch.nn.ModuleList()
        self.graph_convs.append(SAGEConv(nfeat, nhid))
        for l in range(nlayer - 1):
            self.graph_convs.append(SAGEConv(nhid, nhid))
        self.classifier = torch.nn.Linear(nhid, ncls)

    def forward(self, x: Tensor, edge_index: Adj):
        for l in range(len(self.graph_convs)):
            x = self.graph_convs[l](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
        x = self.classifier(x)
        return x, F.log_softmax(x, dim=1), F.softmax(x, dim=1)

    
class GNN_Multi_Layer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
    
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=False, normalize=True)
        self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=False, normalize=True)
        self.conv3 = GCNConv(hidden_channels, out_channels, cached=False, normalize=True)

        # self.conv1 = GATConv(in_channels, hidden_channels, heads=1, cached=True, normalize=True)
        # self.conv2 = GATConv(hidden_channels, out_channels, heads=1, cached=True, normalize=True)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = self.conv3(x, edge_index, edge_weight)

        return x, F.log_softmax(x, dim=1), F.softmax(x, dim=1)

