import torch
import torch.nn.functional as F
from .blocks import *


class Normalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


class Normalizer:
    def __init__(self, opt):
        self.normalizer = self._get_normalizer(opt)

    def _get_normalizer(self, opt):
        if opt.dataset == "cifar10":
            normalizer = Normalize(opt, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        elif opt.dataset == "celeba":
            normalizer = None
        elif opt.dataset == "tinyimagenet":
            normalizer = Normalize(opt, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        else:
            raise Exception("Invalid dataset")
        return normalizer

    def __call__(self, x):
        if self.normalizer:
            x = self.normalizer(x)
        return x


class Denormalizer:
    def __init__(self, opt):
        self.denormalizer = self._get_denormalizer(opt)

    def _get_denormalizer(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        elif opt.dataset == "celeba":
            denormalizer = None
        elif opt.dataset == "tinyimagenet":
            denormalizer = Denormalize(opt, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        else:
            raise Exception("Invalid dataset")
        return denormalizer


    def __call__(self, x):
        if self.denormalizer:
            x = self.denormalizer(x)
        return x


