"""Utility functions to compute classic disentanglement metrics.

Currently supports:
    * Mutual Information (MI) matrix between latent variables and ground–truth factors.
    * Normalised MI (NMI) – each MI value divided by the entropy of the latent variable.
    * Mutual Information Gap (MIG) as described in "β-VAE: Learning Basic Visual Concepts with a Constrained Variational
      Framework".

The implementation uses scikit-learn's KSG method for mutual information estimation which is
well-suited for continuous latent variables and discrete ground-truth factors.

All public helpers accept NumPy arrays.  Callers coming from JAX should first convert with
`np.asarray(jax.device_get(arr))`.
"""
from __future__ import annotations

import math
from typing import Tuple, Dict, Any

import numpy as np
from scipy.special import digamma
from sklearn.feature_selection import mutual_info_regression
from sklearn.neighbors import NearestNeighbors
import jax.numpy as jnp
# -----------------------------------------------------------------------------
# Basic helpers
# -----------------------------------------------------------------------------
from utils.printarr import printarr

def _entropy_knn(X: np.ndarray, k: int = 3) -> float:
    """Estimate entropy of continuous distribution using k-nearest neighbors.
    
    This is compatible with the KSG mutual information estimator and is based on 
    the same principles.
    
    Parameters
    ----------
    X : np.ndarray, shape (n_samples, n_features)
        The data to estimate entropy from.
    k : int
        Number of nearest neighbors to use.
        
    Returns
    -------
    h : float
        Estimated entropy in nats.
    """
    # Reshape if needed
    if X.ndim == 1:
        X = X.reshape(-1, 1)
    
    n_samples, n_dim = X.shape
    
    # Find the k nearest neighbors for each point
    nn = NearestNeighbors(n_neighbors=k+1)  # +1 because a point is its own neighbor
    nn.fit(X)
    distances, _ = nn.kneighbors(X)
    
    # Exclude the point itself by taking the distance to the kth neighbor
    epsilon = distances[:, k]
    
    # Estimate entropy using KSG formula
    h = digamma(n_samples) - digamma(k) + n_dim * np.mean(np.log(epsilon + np.finfo(float).eps))
    
    return float(h)


def _entropy_discrete(X: np.ndarray) -> float:
    """Compute entropy of a discrete distribution using empirical frequencies.
    
    Parameters
    ----------
    X : np.ndarray, shape (n_samples,)
        Discrete data to compute entropy from.
        
    Returns
    -------
    h : float
        Estimated entropy in nats.
    """
    # Ensure 1D
    X = X.flatten()
    
    # Count occurrences of each unique value
    unique_values, counts = np.unique(X, return_counts=True)
    
    # Compute probabilities
    probs = counts.astype(np.float64) / len(X)
    
    # Compute entropy using the formula: -sum(p * log(p))
    return float(-np.sum(probs * np.log(probs + 1e-12)))


def _is_discrete(X: np.ndarray) -> bool:
    """Determine if an array appears to contain discrete values.
    
    Parameters
    ----------
    X : np.ndarray
        Array to check.
        
    Returns
    -------
    is_discrete : bool
        True if X appears to contain discrete values.
    """
    # Check if integer type
    if np.issubdtype(X.dtype, np.integer):
        return True
    
    # Count unique values relative to array size
    unique_ratio = len(np.unique(X)) / len(X)
    
    # Heuristic: if there are relatively few unique values, treat as discrete
    if unique_ratio < 0.05:  # Less than 5% of values are unique
        return True
    
    return False

# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------

def mutual_information_matrix(
    latents: np.ndarray,
    factors: np.ndarray,
    *,
    n_neighbors: int = 7,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute MI matrix and per-variable entropies using KSG method.

    Parameters
    ----------
    latents : np.ndarray, shape (N, L)
        Continuous latent representations.
    factors : np.ndarray, shape (N, K)
        Ground-truth factors (can be continuous or discrete).
    n_neighbors : int, default=3
        Number of neighbors to use for KSG method.
        Higher values reduce variance but may introduce bias.

    Returns
    -------
    mi_matrix : np.ndarray, shape (K, L)
        I(f_k; z_l)
    H_latents : np.ndarray, shape (L,)
        Entropy H(z_l)
    H_factors : np.ndarray, shape (K,)
        Entropy H(f_k)
    """
    assert latents.shape[0] == factors.shape[0], "Latents and factors must share the first (sample) dimension."
    N, L = latents.shape
    _, K = factors.shape

    H_factors = np.array([
        _entropy_discrete(factors[:, k]) if _is_discrete(factors[:, k]) else _entropy_knn(factors[:, k], k=n_neighbors)
        for k in range(K)
    ])

    # MI matrix using KSG method from scikit-learn
    mi_matrix = np.zeros((K, L), dtype=np.float64)
    
    # For each latent dimension, compute mutual information with all ground truth factors
    # KSG method works better when continuous variable is the target
    for l in range(L):
        for k in range(K):
            # Use continuous latents for KSG method
            mi_values = mutual_info_regression(
                factors[:, k].reshape(-1, 1), latents[:, l], 
                discrete_features=jnp.issubdtype(factors.dtype, jnp.integer),
                n_neighbors=n_neighbors
            )
            mi_matrix[k, l] = mi_values

    return mi_matrix, H_factors


def normalised_mi_matrix(mi_matrix: np.ndarray, H_factors: np.ndarray) -> np.ndarray:
    """Return MI normalised by the corresponding factor entropy: I / H(f).
    
    Values greater than 1 are clipped to 0 as they represent invalid normalizations.
    """
    eps = 1e-12
    nmi = mi_matrix / (H_factors[:, None] + eps)
    
    # Clip values > 1 to 0 as they don't make conceptual sense
    nmi = np.where(nmi > 1.0, 0.0, nmi)
    mi = np.where(nmi > 1.0, 0.0, mi_matrix)
    
    return nmi, mi


def mc_metrics(nmi_matrix):

    """
        Compute the Modularity and Compactness metrics from the NMI matrix.
        NMI matrics is Factor x Latent

    """
    n_factors, n_latents = nmi_matrix.shape
    M = (np.mean(nmi_matrix.max(0) / nmi_matrix.sum(0)) - 1/n_factors) / (1 - 1/n_factors)
    C = (np.mean(nmi_matrix.max(1) / nmi_matrix.sum(1)) - 1/n_latents) / (1 - 1/n_latents)
    return M, C
    

def compute_disentanglement_metrics(
    latents: np.ndarray,
    factors: np.ndarray,
    *,
    n_neighbors: int = 3,
) -> Dict[str, Any]:
    """Convenience wrapper returning all supported metrics in a dict."""
    mi_matrix, H_factors = mutual_information_matrix(
        latents, factors, n_neighbors=n_neighbors)
    nmi_matrix, mi_matrix = normalised_mi_matrix(mi_matrix, H_factors)
    return {
        'mi_matrix': mi_matrix,
        'nmi_matrix': nmi_matrix,
        'entropy_factors': H_factors,
    } 