import torch
import torch.nn as nn


class ModelwNorm(nn.Module):
    def __init__(self, model):
        super(ModelwNorm, self).__init__()
        self.model = model
        self.mean = torch.tensor([0.507, 0.487, 0.441]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.267, 0.256, 0.276]).view(1, 3, 1, 1)

    def forward(self, x):
        m, s = self.mean.to(x.device), self.std.to(x.device)
        return self.model((x - m) / s)
