import torch
import torch.nn as nn
from torch.distributions import Normal


class StyleGanBlockWithInput(nn.Module):

    def __init__(self, linear_dim, channels, latent_dim, h, w):
        super().__init__()

        self.ada1 = NormAdaIn(latent_dim, channels)
        self.conv2d = nn.Conv2d(channels, channels, kernel_size=(3, 3),
                                padding=1, bias=False)
        self.ada2 = NormAdaIn(latent_dim, channels)

        self.lin = nn.Linear(linear_dim, channels*h*w)

        self._h = h
        self._w = w
        self._channels = channels
        self.lrelu = nn.LeakyReLU(0.2)

        self.mean = nn.Parameter(torch.zeros((channels, h, w)),
                                 requires_grad=False)
        self.std = nn.Parameter(torch.ones((channels, h, w)),
                                requires_grad=False)
        self.noise_sampler = Normal(self.mean, self.std)

    def forward(self, x, latents):
        batch_size = latents.size(0)
        x = x.view(batch_size, -1)
        x = self.lin(x).view(batch_size, self._channels, self._h, self._w)
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada1(x, latents)
        x = self.conv2d(x)
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada2(x, latents)

        return x


class StyleGanBlockConst(nn.Module):

    def __init__(self, latent_dim, channels, h, w):
        super().__init__()

        self.const = nn.Parameter(torch.zeros(channels, h, w).normal_())
        self.bias = nn.Parameter(torch.zeros(channels, 1, 1).normal_())

        self.ada1 = NormAdaIn(latent_dim, channels)
        self.conv2d = nn.Conv2d(channels, channels, kernel_size=(3, 3),
                                padding=1, bias=False)
        self.ada2 = NormAdaIn(latent_dim, channels)

        self._h = h
        self._w = w
        self._channels = channels
        self.lrelu = nn.LeakyReLU(0.2)
        self.mean = nn.Parameter(torch.zeros((channels, h, w)),
                                 requires_grad=False)
        self.std = nn.Parameter(torch.ones((channels, h, w)),
                                requires_grad=False)
        self.noise_sampler = Normal(self.mean, self.std)

    def forward(self, _, latents):
        batch_size = latents.size(0)
        x = self.const.expand(torch.Size([batch_size]) + self.const.shape)
        x = x + self.bias  # broadcast the bias
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada1(x, latents)
        x = self.conv2d(x)
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada2(x, latents)

        return x


class StyleGanUpscaleBlock(nn.Module):

    def __init__(self, in_channels, out_channels, latent_dim, h, w):
        super().__init__()

        self.upsample = nn.Sequential(nn.Upsample(size=(h, w),
                                                  mode='nearest'),
                                      nn.Conv2d(in_channels,
                                                out_channels,
                                                kernel_size=(3, 3),
                                                padding=1,
                                                bias=False))
        self.ada1 = NormAdaIn(latent_dim, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3),
                               padding=1, bias=False)
        self.ada2 = NormAdaIn(latent_dim, out_channels)
        self.lrelu = nn.LeakyReLU(0.2)

        self.mean = nn.Parameter(torch.zeros((out_channels, h, w)),
                                 requires_grad=False)
        self.std = nn.Parameter(torch.ones((out_channels, h, w)),
                                requires_grad=False)
        self.noise_sampler = Normal(self.mean, self.std)

    def forward(self, x, latents):
        batch_size = x.size(0)
        x = self.upsample(x)
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada1(x, latents)
        x = self.conv2(x)
        x = self.lrelu(x)
        x = x + self.noise_sampler.sample(torch.Size([batch_size]))
        x = self.ada2(x, latents)

        return x


class LinearBlock(nn.Module):

    def __init__(self, linear_dim, style_dim, n_layers):
        """ Linear Block

        Args:

        linear_dim (int): input dimension
        output_dim(int): style dimension
        n_layers (int): numbers of total layers in the block. The first layer maps from linear_dim -> output_dim,
                        and the remaining (n_layers-1) layers are fully connected layers each of size output_dim.

        """
        super().__init__()

        modules = []
        self._ff_reduce = nn.Sequential(nn.Linear(linear_dim, style_dim),
                                        nn.LeakyReLU(negative_slope=0.2))

        for i in range(n_layers-1):
            modules.append(nn.Sequential(nn.Linear(style_dim, style_dim),
                                         nn.LeakyReLU(negative_slope=0.2)))

        self._lin_block = nn.ModuleList(modules)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = self._ff_reduce(x)
        for layer in self._lin_block:
            x = layer(x) + x
        return x


class NormAdaIn(nn.Module):

    def __init__(self, latent_dim, channels):
        """Given a latent vector we seek to map that into "styles" which should be a
        scaling factor and translation factor for each channel --> therefore we
        construct a linear linear layer with output equal twice the number of channels

        See (StyleGAN and citations therein) - https://arxiv.org/pdf/1812.04948.pdf

        Args:

        latent_dim (int): dimension of the latent layer (or styles)
        channels (int): number of channels that are normalized

        """
        super().__init__()
        self.lin = nn.Linear(latent_dim, channels*2)
        self._channels = channels
        self._norm = nn.InstanceNorm2d(self._channels)

    def forward(self, x, latents):
        """
        Args:
        x (obj:tensor): input to be normalized and scaled by styles (coming from latents)
        latents (obj:tensor): a set of latents (assumed fixed but do not have to be) which are mapped to "styles"

        Returns: channel normalized output

        """
        styles = self.lin(latents)
        batch_size = x.size(0)

        # x.dim() - 2 as we need to extend styles to match the remaining number
        # of dimensions of x (we build styles to match the first two dimensions
        # by default)
        shape = torch.Size([batch_size, self._channels] + (x.dim()-2)*[1])
        scale = styles[:, :self._channels].view(shape)
        bias = styles[:, self._channels:].view(shape)

        # see - https://pytorch.org/docs/stable/nn.html?highlight=instancenorm#torch.nn.InstanceNorm2d
        x = self._norm(x)

        return scale*x + bias
