import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from .residual import ResidualStack
from .stylegan import StyleGan


class GeneratorStyleGan(nn.Module):

    def __init__(self, output_shape, style_dim=512, sigmoid=False):
        super(GeneratorStyleGan, self).__init__()
        self._style_shape = torch.Size([style_dim])
        self._output_shape = output_shape
        self._output_dim = self._output_shape.numel()
        self._sigmoid = sigmoid

        if len(self._output_shape) > 2:
            outC = self._output_shape[0]
        else:
            outC = 0

        if outC == 0:
            raise ValueError("This generator should only be used with images")

        self._ff_style_gan = StyleGan(self._style_shape,
                                      self._output_shape)

        self.mean = nn.Parameter(torch.zeros(self._style_shape),
                                 requires_grad=False)
        self.std = nn.Parameter(torch.ones(self._style_shape),
                                requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)

    def forward(self, y):
        """Returns p(x|y)

        Args:
            y (tensor, optional): tensor on which we condition on
        """
        y = self._ff_style_gan(y)
        if self._sigmoid:
            y = F.sigmoid(y)
        return y

    def visuzalize_deconv(self):
        self._ff_style_gan.visualize_deconv()


class GeneratorResidual(nn.Module):

    def __init__(self, output_shape, base_shape=torch.Size([1, 16, 16]),
                 sigmoid=False):
        super(GeneratorResidual, self).__init__()
        self.output_shape = output_shape
        self._sigmoid = sigmoid
        self._ff = nn.Sequential(ResidualStack(base_shape, 32, 8),
                                 nn.Flatten(),
                                 *[nn.Linear(base_shape.numel(),
                                             base_shape.numel())
                                   for _ in range(3)],
                                 nn.Linear(base_shape.numel(),
                                           self.output_shape.numel()))

        self.mean = nn.Parameter(torch.zeros(base_shape), requires_grad=False)
        self.std = nn.Parameter(torch.ones(base_shape), requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)

    def forward(self, y):
        """Returns p(x|y)

        Args:
            y (tensor, optional): tensor on which we condition on
        """
        batch_size = y.size(0)
        y = self._ff(y).view(batch_size, *self.output_shape)
        if self._sigmoid:
            y = F.sigmoid(y)
        return y


class GeneratorImprovedMNISTbase(nn.Module):
    """
    credit to https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py
    (code copied and pasted almost exactly)
    """

    DIM = 64          # Model dimensionality
    LATENT_DIM = 128
    OUTPUT_DIM = 784  # Number of pixels in MNIST (28*28)

    def __init__(self, norm_type):
        super().__init__()

        DIM = self.DIM
        preprocess = nn.Sequential(
            nn.Linear(self.LATENT_DIM, 4*4*4*DIM),
            nn.ReLU(True),
        )
        def batch_norm_if_desired(d):
            if norm_type == 'batch':
                return nn.BatchNorm2d(d, track_running_stats=False)
            elif norm_type == 'instance':
                return nn.InstanceNorm2d(d)
            else:
                assert norm_type is None
                return nn.Sequential()
        block1 = nn.Sequential(
            nn.ConvTranspose2d(4*DIM, 2*DIM, 5),
            batch_norm_if_desired(2*DIM),
            nn.ReLU(True),
        )
        block2 = nn.Sequential(
            nn.ConvTranspose2d(2*DIM, DIM, 5),
            batch_norm_if_desired(DIM),
            nn.ReLU(True),
        )
        deconv_out = nn.ConvTranspose2d(DIM, 1, 8, stride=2)

        self.block1 = block1
        self.block2 = block2
        self.deconv_out = deconv_out
        self.preprocess = preprocess
        self.sigmoid = nn.Sigmoid()

        self.mean = nn.Parameter(torch.zeros(self.LATENT_DIM),
                                 requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.LATENT_DIM),
                                requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)
        self.fixed_styles = nn.Parameter(self._style_sampler.sample(torch.Size([3])),
                                         requires_grad=False)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)

    def forward(self, inp):
        output = self.preprocess(inp)
        output = output.view(-1, 4*self.DIM, 4, 4)
        output = self.block1(output)
        output = output[:, :, :7, :7]
        output = self.block2(output)
        output = self.deconv_out(output)
        output = self.sigmoid(output)
        return output

    def get_fixed(self):
        return self(self.fixed_styles)


