import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from paddle.vision.models import resnet50
# from .blocks import *


class AutoEncoder_MNIST(paddle.nn.Layer):
    def __init__(self):
        super(AutoEncoder_MNIST, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2D(1, 12, 4, stride=2, padding=3),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2D(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2D(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
			nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2DTranspose(12, 1, 4, stride=2, padding=3),   # [batch, 1, 32, 32]
            nn.Sigmoid()
            )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            # GaussianClassifier()
            nn.Linear(48 * 4 * 4, 10),
        )

    def get_feature(self, x):
        return self.encoder(x)
    
    def decode(self, f):
        return self.decoder(f)

    def forward(self, x):
        f = self.get_feature(x)
        pred = self.classifier(f)
        x_hat = self.decode(f)
        return pred, x_hat


class Unflatten(nn.Layer):
    def __init__(self, input_channels, height, width):
        super(Unflatten, self).__init__()
        self.input_channels = input_channels
        self.height = height
        self.width = width

    def forward(self, x):
        return x.reshape([x.shape[0], self.input_channels, self.height, self.width])


class AutoencoderCifar(nn.Layer):
    def __init__(self, cls_num=10):
        super(AutoencoderCifar, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = nn.Sequential(
                nn.Conv2D(3, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(12, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(24, 48, 4, stride=2, padding=1),
                nn.ReLU(),)
        self.decoder = nn.Sequential(
                nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(12, 3, 4, stride=2, padding=1),
                nn.Sigmoid(),
        )

        self.z_dim = 768

        self.class_linear = nn.Sequential(
            nn.Linear(in_features=768, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=cls_num),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        z = paddle.flatten(encoded, 1)
        logits = self.class_linear(z)
        return decoded, logits

    def intervention_z(self, x, ind, value=0.):
        """
        ind
        """
        encoded = self.encoder(x)
        z = paddle.flatten(encoded, 1)
        z[list(range(x.shape[0])), ind] = value
        logits = self.class_linear(z)
        r = self.decoder(encoded)
        return r, logits


class AutoencoderMnist(nn.Layer):
    def __init__(self, cls_num=10):
        super(AutoencoderMnist, self).__init__()
        # Input size: [batch, 1, 28, 28]
        # Output size: [batch, 1, 28, 28]
        self.encoder = nn.Sequential(
            nn.Conv2D(1, 12, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2D(12, 24, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2D(24, 48, 4, stride=2, padding=0),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2DTranspose(48, 24, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2DTranspose(24, 12, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2DTranspose(12, 1, 4, stride=2, padding=1),
            nn.Sigmoid())

        self.z_dim = 192

        self.class_linear = nn.Sequential(
            nn.Linear(in_features=192, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=cls_num),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        z = paddle.flatten(encoded, 1)
        logits = self.class_linear(z)
        return decoded, logits

    def intervention_z(self, x, ind, value=0.):
        """
        ind
        """
        encoded = self.encoder(x)
        z = paddle.flatten(encoded, 1)
        z[list(range(x.shape[0])), ind] = value
        logits = self.class_linear(z)
        r = self.decoder(encoded)
        return r, logits


class AutoencoderCeleba(nn.Layer):
    def __init__(self, cls_num=8):
        super(AutoencoderCeleba, self).__init__()
        # Input size: [batch, 3, 64, 64]
        # Output size: [batch, 3, 64, 64]
        self.encoder = nn.Sequential(
                nn.Conv2D(3, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(12, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2D(24, 48, 4, stride=2, padding=1),
                nn.ReLU(),)
        self.decoder = nn.Sequential(
                nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2DTranspose(12, 3, 4, stride=2, padding=1),
                nn.Sigmoid(),
        )

        self.z_dim = 3072

        self.class_linear = nn.Sequential(
            nn.Linear(in_features=self.z_dim, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=cls_num),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        z = paddle.flatten(encoded, 1)
        logits = self.class_linear(z)
        return decoded, logits

    def intervention_z(self, x, ind, value=0.):
        """
        ind
        """
        encoded = self.encoder(x)
        z = paddle.flatten(encoded, 1)
        z[list(range(x.shape[0])), ind] = value
        logits = self.class_linear(z)
        r = self.decoder(encoded)
        return r, logits
