import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F



class EmbeddingWrapper(nn.Module):
    def __init__(self, embedding_class, embedding_args = {}):
        '''
        This is a simple wrapper for the embedding calculations
        
        Arguments
        ---------
        
            embedding_class: class
                A class that is compatible with Pytorch that is designed for
                embedding. It needs a .forward() method. This should be an 
                un-initialised class.
            
            embedding_args: dictionary
                This is a dictionary that will be used in the embedding_class.
        
        Returns
        ---------
            None
        '''
        super(EmbeddingWrapper, self).__init__()
        
        self.embedding = embedding_class(**embedding_args)
        
        return
    
    
    def forward(self, X):
        
        out = self.embedding(X)
        
        return out