from typing import Any
import torch
import torch.nn as nn

class ParallelWrapper(nn.Module): 
    def __init__(self, model: nn.Module): 
        super(ParallelWrapper, self).__init__()
        self.model = model
        self.embedding_dim = model.get_embedding_dim()
        self.parallel_model = nn.DataParallel(model)

    def forward(self, x):
        return self.parallel_model(x)
    
    def get_embedding_dim(self):
        return self.embedding_dim
