import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel as _DataParallel


class DataParallel(_DataParallel):

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return self.module.__getattr__(name)

            
