import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .utils import (StyleGanBlockWithInput, StyleGanBlockConst,
                    StyleGanUpscaleBlock, LinearBlock)


class StyleGan(nn.Module):

    """StyleGan: Two dimensional deconvolution based on upsampling

    Here we implement the network presented in "A Style-Based Generator
    Architecture for Generative Adversarial Networks"

    ===================================================================
    CITE:
    KARRAS, Tero; LAINE, Samuli; AILA, Timo. A style-based generator
    architecture for generative adversarial networks. In: Proceedings of the
    IEEE Conference on Computer Vision and Pattern Recognition. 2019. p.
    4401-4410.
    ===================================================================

    Implementation is mainly inspired by:

    - https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb

    """

    def __init__(self, input_shape, output_shape, const_base=False):
        """
        Args:

        input_dim (int): input dimension
        H (int): height of the 2D output matrix
        W (int): width of the 2D output matrix
        linear_block_layers: number of layers in the linear block (mapping to styles)
        final_channel (int): number of output channels (3 is equivalent to RGB colors)

        """

        super(StyleGan, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_shape.numel()
        self.final_channels, H, W = output_shape
        # 512 is used in the paper
        self._latent_dim = min(input_shape.numel(), 512)

        max_resolution = max(H, W)
        resolution_log2 = int(np.ceil(np.log2(max_resolution)))

        h = 4 # the paper uses a value of 4
        w = 4 # the paper uses a value of 4
        if not(H >= h) or not(W >= w):
            raise ValueError("Input shape error: Height must be >={h} and width must be >={w}")

        fmap_max = 512
        fmap_decay = 1.0
        fmap_base = max_resolution * 8

        def channels_at_stage(stage):
            return max(min(int(fmap_base / (2.0 ** (stage * fmap_decay))),
                           fmap_max), self.final_channels)

        num_upsampling_blocks = resolution_log2

        channels = channels_at_stage(0)

        self._deconv_styles = LinearBlock(self.input_dim, self._latent_dim, 4)

        if const_base:
            self._input_block = StyleGanBlockConst(self._latent_dim, channels,
                                                   h, w)
        else:
            self._input_block = StyleGanBlockWithInput(self.input_dim,
                                                       channels,
                                                       self._latent_dim, h, w)

        upscale_module = []
        in_channels = channels
        # skip stage 1 and 2 as we start from a 4 x 4 image (and not a 1 x 1)
        for stage in range(3, resolution_log2 + 1):
            out_channels = channels_at_stage(stage-2)
            h = h*2 if 2*h < H else H
            w = w*2 if 2*w < W else W
            upscale_module.append(StyleGanUpscaleBlock(in_channels,
                                                       out_channels,
                                                       self._latent_dim, h, w))

            in_channels = out_channels

        self._upscaling = nn.ModuleList(upscale_module)
        self._to_rgb = nn.Conv2d(in_channels, self.final_channels,
                                 kernel_size=(1, 1), bias=False)

    def forward(self, x):
        latents = self._deconv_styles(x)

        # go through the constant block
        img = self._input_block(x, latents)

        # pass through all upscalings using the latents in the AdaIn modules
        for i, m in enumerate(self._upscaling):
            img = m(img, latents)
            if i < len(self._upscaling) - 1:
                img = F.leaky_relu(img, negative_slope=0.2)

        return self._to_rgb(img)

    def visualize_deconv(self, aspect=1):
        import matplotlib.pyplot as plt
        _ = plt.figure(figsize=(10, 5))

        random_input = torch.randn(self.input_dim).view(1, -1)
        viz = self.forward(random_input).detach().squeeze().numpy()

        for i in range(viz.shape[0]):
            m = plt.imshow(viz[i], aspect=aspect)
            plt.colorbar(m)
            plt.show()
