import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
import numpy as np


class EncoderImg(nn.Module):
    def __init__(self, flags, nfilter=64, nfilter_max=1024, **kwargs):
        super().__init__()
        self.flags = flags
        self.c_dim = flags.class_dim
        self.s_dim = flags.style_m1_dim
        size = flags.img_size
        s0 = self.s0 = self.flags.s0
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)
        self.fc_mu = nn.Linear(self.nf0*s0*s0, self.c_dim)
        self.fc_logvar = nn.Linear(self.nf0*s0*s0, self.c_dim)
        if flags.factorized_representation:
            self.fc_mu_s = nn.Linear(self.nf0*s0*s0, self.s_dim)
            self.fc_logvar_s = nn.Linear(self.nf0*s0*s0, self.s_dim)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = self.resnet(out)
        out = out.view(batch_size, self.nf0*self.s0*self.s0)
        mu_c = self.fc_mu(actvn(out))
        logvar_c = self.fc_logvar(actvn(out))
        if self.s_dim > 0:
            mu_s = self.fc_mu_s(actvn(out))
            logvar_s = self.fc_mu_s(actvn(out))
        else:
            mu_s = None
            logvar_s = None

        return mu_s, logvar_s, mu_c, logvar_c


class DecoderImg(nn.Module):
    def __init__(self, flags, nfilter=64, nfilter_max=512, **kwargs):
        super().__init__()
        self.flags = flags
        size = flags.img_size
        if flags.factorized_representation:
            self.z_dim = flags.style_m1_dim+flags.class_dim
        else:
            self.z_dim = flags.class_dim

        s0 = self.s0 = self.flags.s0
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        self.fc = nn.Linear(self.z_dim, self.nf0*s0*s0)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2**(nlayers-i), nf_max)
            nf1 = min(nf * 2**(nlayers-i-1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2)
            ]

        blocks += [
            ResnetBlock(nf, nf),
        ]

        self.resnet = nn.Sequential(*blocks)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, z_style, z_content):
        batch_size = z_content.size(0)
        if self.flags.factorized_representation:
            z = torch.cat((z_style, z_content), dim=1)
        else:
            z = z_content
        out = self.fc(z)
        out = out.view(batch_size, self.nf0, self.s0, self.s0)
        out = self.resnet(out)
        out = self.conv_img(actvn(out))
        return torch.tanh(out), torch.sigmoid(out)


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out
