import torch
import torch.nn.functional as F
from torch_geometric_signed_directed.nn.directed import MagNet_node_classification, DiGCN_node_classification
from torch_geometric_signed_directed.utils import get_appr_directed_adj
from torch_geometric.nn import LINKX as pyg_LINKX


class MagNet(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, K, q, num_layers=2, dropout=0):
        super(MagNet, self).__init__()
        self.model = MagNet_node_classification(num_features=num_features, label_dim=num_classes, hidden=hidden_dim, layer=num_layers, dropout=dropout, K=K, q=q, activation=True, cached=True)

    def forward(self, x, edge_index):
        return self.model(x, x, edge_index)


class DiGCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, dropout=0, alpha=0):
        super(DiGCN, self).__init__()
        self.model = DiGCN_node_classification(num_features=num_features, label_dim=num_classes, hidden=hidden_dim, dropout=dropout)
        self.alpha = alpha
        self.edge_index, self.edge_weight = None, None

    def forward(self, x, edge_index):
        if self.edge_index is None:
            self.edge_index, self.edge_weight = get_appr_directed_adj(self.alpha, edge_index, x.shape[0], x.dtype)
        
        return self.model(x, self.edge_index, self.edge_weight)
    

class MLP(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim, num_layers=2, dropout=0):
        super(MLP, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(num_features, hidden_dim))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_dim, hidden_dim))
        self.lins.append(torch.nn.Linear(hidden_dim, num_classes))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x, edge_index):
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.log_softmax(x, dim=-1)
    

# python src/full_batch/run.py --model linkx --dataset directed-roman-empire --lr 0.01 --patience 200 --num_layers 2 --undirected
class LINKX(torch.nn.Module):
    def __init__(self, num_nodes, num_features, num_classes, hidden_dim, num_layers=2, dropout=0):
        super(LINKX, self).__init__()

        self.model = pyg_LINKX(
            num_nodes=num_nodes,
            in_channels=num_features,
            hidden_channels=hidden_dim,
            num_layers=num_layers,
            num_edge_layers=num_layers,
            num_node_layers=num_layers,
            out_channels=num_classes,
            dropout=dropout,
        )

    def forward(self, x, edge_index):
        x = self.model(x, edge_index)

        return torch.log_softmax(x, dim=-1)