import torch
import torch.nn as nn
from xad.models.bases import ConditionalDiscriminator, ConditionalGAN


class CGAN64(ConditionalGAN):
    class Encoder(ConditionalGAN.Encoder):
        def __init__(self, latent_dim: int = 512):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(3, 32, 3, stride=2, padding=1),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(32, 64, 3, stride=2, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(64, 128, 3, stride=2, padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(128, 256, 3, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(inplace=True),
                nn.Flatten(),
                nn.Linear(4096, latent_dim)
            )

        def forward(self, x: torch.Tensor, c: torch.Tensor = None):
            z = self.encoder(x)
            return z

    class Decoder(ConditionalGAN.Decoder):
        def __init__(self, latent_dim: int, condition_shape: torch.Size, ):
            super().__init__()
            self.encoder_condition = nn.Sequential(
                nn.Embedding(condition_shape.numel(), latent_dim),
            )
            self.decoder_lin = nn.Sequential(
                nn.Linear(latent_dim * 2, 256 * 2 ** 2),
            )
            self.decoder_conv = nn.Sequential(
                nn.Unflatten(1, (256, 2, 2)),
                nn.ConvTranspose2d(256, 512, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(inplace=True),
                nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(inplace=True),
                nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(inplace=True),
                nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(inplace=True),
                nn.ConvTranspose2d(64, 16, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(inplace=True),
                nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1, output_padding=0)
            )

        def forward(self, z: torch.Tensor, condition: torch.Tensor):
            cond_embd = self.encoder_condition(condition)
            z = self.decoder_lin(torch.cat([z, cond_embd], dim=1))
            z = self.decoder_conv(z)
            # z = torch.sigmoid(z)
            return z

    def __init__(self, latent_dim: int, condition_shape: torch.Size):
        super().__init__(latent_dim, condition_shape)
        self.encoder = CGAN64.Encoder(latent_dim)
        self.decoder = CGAN64.Decoder(latent_dim, condition_shape)


class CDisc64(ConditionalDiscriminator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size):
        super().__init__(latent_dim, condition_shape)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(4096, self.latent_dim),
            nn.LeakyReLU(inplace=True)
        )
        self.condition_embeddings = nn.Embedding(2, self.latent_dim)
        self.predictor = nn.Linear(self.latent_dim, 1)

    def parameterize(self):
        self.condition_embeddings = torch.nn.utils.parametrizations.spectral_norm(self.condition_embeddings)
        self.predictor = torch.nn.utils.parametrizations.spectral_norm(self.predictor)

    def ordinal_predictions(self, encoding: torch.Tensor, ordinal_lbls: torch.Tensor) -> torch.Tensor:
        out = 0
        for i in range(self.condition_shape.numel() - 1):
            out += (self.condition_embeddings(ordinal_lbls[:, i]) * encoding).sum(-1)
        return out

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        z = self.encoder(x)
        y = self.predictor(z).squeeze()
        y += self.ordinal_predictions(z, condition)
        return y
