from typing import Final

import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv, GINConv, SAGEConv, global_mean_pool, global_add_pool, global_max_pool

def custom_sigmoid(x: torch.Tensor, *args, **kwargs):
    return F.sigmoid(x.squeeze())

def no_activation(x: torch.Tensor, *args, **kwargs):
    return x

ACTIVATION_MAPPING: Final[dict[str, any]] = {
    "softmax": F.softmax,
    "sigmoid": F.sigmoid,
    "final_sigmoid": custom_sigmoid,
    "relu": F.relu,
    "tanh": F.tanh,
    "leakyrelu": F.leaky_relu,
    "logsoftmax": F.log_softmax,
    "no_activation": no_activation
}


MESSAGE_PASSING_MAPPING: Final[dict[str, any]] = {
    "gcn": GCNConv,
    "gat_v2": GATv2Conv,
    "gin": GINConv,
    'sage': SAGEConv
}

POOLING_MAPPING: Final[dict[str, any]] = {
    "mean": global_mean_pool,
    "add": global_add_pool,
    "sum": global_add_pool,
    "max": global_max_pool,
    None: None,
    "none": None
}





