##### code taken from https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py
# Alexia: Changed so that if self.n_classes==1, we dont use projection discriminator
# Vikram: Provided defaults

import argparse
import functools

import torch
import torch.nn as nn
from torch.nn import init

from . import layers_biggan as layers


def disc_default_config():
    config = argparse.Namespace()
    # Width multiplier
    config.ch = 64
    # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
    config.thin = False
    # Image
    config.im_ch = 3
    config.im_size = 64
    # Kernel size
    config.kernel_size = 3
    # Attention? # Number of attention filters (If 0, do not use self-attention)
    config.attn = '64'
    # Number of classes
    config.n_classes = 1
    # Initialization style
    config.init = 'xavier'
    # Loss
    config.adv_loss = 'LSGAN'
    return config


# Discriminator architecture
def D_arch(im_ch=3, ch=64, attention='64', ksize='333333', dilation='111111'):
    arch = {}
    arch[256] = {'in_channels': [im_ch] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
                 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
                 'downsample': [True] * 6 + [False],
                 'resolution': [128, 64, 32, 16, 8, 4, 4],
                 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                               for i in range(2, 8)}}
    arch[128] = {'in_channels': [im_ch] + [ch * item for item in [1, 2, 4, 8, 16]],
                 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
                 'downsample': [True] * 5 + [False],
                 'resolution': [64, 32, 16, 8, 4, 4],
                 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                               for i in range(2, 8)}}
    arch[64] = {'in_channels': [im_ch] + [ch * item for item in [1, 2, 4, 8]],
                'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
                'downsample': [True] * 4 + [False],
                'resolution': [32, 16, 8, 4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 7)}}
    # arch[32] = {'in_channels': [im_ch] + [item * ch for item in [4, 4, 4]],
    #             'out_channels': [item * ch for item in [4, 4, 4, 4]],
    #             'downsample': [True, True, False, False],
    #             'resolution': [16, 16, 16, 16],
    #             'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
    #                           for i in range(2, 6)}}
    arch[32] = {'in_channels': [im_ch] + [item * ch for item in [1, 2, 4]],
                'out_channels': [item * ch for item in [1, 2, 4, 8]],
                'downsample': [True] * 3 + [False],
                'resolution': [16, 8, 4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 6)}}
    arch[28] = {'in_channels': [im_ch] + [item * ch for item in [1, 2, 4]],
                'out_channels': [item * ch for item in [1, 2, 4, 8]],
                'downsample': [True] * 3 + [False],
                'resolution': [16, 8, 4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 6)}}
    arch[16] = {'in_channels': [im_ch] + [item * ch for item in [2, 4]],
                'out_channels': [item * ch for item in [2, 4, 4]],
                'downsample': [True, True, False],
                'resolution': [8, 4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 5)}}
    arch[8] = {'in_channels': [im_ch] + [item * ch for item in [4]],
                'out_channels': [item * ch for item in [4, 4]],
                'downsample': [True, False],
                'resolution': [4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 4)}}
    arch[4] = {'in_channels': [im_ch] + [item * ch for item in []],
                'out_channels': [item * ch for item in [4]],
                'downsample': [False],
                'resolution': [4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 3)}}
    return arch


class Discriminator(nn.Module):

    def __init__(self, disc_config=None):
        super(Discriminator, self).__init__()

        if disc_config is None:
            disc_config = disc_default_config()

        self.disc_config = disc_config

        self.param_count = 0
        num_D_SVs = 1
        num_D_SV_itrs = 1
        SN_eps = 1e-12
        output_dim = 1
        # D_mixed_precision = False  ## UNUSED
        # D_fp16 = False   ## UNUSED
        skip_init = False
        self.activation = nn.ReLU(inplace=False)

        # Width multiplier
        self.ch = disc_config.ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = not disc_config.thin
        # Image
        self.im_size = disc_config.im_size
        self.im_ch = disc_config.im_ch
        # Kernel size
        self.kernel_size = disc_config.kernel_size
        # Attention?
        self.attention = disc_config.attn
        # Number of classes
        self.n_classes = disc_config.n_classes

        # Initialization style
        self.init = disc_config.init
        # Architecture
        self.arch = D_arch(self.im_ch, self.ch, self.attention)[self.im_size]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        self.which_conv = functools.partial(layers.SNConv2d,
                                            kernel_size=3, padding=1,
                                            num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                            eps=SN_eps)
        self.which_linear = functools.partial(layers.SNLinear,
                                              num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                              eps=SN_eps)
        self.which_embedding = functools.partial(layers.SNEmbedding,
                                                 num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                                                 eps=SN_eps)
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
                                           out_channels=self.arch['out_channels'][index],
                                           which_conv=self.which_conv,
                                           wide=self.D_wide,
                                           activation=self.activation,
                                           preactivation=(index > 0),
                                           downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
                self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                                     self.which_conv)]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
        # Embedding for projection discrimination
        if self.n_classes != 1:
            self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

    # Initialize
    def init_weights(self):
        for module in self.modules():
            if (isinstance(module, nn.Conv2d)
                    or isinstance(module, nn.Linear)
                    or isinstance(module, nn.Embedding)):
                if self.init == 'ortho':
                    init.orthogonal_(module.weight)
                elif self.init == 'N02':
                    init.normal_(module.weight, 0, 0.02)
                elif self.init in ['glorot', 'xavier']:
                    init.xavier_uniform_(module.weight)
                else:
                    print('Init style not recognized...')
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print('Param count for D''s initialized parameters: %d' % self.param_count)

    def forward(self, x, y=None):
        # Stick x into h for cleaner for loops without flow control
        h = x
        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            for block in blocklist:
                h = block(h)
        # Apply global sum pooling as in SN-GAN
        h = torch.sum(self.activation(h), [2, 3])
        # Get initial class-unconditional output
        out = self.linear(h)
        # Get projection of final featureset onto class vectors and add to evidence
        if self.n_classes != 1:
            out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
        return out.view(x.shape[0])
