import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from paddle.vision.models import resnet50
# from .blocks import *


class Normalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


class Normalizer:
    def __init__(self, opt):
        self.normalizer = self._get_normalizer(opt)

    def _get_normalizer(self, opt):
        if opt.dataset == "cifar10":
            normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            normalizer = Normalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            normalizer = None
        else:
            raise Exception("Invalid dataset")
        return normalizer

    def __call__(self, x):
        if self.normalizer:
            x = self.normalizer(x)
        return x


class Denormalizer:
    def __init__(self, opt):
        self.denormalizer = self._get_denormalizer(opt)

    def _get_denormalizer(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            denormalizer = None
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def __call__(self, x):
        if self.denormalizer:
            x = self.denormalizer(x)
        return x


class MNISTBlock(paddle.nn.Layer):
    def __init__(self, in_planes, planes, stride=1):
        super(MNISTBlock, self).__init__()
        self.bn1 = paddle.nn.BatchNorm2D(in_planes)
        self.conv1 = paddle.nn.Conv2D(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)
        self.ind = None

    def forward(self, x):
        return self.conv1(F.relu(self.bn1(x)))


class NetC_MNIST(paddle.nn.Layer):
    def __init__(self):
        super(NetC_MNIST, self).__init__()
        self.conv1 = paddle.nn.Conv2D(1, 32, (3, 3), 2, 1)  # 14
        self.relu1 = paddle.nn.ReLU()
        self.layer2 = MNISTBlock(32, 64, 2)  # 7
        self.layer3 = MNISTBlock(64, 64, 2)  # 4
        self.flatten = paddle.nn.Flatten()
        self.linear6 = paddle.nn.Linear(64 * 4 * 4, 512)
        self.relu7 = paddle.nn.ReLU()
        self.dropout = paddle.nn.Dropout(0.3)
        self.feature_dim = 512 
        # self.linear9 = paddle.nn.Linear(512, 10)

    def forward(self, x):
        # x = self.relu1(self.conv1(x))
        # x = self.layer2(x)
        # x = self.layer3(x)
        # feat = x
        # x = self.flatten(x)
        # x = self.relu7(self.linear6(x))
        # x = self.dropout(x)
        # x = self.linear9(x)
        for module in self.children():
            x = module(x)
        return x
    
    def get_feature(self, x):
        for module in list(self.children())[:-4]:
            x = module(x)
        return x


class BadNet(paddle.nn.Layer):
    def __init__(self, input_channels, output_num):
        super().__init__()
        self.conv1 = paddle.nn.Sequential(
            paddle.nn.Conv2D(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
            paddle.nn.ReLU(),
            paddle.nn.AvgPool2D(kernel_size=2, stride=2)
        )

        self.conv2 = paddle.nn.Sequential(
            paddle.nn.Conv2D(in_channels=16, out_channels=32, kernel_size=5, stride=1),
            paddle.nn.ReLU(),
            paddle.nn.AvgPool2D(kernel_size=2, stride=2)
        )
        fc1_input_features = 800 if input_channels == 3 else 512
        self.fc1 = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=fc1_input_features, out_features=512),
            paddle.nn.ReLU()
        )
        self.fc2 = paddle.nn.Sequential(
            paddle.nn.Linear(in_features=512, out_features=output_num),
            paddle.nn.Softmax(axis=-1)
        )
        self.dropout = nn.Dropout(p=.5)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


class GaussianClassifier(paddle.nn.Layer):
    def __init__(self):
        super(GaussianClassifier, self).__init__()
        self.embed = nn.Embedding(10, 48 * 4 * 4)
    
    def forward(self, x):
        dist = paddle.square(x.unsqueeze(1) - self.embed.weight.unsqueeze(0)).sum(axis=-1)  # b * 10
        return dist


class AutoEncoder_MNIST(paddle.nn.Layer):
    def __init__(self):
        super(AutoEncoder_MNIST, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2D(1, 12, 4, stride=2, padding=3),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2D(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2D(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
			nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2DTranspose(12, 1, 4, stride=2, padding=3),   # [batch, 1, 32, 32]
            nn.Sigmoid()
            )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            # GaussianClassifier()
            nn.Linear(48 * 4 * 4, 10),
        )

    def get_feature(self, x):
        return self.encoder(x)
    
    def decode(self, f):
        return self.decoder(f)

    def forward(self, x):
        f = self.get_feature(x)
        pred = self.classifier(f)
        x_hat = self.decode(f)
        return pred, x_hat


class Unflatten(nn.Layer):
    def __init__(self, input_channels, height, width):
        super(Unflatten, self).__init__()
        self.input_channels = input_channels
        self.height = height
        self.width = width

    def forward(self, x):
        return x.reshape([x.shape[0], self.input_channels, self.height, self.width])


class AutoencoderCifar(nn.Layer):
    def __init__(self, cls_num=10):
        super(AutoencoderCifar, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = nn.Sequential(
                nn.Conv2D(3, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(12, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(24, 48, 4, stride=2, padding=1),
                nn.ReLU(),)
        self.decoder = nn.Sequential(
                nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(12, 3, 4, stride=2, padding=1),
                nn.Sigmoid(),
        )

        self.z_dim = 768

        self.class_linear = nn.Sequential(
            nn.Linear(in_features=768, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=cls_num),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        z = paddle.flatten(encoded, 1)
        logits = self.class_linear(z)
        return decoded, logits

    def intervention_z(self, x, ind, value=0.):
        """
        ind
        """
        encoded = self.encoder(x)
        z = paddle.flatten(encoded, 1)
        z[list(range(x.shape[0])), ind] = value
        logits = self.class_linear(z)
        r = self.decoder(encoded)
        return r, logits


class AutoencoderMnist(nn.Layer):
    def __init__(self, cls_num=10):
        super(AutoencoderMnist, self).__init__()
        # Input size: [batch, 1, 28, 28]
        # Output size: [batch, 1, 28, 28]
        self.encoder = nn.Sequential(
            nn.Conv2D(1, 12, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2D(12, 24, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2D(24, 48, 4, stride=2, padding=0),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2DTranspose(48, 24, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2DTranspose(24, 12, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2DTranspose(12, 1, 4, stride=2, padding=1),
            nn.Sigmoid())

        self.z_dim = 192

        self.class_linear = nn.Sequential(
            nn.Linear(in_features=192, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=cls_num),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        z = paddle.flatten(encoded, 1)
        logits = self.class_linear(z)
        return decoded, logits

    def intervention_z(self, x, ind, value=0.):
        """
        ind
        """
        encoded = self.encoder(x)
        z = paddle.flatten(encoded, 1)
        z[list(range(x.shape[0])), ind] = value
        logits = self.class_linear(z)
        r = self.decoder(encoded)
        return r, logits
