from __future__ import absolute_import, print_function
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torchvision
from torchvision.utils import save_image

def mkdir(path):
    """create a single empty directory if it didn't exist
    Parameters:
        path (str) -- a single directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)

def get_model_setting(opt):
    if(opt.ModelName == 'FactorVAE'):
        if opt.Uncertain == 'Yes':
            model_setting = 'Uncertain' + opt.ModelName + '_lr' + str(opt.lr) + '_' + opt.Dataset + '_batchsize' + str(opt.BatchSize) + '_gamma' + str(opt.Gamma) + '_zdim' + str(opt.z_dim) + '_' + opt.Metric + '_protonum' + str(opt.Proto_num)
        else:
            model_setting = opt.ModelName + '_lr' + str(opt.lr) + '_' + opt.Dataset + '_batchsize' + str(opt.BatchSize) + '_gamma' + str(opt.Gamma) + '_zdim' + str(opt.z_dim) + '_' + opt.Metric + '_protonum' + str(opt.Proto_num)
    elif(opt.ModelName == 'BetaVAE'):
        if opt.Uncertain == 'Yes':
            model_setting = 'Uncertain' + opt.ModelName + '_lr' + str(opt.lr) + '_' + opt.Dataset + '_batchsize' + str(opt.BatchSize) + '_beta' + str(opt.Beta) + '_zdim' + str(opt.z_dim) + '_' + opt.Metric + '_protonum' + str(opt.Proto_num)
        else:
            model_setting = opt.ModelName + '_lr' + str(opt.lr) + '_' + opt.Dataset + '_batchsize' + str(opt.BatchSize) + '_beta' + str(opt.Beta) + '_zdim' + str(opt.z_dim) + '_' + opt.Metric + '_protonum' + str(opt.Proto_num)
    else:
        model_setting = ''
        print('Wrong Model Name.')
    return model_setting

def reconstruction_loss(x, x_recon):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum').div(x.size(0))
    # recon_loss = F.mse_loss(x_recon, x, reduction='sum').div(x.size(0))

    return recon_loss

def KL(alpha, device):
    beta = torch.FloatTensor(np.ones((1,10))).to(device)
    S_alpha = torch.sum(alpha, dim=1, keepdim=True)
    S_beta = torch.sum(beta, dim=1, keepdim=True)
    lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
    lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)

    dg0 = torch.digamma(S_alpha)
    dg1 = torch.digamma(alpha)

    kl = torch.sum((alpha - beta)*(dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni

    return kl

def uncertain_loss(p, alpha, global_step, annealing_step, device):
    S = torch.sum(alpha, dim=1, keepdim=True)
    E = alpha - 1
    
    A = torch.sum(p *(torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)

    annealing_coef = min(1.0, global_step/annealing_step)

    alp = E * (1-p) + 1
    B = annealing_coef * KL(alp, device)
    
    return (A + B)

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    
    return total_kld

def kl_divergence_metric(mu_proto, logvar_proto, mu_input, logvar_input, sum=0):
    # D_kl(p,q) =/ D_kl(q,p)
    batch_size = mu_input.size(0)
    assert batch_size != 0
    # mu = [batch, z_dim, 1, 1] => [batch, 1, z_dim]
    if mu_input.data.ndimension() == 4:
        mu_input = mu_input.view(mu_input.size(0), 1, mu_input.size(1))
    if logvar_input.data.ndimension() == 4:
        logvar_input = logvar_input.view(logvar_input.size(0), 1, logvar_input.size(1))

    # When mu_proto = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if mu_proto.size(0) == 1:
        mu_proto.expand(batch_size, mu_proto.size(1), mu_proto.size(2))
    # logvar_proto = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if logvar_proto.size(0) == 1:
        logvar_proto.expand(batch_size, logvar_proto.size(1), logvar_proto.size(2))
    # mu_input = [batch, 1 ,z_dim] => [batch, proto_num, z_dim]
    if mu_input.size(1) == 1:
        mu_input.expand(batch_size, mu_proto.size(1), mu_proto.size(2))
    # logvar_input = [batch, 1, z_dim] => [batch, proto_num, z_dim]
    if logvar_input.size(1) == 1:
        logvar_input.expand(batch_size, logvar_proto.size(1), logvar_proto.size(2))
    
    # kld_input_proto = [batch, proto_num, z_dim]
    kld_input_proto = 0.5 * (logvar_proto-logvar_input) + (logvar_input.exp() + (mu_input-mu_proto).pow(2))/(2 * logvar_proto.exp()) - 0.5
    
    # total_kld = [batch, proto_num]
    total_kld = kld_input_proto.sum(2)
    
    return total_kld

def jensen_shannon_distance(mu1, logvar1, mu2, logvar2, sum=0):
    batch_size = mu2.size(0)
    assert batch_size != 0
    # mu1, logvar1 is from prototypes
    # mu2 = [batch, z_dim, 1, 1] => [batch, 1, z_dim]
    if mu2.data.ndimension() == 4:
        mu2 = mu2.view(mu2.size(0), 1, mu2.size(1))
    if logvar2.data.ndimension() == 4:
        logvar2 = logvar2.view(logvar2.size(0), 1, logvar2.size(1))
    # When mu1 = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if mu1.size(0) == 1:
        mu1 = mu1.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_proto = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if logvar1.size(0) == 1:
        logvar1 = logvar1.expand(batch_size, logvar1.size(1), logvar1.size(2))
    # mu_input = [batch, 1 ,z_dim] => [batch, proto_num, z_dim]
    if mu2.size(1) == 1:
        mu2 = mu2.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_input = [batch, 1, z_dim] => [batch, proto_num, z_dim]
    if logvar2.size(1) == 1:
        logvar2 = logvar2.expand(batch_size, logvar2.size(1), logvar2.size(2))

    mu_m = (mu1+mu2) * 0.5
    logvar_m = (logvar1.exp() + logvar2.exp()).div(2).log()

    # [batch, proto_num, z_dim]
    kld1m = 0.5 * (logvar_m-logvar1) + (logvar1.exp() + (mu1-mu_m).pow(2))/(2 * logvar_m.exp()) - 0.5
    kld2m = 0.5 * (logvar_m-logvar2) + (logvar2.exp() + (mu2-mu_m).pow(2))/(2 * logvar_m.exp()) - 0.5
    jsd = 0.5*kld1m + 0.5*kld2m

    # total_jsd = [batch, proto_num]
    total_jsd = jsd.sum(2)
    
    return total_jsd

def jensen_tsallis_distance(mu1, logvar1, mu2, logvar2, sum=0):
    batch_size = mu2.size(0)
    assert batch_size != 0
    # mu1, logvar1 is from prototypes
    # mu2 = [batch, z_dim, 1, 1] => [batch, 1, z_dim]
    if mu2.data.ndimension() == 4:
        mu2 = mu2.view(mu2.size(0), 1, mu2.size(1))
    if logvar2.data.ndimension() == 4:
        logvar2 = logvar2.view(logvar2.size(0), 1, logvar2.size(1))

    # When mu1 = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if mu1.size(0) == 1:
        mu1 = mu1.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_proto = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if logvar1.size(0) == 1:
        logvar1 = logvar1.expand(batch_size, logvar1.size(1), logvar1.size(2))
    # mu_input = [batch, 1 ,z_dim] => [batch, proto_num, z_dim]
    if mu2.size(1) == 1:
        mu2 = mu2.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_input = [batch, 1, z_dim] => [batch, proto_num, z_dim]
    if logvar2.size(1) == 1:
        logvar2 = logvar2.expand(batch_size, logvar2.size(1), logvar2.size(2))

    # var = σ^2, logvar = log(σ^2)
    var1 = logvar1.exp() 
    var2 = logvar2.exp()
    # jtd = [batch, proto_num, z_dim]
    jtd = (calc_integral(var1, mu1, var1, mu1) + calc_integral(var2, mu2, var2, mu2) - 2*calc_integral(var1, mu1, var2, mu2))

    # total_jtd = [batch, proto_num]
    total_jtd = jtd.sum(2)
    
    return total_jtd

def calc_integral(sigma_i, mu_i, sigma_j, mu_j):
    # i = 1, j = 1 
    # JTD
    sum_sigma = torch.add(sigma_i, sigma_j)
    exp = (-(mu_i-mu_j).pow(2)/(2*sum_sigma)).exp()
    summ = exp/torch.sqrt(sum_sigma)

    return summ*(1/np.sqrt(2*np.pi))

def wasserstein_distance(mu1, logvar1, mu2, logvar2, sum=0):
    batch_size = mu2.size(0)
    assert batch_size != 0
    # mu1, logvar1 is from prototypes, mu2, logvar2 is from feature vector
    # mu2 = [batch, z_dim, 1, 1] => [batch, 1, z_dim]
    if mu2.data.ndimension() == 4:
        mu2 = mu2.view(mu2.size(0), 1, mu2.size(1))
    if logvar2.data.ndimension() == 4:
        logvar2 = logvar2.view(logvar2.size(0), 1, logvar2.size(1))

    # When mu1 = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if mu1.size(0) == 1:
        mu1 = mu1.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_proto = [1, proto_num, z_dim] => [batch, proto_num, z_dim]
    if logvar1.size(0) == 1:
        logvar1 = logvar1.expand(batch_size, logvar1.size(1), logvar1.size(2))
    # mu_input = [batch, 1 ,z_dim] => [batch, proto_num, z_dim]
    if mu2.size(1) == 1:
        mu2 = mu2.expand(batch_size, mu1.size(1), mu1.size(2))
    # logvar_input = [batch, 1, z_dim] => [batch, proto_num, z_dim]
    if logvar2.size(1) == 1:
        logvar2 = logvar2.expand(batch_size, logvar2.size(1), logvar2.size(2))

    # var = σ^2, logvar = log(σ^2) = 2log(σ)
    sigma1 = logvar1.div(2).exp() 
    sigma2 = logvar2.div(2).exp()

    # Wasserstein distance = dist1(mu) + dist2(variance)
    mu_dist = torch.norm(mu1 - mu2, p=2, dim=2).pow(2)
    var_dist = torch.norm(sigma1 - sigma2, p=2, dim=2).pow(2)

    total_wasserstein = mu_dist + var_dist

    return total_wasserstein.sqrt()
    
def permute_dims(z):
    assert z.dim() == 2

    B, _ = z.size()
    perm_z = []
    for z_j in z.split(1, 1):
        perm = torch.randperm(B).to(z.device)
        perm_z_j = z_j[perm]
        perm_z.append(perm_z_j)

    return torch.cat(perm_z, 1)

def save_recon_image(x, x_recon, save_dir, save_name, channel):
    n = min(x.size(0), 10)
    n_cols = n
    n_rows = 2
    g, b = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))
    for j in range(n_cols):
        if channel == 1:
            image = x[j].view(32, 32).data.cpu().numpy()
            recon_image = x_recon[j].view(32, 32).data.cpu().numpy()
            b[0][j].imshow(image, cmap='gray', interpolation='none')
            b[1][j].imshow(recon_image, cmap='gray', interpolation='none')
        elif channel == 3:
            image = x[j].view(3, 32, 32).data.cpu().numpy()
            recon_image = x_recon[j].view(3, 32, 32).data.cpu().numpy()
            image = np.moveaxis(image, [0, 1, 2], [2, 1, 0])
            recon_image = np.moveaxis(recon_image, [0, 1, 2], [2, 1, 0])
            b[0][j].imshow(image)
            b[1][j].imshow(recon_image)
         
    plt.savefig(os.path.join(save_dir, save_name),
                    transparent=True,
                    bbox_inches='tight',
                    pad_inches=0, dpi=480)
    plt.close()

    comparison = torch.cat([x[:n], x_recon.view(x.size(0), channel, 32, 32)[:n]])
    save_image(comparison.cpu(), save_dir + '/' + 'reconstruction_' + save_name + '.png', nrow=n)

def save_proto_image(proto_img, save_dir, save_name, num_prototypes):
    
    for i in range(num_prototypes):
        image = proto_img[i]
        save_image(image.cpu(), save_dir + '/' + 'Proto_' + save_name + str(i) + '.png')
