"""
Hilbert Schmidt Information Criterion with a Gaussian kernel, based on the
following references
[1]: https://link.springer.com/chapter/10.1007/11564089_7
[2]: https://www.researchgate.net/publication/301818817_Kernel-based_Tests_for_Joint_Independence

Original HSIC calculation code by Inga Strumke (GitHub: https://github.com/strumke/hsic_python/blob/master/hsic.py) 
with some edits based on https://jejjohnson.github.io/research_journal/appendix/similarity/hsic/
"""
import numpy as np
from sklearn.preprocessing import KernelCenterer
from sklearn.metrics import silhouette_samples
import pandas as pd

def check_constant_columns_numpy(arr: np.ndarray) -> list[int]:
    """
    Checks for constant columns in a NumPy array.

    Args:
        arr (np.ndarray): The input NumPy array.

    Returns:
        list[int]: A list of column indices that are constant.
    """
    constant_cols = []
    for i in range(arr.shape[1]):
        if np.all(arr[:, i] == arr[0, i]):
            constant_cols.append(i)
    return constant_cols

def gaussian_grammat(x, sigma=None):
    """
    Calculate the Gram matrix of x using a Gaussian kernel.
    If the bandwidth sigma is None, it is estimated using the median heuristic:
    ||x_i - x_j||**2 = 2 sigma**2
    """
    try:
        x.shape[1]
    except IndexError:
        x = x.reshape(x.shape[0], 1)

    xxT = np.matmul(x, x.T)
    xnorm = np.diag(xxT) - xxT + (np.diag(xxT) - xxT).T
    if sigma is None:
        mdist = np.median(xnorm[xnorm!= 0])
        sigma = np.sqrt(mdist*0.5)


    # --- If bandwidth is 0, add machine epsilon to it
    if sigma==0:
        eps = 7./3 - 4./3 - 1
        sigma += eps

    KX = - 0.5 * xnorm / sigma / sigma
    np.exp(KX, KX)
    return KX

def HSIC(x, y):
    """
    Calculate the HSIC estimator for d=2, as in [1] eq (9)
    """
    #n = x.shape[0]
    #return np.trace(np.matmul(centering(gaussian_grammat(x)),centering(gaussian_grammat(y))))/n/n
    Kx = gaussian_grammat(x)
    Ky = gaussian_grammat(y)
    if np.any(np.isnan(Kx)):
        print(f"Data:\n{x}")
        print(f"Gram matrix:\n{Kx}")
        raise ValueError("NaN values in the Gram matrix. Check your input data.")
    
    if np.any(np.isnan(Ky)):
        print(f"Data:\n{y}")
        print(f"Gram matrix:\n{Ky}")
        raise ValueError("NaN values in the Gram matrix. Check your input data.")

    HKx = KernelCenterer().fit_transform(Kx)
    HKy = KernelCenterer().fit_transform(Ky)
    frobenius_norm_x = np.sqrt(np.sum(HKx**2))
    frobenius_norm_y = np.sqrt(np.sum(HKy**2))
    return np.sum(HKx * HKy) / (frobenius_norm_x * frobenius_norm_y)


def distance(x_i: np.ndarray, x_j: np.ndarray) -> float:
    # To make sure that the distance is always positive, we use 1-HSIC(x_i, x_j) as the distance metric.
    # Our implementation of HSIC can be just slightly above 1 for cases where the same or two highly correlated variables are passed.
    return max(0, 1-HSIC(x_i, x_j)) 

def evaluate_clusters_hsic(clusters: dict[str, pd.DataFrame], silhouette_threshold=0.5, outliers_threshold=0.25) -> dict[str, list[float | list[str] | list[str]]]:
    """
    Evaluate the variable groups using the mean of the silhouette score, based on 1-HSIC(x_i, x_j) distance.
    It returns the clusters with poor silhouette (< silhouette_threshold), the outliers in those clusters, 
    and the full list of variables in each cluster.

    The silhouette_threshold is used to determine which clusters are considered "bad" and should be returned.
    The outliers_threshold is used to detect variables within a cluster that have poor silhouette scores and are considered outliers.

    Args:
        clusters (dict[str, pd.DataFrame]): A dictionary where keys are cluster names (str) and values are DataFrame slices representing variables in the cluster.
        silhouette_threshold (float, optional): The silhouette score threshold below which clusters are considered poor. Defaults to 0.5.
        outliers_threshold (float, optional): The silhouette score threshold below which variables are considered outliers within a cluster. Defaults to 0.25.

    Returns:
        dict[str, list[float | list[str] | list[str]]]: A dictionary where:
            - Each key is the cluster name (str).
            - Each value is a list containing:
                - The average silhouette score of the cluster (float).
                - A list of variable names identified as outliers in the cluster (list[str]).
                - The full list of variable names in the cluster (list[str]).
    """

    results = {}

    # Convert the clusters to a matrix format for distance calculation
    samples_matrix = np.hstack([group.to_numpy() for group in clusters.values()])

    constant_columns = check_constant_columns_numpy(samples_matrix)
    if len(constant_columns) > 0:
        # Add noise to the whole dataset if there are any constant columns
        noise = np.random.normal(0, 0.0001, samples_matrix.shape)
        samples_matrix = samples_matrix + noise

    samples_matrix=samples_matrix.T

    # For the silhouette score, we need to create a label array that matches the order in the sample matrix shape
    labels = np.concatenate([[key] * group.shape[1] for key, group in clusters.items()])

    # Compute silhouette scores
    silhouette_scores = silhouette_samples(X=samples_matrix, labels=labels, metric=distance)

    for group_label in clusters.keys():
        # Calculate average silhouette score for the group
        group_scores = [silhouette_scores[i] for i, label in enumerate(labels) if label == group_label]
        avg_silhouette = np.mean(group_scores) if group_scores else 0

        # Find the indices of the outliers in the group
        outliers_idx = [i for i, score in enumerate(group_scores) if score < outliers_threshold]
        outliers_names = clusters[group_label].columns[outliers_idx].tolist() if outliers_idx else []

        # Get the full list of variable names in the cluster
        full_variable_names = clusters[group_label].columns.tolist()

        # Store the silhouette scores, outliers, and full variable list for the current group
        results[group_label] = [round(avg_silhouette, 2), outliers_names, full_variable_names]

    # Subset the results dict to only include clusters with silhouette scores below the threshold
    results = {k: v for k, v in results.items() if v[0] < silhouette_threshold}

    return results