import torch
import numpy as np
import math

def Maxprob(prob_X):

    if not isinstance(prob_X, torch.Tensor):
        raise ValueError("Input must be a PyTorch tensor")

    if prob_X.dim() != 2:
        raise ValueError("Input tensor must be two-dimensional (n, c)")

    return torch.max(prob_X, dim=1).values  # Get max probability for each sample


def Ent(prob_X):

    if not isinstance(prob_X, torch.Tensor):
        raise ValueError("Input must be a PyTorch tensor")

    if prob_X.dim() != 2:
        raise ValueError("Input tensor must be two-dimensional (n, c)")

    # Ensure probabilities sum to 1 (if not already normalized)
    prob_X = prob_X / prob_X.sum(dim=1, keepdim=True)

    # Number of classes (K)
    K = prob_X.shape[1]

    # Avoid log(0) by adding a small epsilon
    epsilon = 1e-5
    log_prob_X = torch.log(prob_X + epsilon)  # log of probability

    # Compute entropy: Ent(x) = -Σ p_c log(p_c)
    entropy = torch.sum(prob_X * log_prob_X, dim=1)

    # Normalize entropy by log(K): Ent(x) / log(K)
    normalized_entropy = 1 + entropy / torch.log(torch.tensor(K, dtype=torch.float32))

    return normalized_entropy


def JMDS(features, prob_X):
    class_num = prob_X.shape[1]
    uniform = torch.ones(len(features), class_num) / class_num
    uniform = uniform.cuda()

    pi = prob_X.sum(dim=0)
    mu = torch.matmul(prob_X.t(), (features))
    mu = mu / pi.unsqueeze(dim=-1).expand_as(mu)

    zz, gamma = gmm((features), pi, mu, uniform)
    pred_label = gamma.argmax(dim=1)

    for round in range(1):
        pi = gamma.sum(dim=0)
        mu = torch.matmul(gamma.t(), (features))
        eps = 1e-8
        mu = mu / (pi.unsqueeze(dim=-1).expand_as(mu) + eps)

        zz, gamma = gmm((features), pi, mu, gamma)
        pred_label = gamma.argmax(axis=1)

    aff = gamma

    sort_zz = zz.sort(dim=1, descending=True)[0]
    zz_sub = sort_zz[:, 0] - sort_zz[:, 1]

    LPG = zz_sub / zz_sub.max()

    PPL = prob_X.gather(1, pred_label.unsqueeze(dim=1)).squeeze()
    JMDS_score = (LPG * PPL)

    return JMDS_score


def OT_score():
    pass


def gmm(all_fea, pi, mu, all_output, epsilon=1e-4):
    Cov = []
    dist = []
    log_probs = []

    for i in range(len(mu)):
        temp = all_fea - mu[i]
        predi = all_output[:, i].unsqueeze(dim=-1)
        Covi = torch.matmul(temp.t(), temp * predi.expand_as(temp)) / (predi.sum()) + epsilon * torch.eye(
            temp.shape[1]).cuda()
        try:
            chol = torch.linalg.cholesky(Covi)
        except RuntimeError:
            Covi += epsilon * torch.eye(temp.shape[1]).cuda() * 100
            chol = torch.linalg.cholesky(Covi)
        chol_inv = torch.inverse(chol)
        Covi_inv = torch.matmul(chol_inv.t(), chol_inv)
        logdet = torch.logdet(Covi)
        mah_dist = (torch.matmul(temp, Covi_inv) * temp).sum(dim=1)
        log_prob = -0.5 * (Covi.shape[0] * np.log(2 * math.pi) + logdet + mah_dist) + torch.log(pi)[i]
        Cov.append(Covi)
        log_probs.append(log_prob)
        dist.append(mah_dist)
    Cov = torch.stack(Cov, dim=0)
    dist = torch.stack(dist, dim=0).t()
    log_probs = torch.stack(log_probs, dim=0).t()
    zz = log_probs - torch.logsumexp(log_probs, dim=1, keepdim=True).expand_as(log_probs)
    gamma = torch.exp(zz)

    return zz, gamma