class GeneratorImprovedMNIST(GeneratorImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type=None)


class GeneratorBatchNormImprovedMNIST(GeneratorImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type='batch')


class GeneratorInstanceNormImprovedMNIST(GeneratorImprovedMNISTbase):
    def __init__(self):
        super().__init__(norm_type='instance')


class GeneratorBNMNIST(nn.Module):
    IMG_DIM = 28
    INIT_SIZE = 8
    LATENT_DIM = 100
    CHANNELS = 1

    def __init__(self, norm_type='batch'):
        super().__init__()

        NormLayer = {'batch': nn.BatchNorm2d,
                     'instance': nn.InstanceNorm2d}[norm_type]

        self.l1 = nn.Linear(self.LATENT_DIM, 128*self.INIT_SIZE**2)
        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            NormLayer(128, track_running_stats=(norm_type=='batch')),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=0),
            NormLayer(64, track_running_stats=(norm_type=='batch')),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, self.CHANNELS, 3, stride=1, padding=0),
            nn.Sigmoid(),
        )

        self.mean = nn.Parameter(torch.zeros(self.LATENT_DIM), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.LATENT_DIM), requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)
        self.fixed_styles = nn.Parameter(self._style_sampler.sample(torch.Size([3])),
                                         requires_grad=False)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.INIT_SIZE, self.INIT_SIZE)
        img = self.conv_blocks(out)
        return img

    def get_fixed(self):
        return self(self.fixed_styles)


class GeneratorDCGAN(nn.Module):

    def make_block(self, in_c, out_c, s, p):

        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c,
                               kernel_size=4,
                               stride=s,
                               padding=p,
                               bias=False),
            nn.BatchNorm2d(out_c, track_running_stats=False),
            nn.ReLU(True),
        )

    def __init__(self, img_dim, nc, nz, ngf):
        super().__init__()

        log_out_dim = math.ceil(math.log(img_dim, 2))
        assert img_dim % 2 == 0
        self.out_dim = 2**log_out_dim
        self.crop_by = (self.out_dim - img_dim) // 2
        self.nz = nz

        remaining_upsamples = log_out_dim - 2
        layers = [self.make_block(in_c=nz, out_c=ngf*2**(remaining_upsamples-1),
                                  s=1, p=0),
                  ]
        while remaining_upsamples > 1:
            remaining_upsamples -= 1
            layers.append(self.make_block(in_c=ngf*2**remaining_upsamples,
                                          out_c=ngf*2**(remaining_upsamples-1),
                                          s=2, p=1))

        layers.extend([
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        ])
        self.main = nn.Sequential(*layers)

        self.mean = nn.Parameter(torch.zeros(self.nz), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.nz), requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)

    def forward(self, noise_vector):
        noise_image = noise_vector.unsqueeze(-1).unsqueeze(-1)
        out = self.main(noise_image)
        cr1 = self.crop_by
        cr2 = self.out_dim - self.crop_by
        return out[:, :, cr1:cr2, cr1:cr2]

class ScalableGenerator(nn.Sequential):
    n_hidden_layers = None
    hidden_dim = None
    nz = 100

    def __init__(self):
        self.mean = nn.Parameter(torch.zeros(self.nz), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.nz), requires_grad=False)
        self._style_sampler = Normal(self.mean, self.std)

        in_dim = self.nz
        layers = []
        for _ in range(self.n_hidden_layers):
            layers.append(nn.Linear(in_dim, n_hidden_layers))
            layers.append(nn.ReLU())
            in_dim = self.hidden_dim
        layers.append(nn.Linear(in_dim, 28**2))
        layers.append(nn.Sigmoid())
        super().__init__(*layers)

    def sample(self, batch_shape=torch.Size([1])):
        styles = self._style_sampler.sample(batch_shape)
        return self(styles)
