import torch
import torch.nn as nn
import torch.nn.functional as F
from ..blocks import StyleGan, ResidualStack


class BiGeneratorImgToImg(nn.Module):

    def __init__(self, input_shape, output_shape, sigmoid=False):
        super(BiGeneratorImgToImg, self).__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        self._input_dim = self.input_shape.numel()
        self._output_dim = self.output_shape.numel()
        self._sigmoid = sigmoid

        if len(self.input_shape) > 2:
            inC = self.input_shape[0]
        else:
            inC = 0

        if len(self.output_shape) > 2:
            outC = self.output_shape[0]
        else:
            outC = 0

        H, W = self.input_shape[1:]
        if inC == 0 or outC == 0:
            raise ValueError("This generator should only be used with images")
        latent_dim = min(512, self.input_shape.numel())

        to_latent = [nn.Flatten(), nn.Linear(self._input_dim, latent_dim)]
        linear = [nn.Sequential(nn.Linear(latent_dim,
                                          latent_dim),
                                nn.LeakyReLU(negative_slope=0.2))
                  for _ in range(4)]
        self._pre_stylegan = nn.Sequential(*(to_latent + linear))

        self._ff_style_gan = StyleGan(torch.Size([latent_dim]),
                                      self.output_shape)

    def forward(self, y):
        """Returns p(x|y)

        Args:
            y (tensor, optional): tensor on which we condition on
        """
        y = self._pre_stylegan(y)
        y = self._ff_style_gan(y)
        if self._sigmoid:
            y = F.sigmoid(y)
        return y

    def visuzalize_deconv(self):
        self._ff_style_gan.visualize_deconv()


class BiGeneratorResidual(nn.Module):

    def __init__(self, input_shape, output_shape, sigmoid=False):
        super(BiGeneratorResidual, self).__init__()
        self.output_shape = output_shape
        self._sigmoid = sigmoid
        if len(input_shape) == 2:
            self.input_shape = torch.Size([1]) + input_shape
        elif len(input_shape) == 1:
            self.input_shape = torch.Size([1, 1]) + input_shape
        else:
            self.input_shape = input_shape

        self._ff = nn.Sequential(ResidualStack(self.input_shape, 32, 8),
                                 nn.Flatten(),
                                 *[nn.Linear(self.input_shape.numel(),
                                             self.input_shape.numel())
                                   for _ in range(3)],
                                 nn.Linear(self.input_shape.numel(),
                                           self.output_shape.numel()))

    def forward(self, y):
        """Returns p(x|y)

        Args:
            y (tensor, optional): tensor on which we condition on
        """
        batch_size = y.size(0)
        y = y.view(torch.Size([batch_size])+self.input_shape)
        y = self._ff(y).view(batch_size, *self.output_shape)
        if self._sigmoid:
            y = F.sigmoid(y)
        return y
