import torch.nn as nn
import numpy as np


class InverseStylegan(nn.Module):

    def __init__(self, dataset_shape):
        super().__init__()
        channels, H, W = dataset_shape

        max_resolution = min(H, W)
        resolution_log2 = int(np.ceil(np.log2(max_resolution)))

        fmap_max = min(512, dataset_shape.numel())//4
        fmap_decay = 1.
        fmap_base = max_resolution*8

        def channels_at_stage(stage):
            return max(min(int(fmap_base / (2.0**(stage*fmap_decay))),
                           fmap_max), channels)

        downscale_modules = []
        in_channels = channels
        for stage in range(resolution_log2, 2, -1):
            out_channels = channels_at_stage(stage-2)
            downscale_modules.append(DownScaleConv2d(in_channels, out_channels,
                                                     padding=1, kernel_size=3))
            in_channels = out_channels
        downscale_modules.append(DownScaleConv2d(in_channels, in_channels*2,
                                                 padding=1, kernel_size=3,
                                                 downscale=True))
        downscale_modules.append(DownScaleConv2d(in_channels*2, in_channels*4,
                                                 padding=0, kernel_size=1,
                                                 downscale=False))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))
        self.layers = nn.ModuleList(downscale_modules)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class DownScale1d(nn.Module):

    def __init__(self, dataset_shape):
        super().__init__()
        c, max_resolution = dataset_shape
        resolution_log2 = int(np.ceil(np.log2(max_resolution)))

        fmap_max = min(512, dataset_shape.numel())//4
        fmap_decay = 1.
        fmap_base = max_resolution*8

        def channels_at_stage(stage):
            return max(min(int(fmap_base / (2.0**(stage*fmap_decay))),
                           fmap_max), c)

        downscale_modules = []
        in_channels = c
        for stage in range(resolution_log2, 2, -1):
            out_channels = channels_at_stage(stage-2)
            downscale_modules.append(DownScaleConv1d(in_channels, out_channels,
                                                     padding=1, kernel_size=3))
            in_channels = out_channels
        downscale_modules.append(DownScaleConv1d(in_channels, in_channels*2,
                                                 padding=1, kernel_size=3,
                                                 downscale=True))
        downscale_modules.append(DownScaleConv1d(in_channels*2, in_channels*4,
                                                 padding=0, kernel_size=1,
                                                 downscale=False))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))
        self.layers = nn.ModuleList(downscale_modules)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class DownScaleConv1d(nn.Module):

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3,
                 downscale=True):
        super().__init__()
        downscale_modules = []
        downscale_modules.append(nn.Conv1d(in_channels, out_channels,
                                           padding=1, kernel_size=3,
                                           bias=False))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))
        downscale_modules.append(nn.Conv1d(out_channels, out_channels,
                                           kernel_size=3, padding=1,
                                           bias=False))
        if downscale:
            downscale_modules.append(nn.AvgPool1d(2))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))

        self.downscale = nn.ModuleList(downscale_modules)

    def forward(self, x):
        for layer in self.downscale:
            x = layer(x)
        return x


class DownScaleConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3,
                 downscale=True):
        super().__init__()
        downscale_modules = []
        downscale_modules.append(nn.Conv2d(in_channels, out_channels,
                                           padding=1, kernel_size=3,
                                           bias=False))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))
        downscale_modules.append(nn.Conv2d(out_channels, out_channels,
                                           kernel_size=3, padding=1,
                                           bias=False))
        if downscale:
            downscale_modules.append(nn.AvgPool2d(2))
        downscale_modules.append(nn.LeakyReLU(negative_slope=0.2))

        self.downscale = nn.ModuleList(downscale_modules)

    def forward(self, x):
        for layer in self.downscale:
            x = layer(x)
        return x
