import torch
import torch.nn as nn
from typing import List, Tuple, Dict, Any, Union

from data.structures import MaskMatrices


class Generator(nn.Module):
    def __init__(self, hv_dim: int, he_dim: int, pos_dim: int, need_momentum=False, use_cuda=False, p_dropout=0.0):
        super(Generator, self).__init__()
        self.hv_dim = hv_dim
        self.he_dim = he_dim
        self.pos_dim = pos_dim
        self.need_momentum = need_momentum
        self.use_cuda = use_cuda
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                return_list: List[str], **kwargs) -> Tuple[Union[torch.Tensor, None], torch.Tensor, Dict[str, Any]]:
        raise NotImplementedError
