import torch
import torch.nn.functional as F
from scipy.stats import wasserstein_distance

def overlap_rate(X: torch.Tensor, Y: torch.Tensor):
    
    X_min, X_max = X.min(dim=0).values, X.max(dim=0).values
    Y_min, Y_max = Y.min(dim=0).values, Y.max(dim=0).values
    
    
    intersection_min = torch.max(X_min, Y_min)
    intersection_max = torch.min(X_max, Y_max)
    
    
    union_min = torch.min(X_min, Y_min)
    union_max = torch.max(X_max, Y_max)
    
    
    intersection_length = (intersection_max - intersection_min).clamp(min=0)
    union_length = union_max - union_min
    
    
    overlap_rate = intersection_length.sum() / union_length.sum()
    
    return overlap_rate.item()

def bhattacharyya_distance(mean_X, cov_X, mean_Y, cov_Y, epsilon=1e-6):
    
    cov_X_reg = cov_X + epsilon * torch.eye(cov_X.size(0))
    cov_Y_reg = cov_Y + epsilon * torch.eye(cov_Y.size(0))
    
    
    mean_diff = mean_X - mean_Y
    
    
    cov_mean = (cov_X_reg + cov_Y_reg) / 2
    
    
    inv_cov_mean = torch.inverse(cov_mean)
    term1 = 0.125 * torch.dot(mean_diff, torch.mv(inv_cov_mean, mean_diff))
    
    
    try:
        term2 = 0.5 * torch.logdet(cov_mean) - 0.25 * (torch.logdet(cov_X_reg) + torch.logdet(cov_Y_reg))
    except RuntimeError:
        print("Log determinant calculation failed due to numerical issues.")
        return float('inf')  
    
    
    BC = term1 + term2
    return BC

def overlap_rate_bhattacharyya(X: torch.Tensor, Y: torch.Tensor):
    
    mean_X = torch.mean(X, dim=0)
    mean_Y = torch.mean(Y, dim=0)
    
    cov_X = torch.cov(X.T)  
    cov_Y = torch.cov(Y.T)  
    
    
    BC = bhattacharyya_distance(mean_X, cov_X, mean_Y, cov_Y)
    print(f"Bhattacharyya distance: {BC:.4f}")
    
    overlap_rate = torch.exp(-BC)
    
    return overlap_rate.item()


def get_prob_distributions(expert_outputs):
    return F.softmax(expert_outputs, dim=-1)  


def compute_kl_divergence(p, q):
    epsilon = 1e-10
    p = p + epsilon
    q = q + epsilon
    return - (F.kl_div(p.log(), q, reduction='batchmean') + 
            F.kl_div(q.log(), p, reduction='batchmean')) / 2

def compute_wasserstein_distance(p, q):
    
    p_sorted, _ = torch.sort(p)
    q_sorted, _ = torch.sort(q)
    
    
    
    wasserstein_dist = torch.mean(torch.abs(p_sorted - q_sorted))
    
    return -wasserstein_dist
