import torch.nn as nn
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GINConv
from torch.nn import Linear
import numpy as np
from torch.nn import Sequential



class GCN_LP(torch.nn.Module):
    def __init__(self, args, in_channels, hidden_channels, out_channels):
        super(GCN_LP, self).__init__()

        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        self.layers = nn.ModuleList()
        if args.K > 1:
            self.layers.append(GCNConv(in_channels, hidden_channels, cached=False, add_self_loops=True))
            for _ in range(args.K - 2):
                self.layers.append(GCNConv(hidden_channels, hidden_channels, cached=False, add_self_loops=True))
            self.layers.append(GCNConv(hidden_channels, out_channels, cached=False, add_self_loops=True))
        else:
            self.layers.append(GCNConv(in_channels, out_channels))
        self.dropout = args.dropout

    def reset_parameter(self):
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.lin.weight.data)
            if layer.lin.bias is not None:
                layer.lin.bias.data.zero_()

    def forward(self, x, adj_t):

        for conv in self.layers[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.layers[-1](x, adj_t)
        x = F.relu(x)
        return x
