"""
Credit to https://github.com/lucidrains/stylegan2-pytorch/, from which this is mostly copied.
"""
import math
from math import log2
from functools import partial

from torch import nn
import torch.nn.functional as F


def leaky_relu(p=0.2):
    return nn.LeakyReLU(p)


class DiscriminatorBlock(nn.Module):
    def __init__(self, input_channels, filters, downsample=True):
        super(DiscriminatorBlock, self).__init__()
        self.conv_res = nn.Conv2d(input_channels, filters, 1,
                                  stride=(2 if downsample else 1))

        self.net = nn.Sequential(
            nn.Conv2d(input_channels, filters, 3, padding=1),
            leaky_relu(),
            nn.Conv2d(filters, filters, 3, padding=1),
            leaky_relu()
        )

        if downsample:
            self.downsample = nn.Conv2d(filters, filters, 3, padding=1,
                                        stride=2)
        else:
            self.downsample = None

    def forward(self, x):
        res = self.conv_res(x)
        x = self.net(x)
        if self.downsample is not None:
            x = self.downsample(x)
        x = (x + res) * (1 / math.sqrt(2))
        return x


class Discriminator(nn.Module):
    def __init__(self, image_size, network_capacity=16, fmap_max=512):

        super(Discriminator, self).__init__()
        num_layers = math.ceil(log2(image_size) - 1)
        num_init_filters = 3

        blocks = []
        filters = [num_init_filters] + [(network_capacity) * (2 ** i)
                                        for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan, out_chan,
                                       downsample=is_not_last)
            blocks.append(block)

        self.blocks = nn.ModuleList(blocks)

        latent_dim = 2 * 2 * filters[-1]

        self.to_out = nn.Linear(latent_dim, 1)

        self._init_weights()

    def forward(self, x):

        # modify x in case real data has bad size or bad num. channels
        b, c, h, w = x.shape
        assert h == w
        log_size = log2(h)
        if log_size != int(log_size):
            # image resolution is not a power of 2, so do padding
            pad_by = 2**math.ceil(log_size)-h
            pad_begin = pad_by // 2
            pad_end = pad_by - pad_begin
            x = F.pad(x, (pad_begin, pad_end, pad_begin, pad_end))
        if c != 3:
            # expand to have 3 channels
            x = x.mean(dim=1, keepdims=True).expand(-1, 3, -1, -1)

        for block in self.blocks:
            x = block(x)

        x = x.flatten(start_dim=1)
        x = self.to_out(x)
        return x.view(-1)

    def _init_weights(self):
        for m in self.modules():
            if type(m) in {nn.Conv2d, nn.Linear}:
                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
