import torch
import torch.nn as nn
from .gen_resblock import GenBlock
from config import opt

class Args():
    pass
args = Args()
args.latent_dim = 256
args.gf_dim = 256
args.bottom_width = 4
args.d_spectral_norm = True

class Decoder(nn.Module):
    def __init__(self, args = args, activation=nn.ReLU(inplace=False)):
        super(Decoder, self).__init__(), 
        self.bottom_width = args.bottom_width
        self.activation = activation
        self.ch = args.gf_dim
        # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.ch)
        self.l1 = nn.Linear(args.latent_dim*opt.n_fuse, (self.bottom_width ** 2) * self.ch)
        self.block2 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.block3 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.block4 = GenBlock(self.ch, self.ch, activation=activation, upsample=True)
        self.b5 = nn.BatchNorm2d(self.ch)
        self.c5 = nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1)

    def encoding_size(self):
        return 128

    def forward(self, h):

        h = self.l1(h).view(-1, self.ch, self.bottom_width, self.bottom_width)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.b5(h)
        h = self.activation(h)
        h = nn.Tanh()(self.c5(h))
        return h


"""Discriminator"""


def _downsample(x):
    # Downsample (Mean Avg Pooling with 2x2 kernel)
    return nn.AvgPool2d(kernel_size=2)(x)


class OptimizedDisBlock(nn.Module):
    def __init__(self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU(inplace=False)):
        super(OptimizedDisBlock, self).__init__()
        self.activation = activation

        self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad)
        self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad)
        self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        if args.d_spectral_norm:
            self.c1 = nn.utils.spectral_norm(self.c1)
            self.c2 = nn.utils.spectral_norm(self.c2)
            self.c_sc = nn.utils.spectral_norm(self.c_sc)

    def residual(self, x):
        h = x
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        h = _downsample(h)
        return h

    def shortcut(self, x):
        return self.c_sc(_downsample(x))

    def forward(self, x):
        return self.residual(x) + self.shortcut(x)


class DisBlock(nn.Module):
    def __init__(self, args, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,
                 activation=nn.ReLU(inplace=False), downsample=False):
        super(DisBlock, self).__init__()
        self.activation = activation
        self.downsample = downsample
        self.learnable_sc = (in_channels != out_channels) or downsample
        hidden_channels = in_channels if hidden_channels is None else hidden_channels
        self.c1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=ksize, padding=pad)
        self.c2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=ksize, padding=pad)
        if args.d_spectral_norm:
            self.c1 = nn.utils.spectral_norm(self.c1)
            self.c2 = nn.utils.spectral_norm(self.c2)

        if self.learnable_sc:
            self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
            if args.d_spectral_norm:
                self.c_sc = nn.utils.spectral_norm(self.c_sc)

    def residual(self, x):
        h = x
        h = self.activation(h)
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.downsample:
            h = _downsample(h)
        return h

    def shortcut(self, x):
        if self.learnable_sc:
            x = self.c_sc(x)
            if self.downsample:
                return _downsample(x)
            else:
                return x
        else:
            return x

    def forward(self, x):
        return self.residual(x) + self.shortcut(x)

#
# class Encoder(nn.Module):
#     def __init__(self, args=args, activation=nn.ReLU(inplace=False)):
#         super(Encoder, self).__init__()
#         self.ch = args.latent_dim
#         self.activation = activation
#         self.n_fuse = opt.n_fuse
#         # self.block1 = OptimizedDisBlock(args, 3, self.ch)
#         self.block1 = OptimizedDisBlock(args, 3*self.n_fuse, self.ch)
#         self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True)
#         self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
#         self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
#         self.l5 = nn.Linear(self.ch, 1, bias=False)
#         if args.d_spectral_norm:
#             self.l5 = nn.utils.spectral_norm(self.l5)
#
#     def forward(self, fuse_x):
#         # h = fuse_x[1]
#         assert len(fuse_x) == self.n_fuse
#         # weight = torch.Tensor([1.0 / self.n_fuse]).cuda()
#         # h = torch.zeros(fuse_x[0].shape).cuda()
#         # for x in fuse_x:
#         #     h += weight*x
#         h = torch.cat(fuse_x, 1)
#         h = self.block1(h)
#         h = self.block2(h)
#         h = self.block3(h)
#         h = self.block4(h)
#         h = self.activation(h)
#         # Global average pooling
#         h = h.sum(2).sum(2)
#         # h = self.l5(h)
#
#         return h

class SingleEncoder(nn.Module):
    def __init__(self, args=args, activation=nn.ReLU(inplace=False)):
        super(SingleEncoder, self).__init__()
        self.ch = args.latent_dim
        self.activation = activation
        # self.block1 = OptimizedDisBlock(args, 3, self.ch)
        self.block1 = OptimizedDisBlock(args, 3, self.ch)
        self.block2 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=True)
        self.block3 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
        self.block4 = DisBlock(args, self.ch, self.ch, activation=activation, downsample=False)
        self.l5 = nn.Linear(self.ch, 1, bias=False)
        if args.d_spectral_norm:
            self.l5 = nn.utils.spectral_norm(self.l5)

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.activation(h)
        # Global average pooling
        h = h.sum(2).sum(2)
        # h = self.l5(h)

        return h

class Encoder(nn.Module):
    def __init__(self, args=args, activation=nn.ReLU(inplace=False)):
        super(Encoder, self).__init__()
        self.ch = args.latent_dim
        self.n_fuse = opt.n_fuse
        encoder_list = [SingleEncoder() for _ in range(self.n_fuse)]
        self.encoders = nn.Sequential(*encoder_list)

    def forward(self, fuse_x):
        # h = fuse_x[1]
        assert len(fuse_x) == self.n_fuse
        z_list = []
        for i in range(self.n_fuse):
            z = self.encoders[i](fuse_x[i])
            z_list.append(z)
        h = torch.cat(z_list, 1)
        assert h.shape[1] == self.n_fuse * self.ch

        return h

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self,x):
        h = self.encoder(x)
        y = self.decoder(h)
        output = y + x[0]
        return output