from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from models.model_utils import POOLING_MAPPING
from src.models.model_utils import ACTIVATION_MAPPING


# Define the GCN model
class InterfaceGNN(nn.Module, ABC):

    @abstractmethod
    def __init__(self):
        super(InterfaceGNN, self).__init__()

    @abstractmethod
    def forward(self, x, edge_index, edge_attr=None, batch=None) -> torch.Tensor:
        pass