import torch.nn as nn
import torch

class Concatenate(nn.Module):
    def __init__(self, dim=0):
        super(Concatenate, self).__init__()
        self.dim = dim
    
    def forward(self, inputs):
        return torch.cat(inputs, dim=self.dim)