import torch
import torch.nn as nn
from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices


class Derivator(nn.Module):
    def __init__(self, v_dim: int, e_dim: int, pq_dim: int,
                 use_cuda=False, p_dropout=0.0):
        super(Derivator, self).__init__()
        self.v_dim = v_dim
        self.e_dim = e_dim
        self.pq_dim = pq_dim
        self.use_cuda = use_cuda
        self.p_dropout = p_dropout

    def forward(self, v: torch.Tensor, e: torch.Tensor, m: torch.Tensor, p: torch.Tensor, q: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List[str], **kwargs
                ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        raise NotImplementedError
