import torch
import torch.nn as nn
from typing import List, Tuple, Dict, Any

from data.structures import MaskMatrices


class Message(nn.Module):
    def __init__(self, hv_dim: int, he_dim: int, mv_dim: int, me_dim: int, pos_dim: int,
                 use_cuda=False, p_dropout=0.0, option_dict: Dict[str, Any] = None):
        super(Message, self).__init__()
        self.hv_dim = hv_dim
        self.he_dim = he_dim
        self.mv_dim = mv_dim
        self.me_dim = me_dim
        self.pos_dim = pos_dim
        self.use_cuda = use_cuda
        self.dropout = nn.Dropout(p_dropout)
        self.option_dict = option_dict

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, pos_ftr: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        """
        naive message passing with dynamic properties
        :param hv_ftr: hidden vertex features with shape [n_vertex, hv_dim]
        :param he_ftr: hidden edge features with shape [n_edge, he_dim]
        :param pos_ftr: atom position features with shape [n_vertex, pos_dim]
        :param mask_matrices: mask matrices
        :param return_list: extra data to return
        :return: vertex message, edge message, node alignment
        """
        raise NotImplementedError
