import torch
import torch.nn as nn


class Normalize(nn.Module):

    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = mean
        self.std = std

    def forward(self, input, input2=None):
        if input2 is not None:
            len_input = len(input)
            input = torch.cat((input, input2), dim=0)
        size = input.size()
        x = input.clone()
        for i in range(size[1]):
            x[:,i] = (x[:,i] - self.mean[i])/self.std[i]
        if input2 is not None:
            return x[:len_input], x[len_input:]
        else:
            return x

