import torch
from scipy import linalg
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def orthogonal_loss(V):
    V_T = V.T
    VV_T = torch.matmul(V, V_T)
    I = torch.eye(VV_T.shape[0], device=V.device)
    orthogonal_loss = torch.norm(VV_T - I, p='fro')  # Frobenius
    return orthogonal_loss

def dissimilar_loss(V):
    dis = V @ V.permute(1, 0)
    dissimilarity_loss = dis[~torch.eye(dis.shape[0], dtype=torch.bool, device='cpu')].abs().mean()
    return dissimilarity_loss

class MMD_loss(nn.Module):
    def __init__(self, kernel_mul = 2.0, kernel_num = 5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        return
    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)
    def forward(self, source, target):
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX + YY - XY -YX)
        return loss


def calculate_act_statistics(act):
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma
def calculate_fid(act1, act2):
    m1, s1 = calculate_act_statistics(act1)
    m2, s2 = calculate_act_statistics(act2)
    return calculate_frechet_distance(m1, s1, m2, s2)
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert (
            mu1.shape == mu2.shape
    ), "Training and test mean vectors have different lengths"
    assert (
            sigma1.shape == sigma2.shape
    ), "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
                  "fid calculation produces singular product; "
                  "adding %s to diagonal of cov estimates"
              ) % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean




class InferenceBlock(nn.Module):
    def __init__(self, input_units, d_theta, output_units):
        """
        :param d_theta: dimensionality of the intermediate hidden layers.
        :param output_units: dimensionality of the output.
        :return: batch of outputs.
        """
        super(InferenceBlock, self).__init__()
        self.module = nn.Sequential(
            nn.Linear(input_units, d_theta, bias=True),
            nn.ELU(inplace=True),
            nn.Linear(d_theta, d_theta, bias=True),
            nn.ELU(inplace=True),
            nn.Linear(d_theta, output_units, bias=True),
        )
    def forward(self, inps):
        out = self.module(inps)
        return out

class Amortized(nn.Module):
    def __init__(self, input_units=400, d_theta=400, output_units=400):
        super(Amortized, self).__init__()
        self.weight_mean = InferenceBlock(input_units, d_theta, output_units)
        self.weight_log_variance = InferenceBlock(input_units, d_theta, output_units)
    def forward(self, inps):
        weight_mean = self.weight_mean(inps)                  # μ: mu
        weight_log_variance = self.weight_log_variance(inps)  # Σ: logvar
        return weight_mean, weight_log_variance

def Wasserstein_score(self, mu, sigma, mu_s, sigma_s):
    scores = []
    for b in range(len(mu_s)):
        mean_dis = torch.sum((mu - mu_s[b]) ** 2)
        std_dis = torch.sum((sigma + sigma_s[b] - 2 * sigma * sigma_s[b]))
        scores.append(mean_dis + std_dis)
    scores = torch.stack(scores)
    return scores
def calc_mean_std(self, pac_token, eps=1e-6):
    # eps is a small value added to the variance to avoid divide-by-zero.
    feat = pac_token.permute(0, 2, 1)
    size = feat.size()
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std
def sample(self, mu, logvar, L):  # reparameterization trick
    # σ=logvar.exp().sqrt(), ϵ=torch.randn(shape), z=μ+ϵ⋅σ
    # L indicates the number of samples you want to draw?
    shape = (L,) + mu.size()  # (L, 1, vis_dim), the shape of the noise to be sampled.
    eps = torch.randn(shape).type_as(mu)  # μ:0 Σ:1, [L, 1, vis_dim]
    bias = mu.unsqueeze(0) + eps * logvar.exp().sqrt().unsqueeze(0)  # [L, 1, vis_dim]
    # This scales the random noise eps by the standard deviation σ, introducing variability around the mean mu.
    # Finally, we add the scaled noise (i.e., σ * ε) to the mean μ, producing the sampled latent variable z.
    # This represents a point drawn from the Gaussian distribution 𝑁(𝜇,𝜎^2).
    # eps represents random fluctuations, with zero mean and unit variance.
    # It is multiplied by the standard deviation to scale the noise accordingly,
    # ensuring that the generated sample is consistent with the desired distribution.
    return bias

def consistency_loss(scoresM1, scoresM2, type='euclidean'):
    if(type=='euclidean'):
        avg_pro = (scoresM1 + scoresM2)/2.0
        matrix1 = torch.sqrt(torch.sum((scoresM1 - avg_pro)**2,dim=1))
        matrix2 = torch.sqrt(torch.sum((scoresM2 - avg_pro)**2,dim=1))
        dis1 = torch.mean(matrix1)
        dis2 = torch.mean(matrix2)
        dis = (dis1+dis2)/2.0
    elif(type=='KL1'):
        avg_pro = (scoresM1 + scoresM2)/2.0
        matrix1 = torch.sum( F.softmax(scoresM1,dim=-1) * (F.log_softmax(scoresM1, dim=-1) - F.log_softmax(avg_pro,dim=-1)), 1)
        matrix2 = torch.sum( F.softmax(scoresM2,dim=-1) * (F.log_softmax(scoresM2, dim=-1) - F.log_softmax(avg_pro,dim=-1)), 1)
        dis1 = torch.mean(matrix1)
        dis2 = torch.mean(matrix2)
        dis = (dis1+dis2)/2.0
    elif(type=='KL2'):
        matrix = torch.sum( F.softmax(scoresM2,dim=-1) * (F.log_softmax(scoresM2, dim=-1) - F.log_softmax(scoresM1,dim=-1)), 1)
        dis = torch.mean(matrix)
    elif(type=='KL3'):
        matrix = torch.sum( F.softmax(scoresM1,dim=-1) * (F.log_softmax(scoresM1, dim=-1) - F.log_softmax(scoresM2,dim=-1)), 1)
        dis = torch.mean(matrix)
    else:
        return
    return dis