from __future__ import absolute_import, print_function
from typing import Union, Optional, Callable

import torch
from torch import nn
import torch.nn.init as init

import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self, z_dim):
        super(Discriminator, self).__init__()
        self.z_dim = z_dim
        self.net = nn.Sequential(
            nn.Linear(z_dim, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 2),
        )
        self.weight_init()

    def weight_init(self, mode='normal'):
        if mode == 'kaiming':
            initializer = kaiming_init
        elif mode == 'normal':
            initializer = normal_init

        for block in self._modules:
            for m in self._modules[block]:
                initializer(m)

    def forward(self, z):
        return self.net(z).squeeze()

class Proto_classification(torch.nn.Module):
    def __init__(self, chnum_in_, n_prototypes, n_classes, model_name, z_dim, sum_dist=0):
        super(Proto_classification, self).__init__()
        self.z_dim = z_dim
        self.chnum_in = chnum_in_
        self.encoder = nn.Sequential(
            # input channel, output channel, kernel size, stride, padding
            nn.Conv2d(self.chnum_in, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.Conv2d(32, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.Conv2d(64, 256, 4, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.Conv2d(256, 2*z_dim, 1)
        )
        
        self.decoder = nn.Sequential(
            #
            nn.Conv2d(z_dim, 256, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.ConvTranspose2d(256, 64, 4),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.ConvTranspose2d(32, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            #
            nn.ConvTranspose2d(32, self.chnum_in, 4, 2, 1),
            nn.Sigmoid(),
        )
        self.weight_init()

        if sum_dist == 1:
            self.linear = nn.Linear(n_prototypes, n_classes)
        elif sum_dist == 0:
            self.linear = nn.Linear(n_prototypes * z_dim, n_classes)
        
        # Prototypes = mu & logvar for one prototype
        # First dim is batch: [:, :self.z_dim] => mu, [:, self.z_dim:] => logvar
        self.prototypes = nn.Parameter(torch.FloatTensor(n_prototypes, 2*z_dim))
        nn.init.xavier_uniform_(self.prototypes)

    def weight_init(self, mode='normal'):
        if mode == 'kaiming':
            initializer = kaiming_init
        elif mode == 'normal':
            initializer = normal_init

        for block in self._modules:
            for m in self._modules[block]:
                initializer(m)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x, model_name, distance, sum_dist=0, no_dec=False):
        batch = x.size(0)
        f = self.encoder(x)
        
        # For reparameterization
        mu = f[:, :self.z_dim]
        logvar = f[:, self.z_dim:]
        z = self.reparameterize(mu, logvar)
        
        if no_dec:
            return z.squeeze()
        else:
            # For image reconstruction
            x_recon = self.decoder(z)
        
        # prototypes = [n_prototypes, 2*z_dim]
        # proto_mu = [1, n_prototypes, z_dim]
        proto_mu = self.prototypes[:, :self.z_dim].unsqueeze(0)
        proto_logvar = self.prototypes[:, self.z_dim:].unsqueeze(0)

        # [batch, n_prototypes], using D_kl, JSD, JTD
        # [batch, n_prototypes, z_dim] => [batch, n_prototypes * z_dim] when sum_dist == 0
        if sum_dist == 1:
            prototype_distances = distance(proto_mu, proto_logvar, mu, logvar, sum_dist)
        elif sum_dist == 0:
            prototype_distances = distance(proto_mu, proto_logvar, mu, logvar, sum_dist).view(batch, -1)
        
        # [batch, n_classes]
        logits = self.linear(prototype_distances)

        if model_name == 'BetaVAE':
            return logits, prototype_distances, x_recon, mu, logvar
        if model_name == 'FactorVAE':
            return logits, prototype_distances, x_recon, mu, logvar, z.squeeze()


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def normal_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)