"""
Credit to https://github.com/lucidrains/stylegan2-pytorch/, from which this is mostly copied.
"""
import math
from math import log2
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F

# from adamp import AdamP
from torch.autograd import grad as torch_grad

# constants

EPS = 1e-8


def leaky_relu(p=0.2):
    return nn.LeakyReLU(p)


class RGBBlock(nn.Module):
    def __init__(self, latent_dim, input_channel, upsample, out_filters):
        super(RGBBlock, self).__init__()
        self.input_channel = input_channel
        self.to_style = nn.Linear(latent_dim, input_channel)

        self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)

        if upsample:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                        align_corners=False)
        else:
            self.upsample = None

    def forward(self, x, prev_rgb, istyle):
        b, c, h, w = x.shape
        style = self.to_style(istyle)
        x = self.conv(x, style)

        if prev_rgb is not None:
            x = x + prev_rgb

        if self.upsample is not None:
            x = self.upsample(x)

        return x


class Conv2DMod(nn.Module):
    def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1,
                 dilation=1, **kwargs):
        super(Conv2DMod, self).__init__()
        self.filters = out_chan
        self.demod = demod
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel,
                                                kernel)))
        nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in',
                                nonlinearity='leaky_relu')

    def _get_same_padding(self, size, kernel, dilation, stride):
        return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2

    def forward(self, x, y):
        b, c, h, w = x.shape

        w1 = y[:, None, :, None, None]
        w2 = self.weight[None, :, :, :, :]
        weights = w2 * (w1 + 1)

        if self.demod:
            d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS)
            weights = weights * d

        x = x.reshape(1, -1, h, w)

        _, _, *ws = weights.shape
        weights = weights.reshape(b * self.filters, *ws)

        padding = self._get_same_padding(h, self.kernel, self.dilation,
                                         self.stride)
        x = F.conv2d(x, weights, padding=padding, groups=b)

        x = x.reshape(-1, self.filters, h, w)
        return x


class StyleVectorizer(nn.Module):
    def __init__(self, emb, depth):
        super(StyleVectorizer, self).__init__()

        layers = []
        for i in range(depth):
            layers.extend([nn.Linear(emb, emb), leaky_relu()])

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class GeneratorBlock(nn.Module):
    def __init__(self, latent_dim, input_channels, filters,
                 out_channels, upsample=True, upsample_rgb=True):
        super(GeneratorBlock, self).__init__()
        if upsample:
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                        align_corners=False)
        else:
            self.upsample = None

        self.to_style1 = nn.Linear(latent_dim, input_channels)
        self.to_noise1 = nn.Linear(1, filters)
        self.conv1 = Conv2DMod(input_channels, filters, 3)

        self.to_style2 = nn.Linear(latent_dim, filters)
        self.to_noise2 = nn.Linear(1, filters)
        self.conv2 = Conv2DMod(filters, filters, 3)

        self.activation = leaky_relu()
        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, out_channels)

    def forward(self, x, prev_rgb, istyle):
        if self.upsample is not None:
            x = self.upsample(x)

        # inoise = inoise[:, :x.shape[2], :x.shape[3], :]
        # noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
        # noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
        noise1 = torch.randn(x[:, :1].shape).to(x.device)
        noise2 = torch.randn(x[:, :1].shape).to(x.device)

        style1 = self.to_style1(istyle)
        x = self.conv1(x, style1)
        x = self.activation(x + noise1)

        style2 = self.to_style2(istyle)
        x = self.conv2(x, style2)
        x = self.activation(x + noise2)

        rgb = self.to_rgb(x, prev_rgb, istyle)
        return x, rgb


class Generator(nn.Module):
    def __init__(self, channels, image_size, latent_dim, network_capacity=16,
                 no_const=False, fmap_max=512,
                 style_depth=8, mixing_prob=0.9):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.latent_dim = latent_dim
        if log2(image_size) != int(log2(image_size)):
            print('Image size not a power of 2. Will output too big a shape.')
            self.num_layers = math.ceil(log2(image_size) - 1)
        else:
            self.num_layers = int(log2(image_size))  # this should be the same as math.ceil...
        self.mixing_prob = mixing_prob

        filters = [network_capacity * (2 ** (i + 1))
                   for i in range(self.num_layers)][::-1]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        init_channels = filters[0]
        filters = [init_channels, *filters]

        in_out_pairs = zip(filters[:-1], filters[1:])
        self.no_const = no_const

        if no_const:
            self.to_initial_block = nn.ConvTranspose2d(latent_dim,
                                                       init_channels,
                                                       4, 1, 0, bias=False)
        else:
            self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4,
                                                           4)))

        self.blocks = nn.ModuleList([])

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            not_first = ind != 0
            not_last = ind != (self.num_layers - 1)
            num_layer = self.num_layers - ind

            block = GeneratorBlock(
                latent_dim,
                in_chan,
                out_chan,
                channels,
                upsample=not_first,
                upsample_rgb=not_last,
            )
            self.blocks.append(block)

        self.style_vectorizer = StyleVectorizer(latent_dim, style_depth)

        self._init_weights()

        self.zero = nn.Parameter(torch.zeros(1,), requires_grad=False)
        self.ones = nn.Parameter(torch.ones(1,), requires_grad=False)
        self.z_dist = torch.distributions.Normal(self.zero, self.ones)

    def forward(self, styles):
        batch_size = styles.shape[0]
        image_size = self.image_size

        if self.no_const:
            avg_style = styles.mean(dim=1)[:, :, None, None]
            x = self.to_initial_block(avg_style)
        else:
            x = self.initial_block.expand(batch_size, -1, -1, -1)

        styles = styles.transpose(0, 1)

        rgb = None
        for style, block in zip(styles, self.blocks):
            x, rgb = block(x, rgb, style)

        # if image size is not a power of 2, crop down to what it should be
        gen_size = rgb.shape[2]
        crop_by = gen_size-self.image_size
        s = crop_by//2
        f = s+self.image_size
        rgb = rgb[:, :, s:f, s:f]

        return rgb

    def sample(self, batch_shape=torch.Size([1])):

        def sample_z():
            return self.z_dist.sample(
                batch_shape+torch.Size([self.latent_dim])
            ).squeeze(-1)

        if self.training:
            do_mixing = torch.rand(1).item() < self.mixing_prob
        else:
            do_mixing = False

        if do_mixing:
            # mixing regularisation use two zs
            zs = torch.cat([sample_z() for _ in range(2)], dim=0)
            ws = self.style_vectorizer(zs)
            w1, w2 = ws.chunk(dim=0, chunks=2)
            switchover_location = int(torch.rand(()).numpy() * self.num_layers)
            w = torch.cat([
                w1.unsqueeze(1).expand(-1, switchover_location, -1),
                w2.unsqueeze(1).expand(-1,
                                       self.num_layers - switchover_location,
                                       -1),
            ], dim=1)
        else:
            z = sample_z()
            w = self.style_vectorizer(z)
            w = w.unsqueeze(1).expand(-1, self.num_layers, -1)

        return self(w)

    def _init_weights(self):
        for m in self.modules():
            if type(m) in {nn.Conv2d, nn.Linear}:
                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in',
                                        nonlinearity='leaky_relu')

        for block in self.blocks:
            nn.init.zeros_(block.to_noise1.weight)
            nn.init.zeros_(block.to_noise2.weight)
            nn.init.zeros_(block.to_noise1.bias)
            nn.init.zeros_(block.to_noise2.bias)


