import paddle
import paddle.nn.functional as F
import paddle.nn as nn
# from .blocks import *


class Normalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


class Normalizer:
    def __init__(self, opt):
        self.normalizer = self._get_normalizer(opt)

    def _get_normalizer(self, opt):
        if opt.dataset == "cifar10":
            normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            normalizer = Normalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            normalizer = None
        else:
            raise Exception("Invalid dataset")
        return normalizer

    def __call__(self, x):
        if self.normalizer:
            x = self.normalizer(x)
        return x


class Denormalizer:
    def __init__(self, opt):
        self.denormalizer = self._get_denormalizer(opt)

    def _get_denormalizer(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            denormalizer = None
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def __call__(self, x):
        if self.denormalizer:
            x = self.denormalizer(x)
        return x


class MNISTBlock(paddle.nn.Layer):
    def __init__(self, in_planes, planes, stride=1):
        super(MNISTBlock, self).__init__()
        self.bn1 = paddle.nn.BatchNorm2D(in_planes)
        self.conv1 = paddle.nn.Conv2D(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
        self.ind = None

    def forward(self, x):
        return self.conv1(F.relu(self.bn1(x)))


class NetC_MNIST(paddle.nn.Layer):
    def __init__(self):
        super(NetC_MNIST, self).__init__()
        self.conv1 = paddle.nn.Conv2D(1, 32, (3, 3), 2, 1)  # 14
        self.relu1 = paddle.nn.ReLU()
        self.layer2 = MNISTBlock(32, 64, 2)  # 7
        self.layer3 = MNISTBlock(64, 64, 2)  # 4
        self.flatten = paddle.nn.Flatten()
        self.linear6 = paddle.nn.Linear(64 * 4 * 4, 512)
        self.relu7 = paddle.nn.ReLU()
        self.dropout = paddle.nn.Dropout(0.3)
        self.linear9 = paddle.nn.Linear(512, 10)


    def forward(self, x):
        activations = []
        x = self.relu1(self.conv1(x))
        x = self.layer2(x)
        activations.append(x)
        x = self.layer3(x)
        activations.append(x)
        x = self.flatten(x)
        x = self.relu7(self.linear6(x))
        x = self.dropout(x)
        x = self.linear9(x)
        return activations, x

class BadNet(paddle.nn.Layer):
    def __init__(self, input_channels, output_num):
        super().__init__()
        self.conv1 = paddle.nn.Sequential(
            paddle.nn.Conv2D(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
            paddle.nn.ReLU(),
            paddle.nn.AvgPool2D(kernel_size=2, stride=2)
        )

        self.conv2 = paddle.nn.Sequential(
            paddle.nn.Conv2D(in_channels=16, out_channels=32, kernel_size=5, stride=1),
            paddle.nn.ReLU(),
            paddle.nn.AvgPool2D(kernel_size=2, stride=2)
        )
        fc1_input_features = 800 if input_channels == 3 else 512
        self.fc1 = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=fc1_input_features, out_features=512),
            paddle.nn.ReLU()
        )
        self.fc2 = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=512, out_features=output_num),
            paddle.nn.Softmax(axis=-1)
        )
        self.dropout = nn.Dropout(p=.5)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
