import torch
import torch.nn as nn
from .gen_resblock import GenBlock
from config import opt
import torch.nn.functional as F

class Args():
    pass
args = Args()
args.latent_dim = 256
args.noise_dim = 64
args.gf_dim = 256
args.bottom_width = 4
args.d_spectral_norm = True

class Decoder(nn.Module):
    def __init__(self, args = args, activation=nn.ReLU(inplace=False)):
        super(Decoder, self).__init__(), 
        self.bottom_width = args.bottom_width
        self.activation = activation
        self.ch = args.gf_dim
        # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch)
        self.l1 = nn.Linear(args.latent_dim*opt.n_fuse+opt.noise_dim, (self.bottom_width ** 2) * self.ch)
        self.block2 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.block3 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.block4 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.b5 = nn.BatchNorm2d(self.ch)
        self.c5 = nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1)

    def encoding_size(self):
        return 128

    def forward(self, h):

        h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.b5(h)
        h = self.activation(h)
        h = nn.Tanh()(self.c5(h))
        return h


class AdaIN(nn.Module):
    def __init__(self, noise_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.fc = nn.Linear(noise_dim, num_features*2)

    def forward(self, x, z):
        # print(x.shape, z.shape)
        z = self.fc(z)
        z = z.view(z.size(0), z.size(1), 1, 1)
        gamma, beta = torch.chunk(z, chunks=2, dim=1)
        return (1 + gamma) * self.norm(x) + beta

class AdainResBlk(nn.Module):
    def __init__(self, dim_in, dim_out,
                 actv=nn.LeakyReLU(0.2), upsample=False):
        super().__init__()
        self.actv = actv
        self.upsample = upsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out, args.noise_dim)

    def _build_weights(self, dim_in, dim_out, noise_dim=64):
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
        self.norm1 = AdaIN(noise_dim, dim_in)
        self.norm2 = AdaIN(noise_dim, dim_out)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        if self.learned_sc:
            x = self.conv1x1(x)
        return x

    def _residual(self, x, s):
        x = self.norm1(x, s)
        x = self.actv(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv1(x)
        x = self.norm2(x, s)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, h, z):
        out = self._residual(h, z)
        # out = (out + self._shortcut(h)) / math.sqrt(2)
        out = out + self._shortcut(h)
        return out

class AdainDecoder(nn.Module):
    def __init__(self, args=args):
        super().__init__()
        dim_in = args.latent_dim * opt.n_fuse
        dim_out = dim_in
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_out, 3, 1, 1, 0))
        self.decode = nn.ModuleList()

        for _ in range(5):
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_in, upsample=True))  # stack-like
            dim_out = dim_in

        # bottleneck blocks
        for _ in range(2):
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_out))

    def forward(self, h, z):
        h = h.view(h.shape[0], h.shape[1], 1, 1)
        for block in self.decode:
            h = block(h, z)
        h = self.to_rgb(h)
        return h


"""Discriminator"""


def _downsample(x):
    # Downsample (Mean Avg Pooling with 2x2 kernel)
    return nn.AvgPool2d(kernel_size=2)(x)


class OptimizedDisBlock(nn.Module):
    def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU(inplace=False)):
        super(OptimizedDisBlock, self).__init__()
        self.activation = activation

        self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad)
        self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad)
        self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        if args.d_spectral_norm:
            self.c1 = nn.utils.spectral_norm(self.c1)
            self.c2 = nn.utils.spectral_norm(self.c2)
            self.c_sc = nn.utils.spectral_norm(self.c_sc)

    def residual(self, x):
        h = x
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        h = _downsample(h)
        return h

    def shortcut(self, x):
        return self.c_sc(_downsample(x))

    def forward(self, x):
        return self.residual(x) + self.shortcut(x)


class DisBlock(nn.Module):
    def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,
                 activation=nn.ReLU(inplace=False), downsample=False):
        super(DisBlock, self).__init__()
        self.activation = activation
        self.downsample = downsample
        self.learnable_sc = (in_channels != out_channels) or downsample
        hidden_channels = in_channels if hidden_channels is None else hidden_channels
        self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad)
        self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad)
        if args.d_spectral_norm:
            self.c1 = nn.utils.spectral_norm(self.c1)
            self.c2 = nn.utils.spectral_norm(self.c2)

        if self.learnable_sc:
            self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
            if args.d_spectral_norm:
                self.c_sc = nn.utils.spectral_norm(self.c_sc)

    def residual(self, x):
        h = x
        h = self.activation(h)
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.downsample:
            h = _downsample(h)
        return h

    def shortcut(self, x):
        if self.learnable_sc:
            x = self.c_sc(x)
            if self.downsample:
                return _downsample(x)
            else:
                return x
        else:
            return x

    def forward(self, x):
        return self.residual(x) + self.shortcut(x)


class SingleEncoder(nn.Module):
    def __init__(self, args=args, activation=nn.ReLU(inplace=False)):
        super(SingleEncoder, self).__init__()
        self.ch = args.latent_dim
        self.activation = activation
        # self.block1 = OptimizedDisBlock(args, 3, self.ch)
        self.block1 = OptimizedDisBlock(args, 3, self.ch)
        self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True)
        self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
        self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
        self.l5 = nn.Linear(self.ch, 1, bias=False)
        if args.d_spectral_norm:
            self.l5 = nn.utils.spectral_norm(self.l5)

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.activation(h)
        # Global average pooling
        h = h.sum(2).sum(2)
        # h = self.l5(h)

        return h

class Encoder(nn.Module):
    def __init__(self, args=args, activation=nn.ReLU(inplace=False)):
        super(Encoder, self).__init__()
        self.ch = args.latent_dim
        self.n_fuse = opt.n_fuse
        encoder_list = [SingleEncoder() for _ in range(self.n_fuse)]
        self.encoders = nn.Sequential(*encoder_list)

    def forward(self, fuse_x):
        # h = fuse_x[1]
        assert len(fuse_x) == self.n_fuse
        z_list = []
        for i in range(self.n_fuse):
            z = self.encoders[i](fuse_x[i])
            z_list.append(z)
        h = torch.cat(z_list, 1)
        assert h.shape[1] == self.n_fuse * self.ch

        return h

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x, z):
        h = self.encoder(x)
        h = torch.cat([h, z], 1)
        y = self.decoder(h)
        output = y + x[0]
        return output

# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.encoder = Encoder()
#         self.decoder = AdainDecoder()
#
#     def forward(self, x, z):
#         h = self.encoder(x)
#         y = self.decoder(h, z)
#         output = y + x[0]
#         return output