import torch.nn.functional as F
from .transformer import Transformer, ActivationFunction
from .universal_transformer import UniversalTransformerDecoderWithLayer, UniversalTransformerEncoderWithLayer
from .relative_transformer import RelativeTransformerDecoderLayer, RelativeTransformerEncoderLayer


class UniversalRelativeTransformer(Transformer):
    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: ActivationFunction = F.relu, attention_dropout: float = 0):

        super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation,
                         UniversalTransformerEncoderWithLayer(RelativeTransformerEncoderLayer),
                         UniversalTransformerDecoderWithLayer(RelativeTransformerDecoderLayer), attention_dropout)
