import torch
import torch.nn as nn
import torch.nn.functional as F

import climategan.strings as strings
from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock
from climategan.norms import SpectralNorm


def create_painter(opts, no_init=False, verbose=0):
    if verbose > 0:
        print("  - Add PainterSpadeDecoder Painter")
    return PainterSpadeDecoder(opts)


class PainterSpadeDecoder(nn.Module):
    def __init__(self, opts):
        """Create a SPADE-based decoder, which forwards z and the conditioning
        tensors seg (in the original paper, conditioning is on a semantic map only).
        All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
        the channel dimension, and an upsampling is applied after each. Therefore
        2 upsamplings at this point. Then, for each remaining upsamplings
        (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
        channels, the number of channels is therefore:
            final_nc = channels(z) * 2 ** (spade_n_up - 2)
        Args:
            latent_dim (tuple): z's shape (only the number of channels matters)
            cond_nc (int): conditioning tensor's expected number of channels
            spade_n_up (int): Number of total upsamplings from z
            spade_use_spectral_norm (bool): use spectral normalization?
            spade_param_free_norm (str): norm to use before SPADE de-normalization
            spade_kernel_size (int): SPADE conv layers' kernel size
        Returns:
            [type]: [description]
        """
        super().__init__()

        latent_dim = opts.gen.p.latent_dim
        cond_nc = 3
        spade_n_up = opts.gen.p.spade_n_up
        spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm
        spade_param_free_norm = opts.gen.p.spade_param_free_norm
        spade_kernel_size = 3

        self.z_nc = latent_dim
        self.spade_n_up = spade_n_up

        self.z_h = self.z_w = None

        self.fc = nn.Conv2d(3, latent_dim, 3, padding=1)
        self.head_0 = SPADEResnetBlock(
            self.z_nc,
            self.z_nc,
            cond_nc,
            spade_use_spectral_norm,
            spade_param_free_norm,
            spade_kernel_size,
        )

        self.G_middle_0 = SPADEResnetBlock(
            self.z_nc,
            self.z_nc,
            cond_nc,
            spade_use_spectral_norm,
            spade_param_free_norm,
            spade_kernel_size,
        )
        self.G_middle_1 = SPADEResnetBlock(
            self.z_nc,
            self.z_nc,
            cond_nc,
            spade_use_spectral_norm,
            spade_param_free_norm,
            spade_kernel_size,
        )

        self.up_spades = nn.Sequential(
            *[
                SPADEResnetBlock(
                    self.z_nc // 2 ** i,
                    self.z_nc // 2 ** (i + 1),
                    cond_nc,
                    spade_use_spectral_norm,
                    spade_param_free_norm,
                    spade_kernel_size,
                )
                for i in range(spade_n_up - 2)
            ]
        )

        self.final_nc = self.z_nc // 2 ** (spade_n_up - 2)

        self.final_spade = SPADEResnetBlock(
            self.final_nc,
            self.final_nc,
            cond_nc,
            spade_use_spectral_norm,
            spade_param_free_norm,
            spade_kernel_size,
        )
        self.final_shortcut = None
        if opts.gen.p.use_final_shortcut:
            self.final_shortcut = nn.Sequential(
                *[
                    SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)),
                    nn.BatchNorm2d(3),
                    nn.LeakyReLU(0.2, True),
                ]
            )

        self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1)

        self.upsample = InterpolateNearest2d(scale_factor=2)

    def set_latent_shape(self, shape, is_input=True):
        """
        Sets the latent shape to start the upsampling from, i.e. z_h and z_w.
        If is_input is True, then this is the actual input shape which should
        be divided by 2 ** spade_n_up
        Otherwise, just sets z_h and z_w from shape[-2] and shape[-1]

        Args:
            shape (tuple): The shape to start sampling from.
            is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up
        """
        if isinstance(shape, (list, tuple)):
            self.z_h = shape[-2]
            self.z_w = shape[-1]
        elif isinstance(shape, int):
            self.z_h = self.z_w = shape
        else:
            raise ValueError("Unknown shape type:", shape)

        if is_input:
            self.z_h = self.z_h // (2 ** self.spade_n_up)
            self.z_w = self.z_w // (2 ** self.spade_n_up)

    def _apply(self, fn):
        super()._apply(fn)
        return self

    def forward(self, z, cond):
        if z is None:
            assert self.z_h is not None and self.z_w is not None
            z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w)))
        y = self.head_0(z, cond)
        y = self.upsample(y)
        y = self.G_middle_0(y, cond)
        y = self.upsample(y)
        y = self.G_middle_1(y, cond)

        for i, up in enumerate(self.up_spades):
            y = self.upsample(y)
            y = up(y, cond)

        if self.final_shortcut is not None:
            cond = self.final_shortcut(y)
        y = self.final_spade(y, cond)
        y = self.conv_img(F.leaky_relu(y, 2e-1))
        y = torch.tanh(y)
        return y

    def __str__(self):
        return strings.spadedecoder(self)
