"""
Utility functions for calculating geometric mean of KL divergence values
"""
import numpy as np
from scipy import stats

def geometric_mean_safe(values, epsilon=1e-10):
    """
    Calculate geometric mean of values, handling zeros and negative values safely

    Args:
        values: list or array of values
        epsilon: small value to add to avoid log(0)

    Returns:
        geometric mean value
    """
    values = np.array(values)

    # Add small epsilon to avoid log(0)
    safe_values = np.maximum(values, epsilon)

    # Calculate geometric mean using scipy
    return stats.gmean(safe_values)

def calculate_metrics_with_geometric_kl(metrics_dict):
    """
    Calculate mean and std for all metrics, using geometric mean for KL divergence

    Args:
        metrics_dict: dictionary with metric names as keys and lists of values

    Returns:
        dictionary with mean and std for each metric
    """
    result = {}

    for metric_name, values in metrics_dict.items():
        if metric_name == 'kl_divergence':
            # Use geometric mean for KL divergence
            result[f'{metric_name}_mean'] = geometric_mean_safe(values)
            # For std, we can use the standard calculation or geometric std
            # Using standard std for consistency with other metrics
            result[f'{metric_name}_std'] = np.std(values)
        else:
            # Use arithmetic mean for other metrics
            result[f'{metric_name}_mean'] = np.mean(values)
            result[f'{metric_name}_std'] = np.std(values)

    return result