import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        xavier_uniform_(m.weight)
        m.bias.data.fill_(0.1)


def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))


def snlinear(in_features, out_features):
    return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))


def sn_embedding(num_embeddings, embedding_dim):
    return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))


class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, first_block=False):
        super(DiscBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.first_block = first_block
        self.relu = nn.ReLU(inplace=True)
        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.downsample = nn.AvgPool2d(2)
        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x0 = x

        if not self.first_block:
            x = self.relu(x)
        x = self.snconv2d1(x)
        x = self.relu(x)
        x = self.snconv2d2(x)
        x = self.downsample(x)

        x0 = self.snconv2d0(x0)
        x0 = self.downsample(x0)

        out = x + x0
        return out


class Discriminator(nn.Module):
    """Discriminator."""

    def __init__(self, d_conv_dim=64, num_layers=3):
        super(Discriminator, self).__init__()
        self.d_conv_dim = d_conv_dim
        self.num_layers = num_layers
        self.blocks = nn.ModuleList()
        self.blocks.append(DiscBlock(3, d_conv_dim, first_block=True))
        for i in range(num_layers-1):
            self.blocks.append(DiscBlock(d_conv_dim, d_conv_dim))
        self.snlinear1 = snlinear(in_features=d_conv_dim, out_features=1)
        self.relu = nn.ReLU(inplace=True)

        # Weight init
        self.apply(init_weights)

    def forward(self, x):
        # n x 3 x 128 x 128
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)    # n x d_conv_dim x 32 x 32
        x = self.relu(x)              # n x d_conv_dim x 4 x 4
        x = torch.sum(x, dim=[2,3])   # n x d_conv_dim
        output = torch.squeeze(self.snlinear1(x)) # n
        return output
