from tqdm import tqdm
import numpy as np
import torch



def kl_div(mus, sigs, unc_per_dim=False):
    # Assumes diagonal covariance matrix
    mus = mus.type(torch.float64).permute(1,0,2,3,4,5)
    Sigs = sigs.type(torch.float64).permute(1,0,2,3,4,5)
    sigs = Sigs**0.5
    if not unc_per_dim:
        mus = mus.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        Sigs = Sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        sigs = sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        tr_term = (Sigs[:,None]*(Sigs**-1)).sum(4)
        det_term = torch.log((Sigs/Sigs[:,None]).prod(4))
        quad_term = torch.einsum('ijklm->ijkl',(mus - mus[:,None])**2/Sigs)
        k = mus.shape[-1]
    else:
        # Note every dim is its own gaussian
        # therefore no sum for trace
        tr_term = (Sigs[:,None]*(Sigs**-1))
        det_term = torch.log((Sigs/Sigs[:,None]))
        quad_term = torch.einsum('ijklmno->ijklmno',(mus - mus[:,None])**2/Sigs)
        k = 1
    return .5 * (tr_term + det_term + quad_term - k)

def bhatt_div(mus, sigs, unc_per_dim=False):
    mus = mus.type(torch.float64).permute(1,0,2,3,4,5)
    Sigs = sigs.type(torch.float64).permute(1,0,2,3,4,5)
    sigs = Sigs**0.5
    if not unc_per_dim:
        mus = mus.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        Sigs = Sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        sigs = sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        mean_sig = (Sigs[:,None]+Sigs)/2
        quad_term = torch.einsum('ijklm->ijkl',(mus[:,None] - mus)**2/mean_sig)
        log_term = torch.log(((mean_sig)/
                torch.sqrt(Sigs[:,None]*Sigs)).prod(4))
    else:
        mean_sig = (Sigs[:,None]+Sigs)/2
        quad_term = 4*torch.einsum('ijklmno->ijklmno',(mus[:,None] - mus)**2/mean_sig)
        log_term = torch.log(((mean_sig)/
                torch.sqrt(Sigs[:,None]*Sigs)))
    return ((1/8)*quad_term+(1/2)*log_term)


def wasserstein_dist(mus, sigs, unc_per_dim=False):
    #NOTE this is the squared Wasserstein 2 distance
    mus = mus.type(torch.float64).permute(1,0,2,3,4,5)
    Sigs = sigs.type(torch.float64).permute(1,0,2,3,4,5)
    sigs = Sigs**0.5
    if not unc_per_dim:
        mus = mus.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        Sigs = Sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        sigs = sigs.reshape(mus.shape[0], mus.shape[1], mus.shape[2], -1)
        quad_term = torch.einsum('ijklm->ijkl',(mus[:,None] - mus)**2)
        tr_term = (Sigs[:,None]+Sigs-
                2*torch.sqrt(sigs[:,None]*Sigs*sigs[:,None])).sum(4)
    else:
        quad_term = torch.einsum('ijklmno->ijklmno',(mus[:,None] - mus)**2)
        tr_term = (Sigs[:,None]+Sigs-
                2*torch.sqrt(sigs[:,None]*Sigs*sigs[:,None]))
    return quad_term+tr_term



def paide(mus, sigs, pre_metric='KL', unc_per_dim=False):
    if pre_metric == 'KL':
        dist = kl_div(mus, sigs, unc_per_dim=unc_per_dim)
    elif pre_metric == 'Bhatt':
        dist = bhatt_div(mus, sigs, unc_per_dim=unc_per_dim)
    elif pre_metric == 'Wass':
        dist = wasserstein_dist(mus, sigs, unc_per_dim=unc_per_dim)
    numb_comp = mus.shape[1]
    weight = torch.tensor([1/numb_comp]).to(mus.device).type(torch.float64)
    pairwise_dist = torch.log(weight)+weight*torch.log(torch.exp(-dist).sum(1)).sum(0)
    pairwise_dist = -pairwise_dist
    return pairwise_dist, dist

def var_mus(mus):
    return mus.var(1) 
