import tensorflow as tf
import numpy as np
from sklearn.decomposition import PCA


def compute_bhattacharyya_distance(mu1, mu2, sigma1, sigma2):
    """Compute Bhattacharyya distance between diagonal or spherical Gaussians."""
    avg_sigma = (sigma1 + sigma2) / 2
    first_part = tf.reduce_sum((mu1 - mu2)**2 / avg_sigma) / 8
    second_part = tf.reduce_sum(tf.math.log(avg_sigma))
    second_part -= 0.5 * (tf.reduce_sum(tf.math.log(sigma1)))
    second_part -= 0.5 * (tf.reduce_sum(tf.math.log(sigma2)))
    return first_part + 0.5 * second_part


def get_bhattacharyya_distance(per_class_stats, c1, c2, gaussian_type):
    """Return Bhattacharyya distance between 2 diagonal or spherical gaussians."""
    mu1 = per_class_stats[c1]['mean']
    mu2 = per_class_stats[c2]['mean']
    sigma1 = per_class_stats[c1]['variance']
    sigma2 = per_class_stats[c2]['variance']
    if gaussian_type == 'spherical':
        sigma1 = tf.reduce_mean(sigma1)
        sigma2 = tf.reduce_mean(sigma2)
    return compute_bhattacharyya_distance(mu1, mu2, sigma1, sigma2)


def compute_per_class_mean_and_variance(features, target_labels, unique_labels):
    """Compute features mean and variance for each class."""
    per_class_stats = {}
    for label in unique_labels:
        label = int(label)  # For correct indexing
        per_class_stats[label] = {}
        class_ids = tf.equal(target_labels, label)
        class_features = tf.gather_nd(features, tf.where(class_ids))
        mean = tf.reduce_mean(class_features, axis=0)
        variance = tf.math.reduce_variance(class_features, axis=0)
        per_class_stats[label]['mean'] = mean
        # Avoid 0 variance in cases of constant features with tf.maximum
        per_class_stats[label]['variance'] = tf.maximum(variance, 1e-4)
    return per_class_stats


def get_gbc_score(features, target_labels, gaussian_type):
    """Compute Gaussian Bhattacharyya Coefficient (GBC).
    Args:
    features: source features from the target data.
    target_labels: ground truth labels in the target label space.
    gaussian_type: type of gaussian used to represent class features. The
      possibilities are spherical (default) or diagonal.
    Returns:
    gbc: transferability metric score.
    """

    assert gaussian_type in ('diagonal', 'spherical')
    pca = PCA(n_components=64)
    feature_space = pca.fit_transform(features)
    unique_labels, _ = tf.unique(target_labels)
    unique_labels = list(unique_labels)
    per_class_stats = compute_per_class_mean_and_variance(
      feature_space, target_labels, unique_labels)

    per_class_bhattacharyya_distance = []
    for c1 in unique_labels:
        temp_metric = []
        for c2 in unique_labels:
            if c1 != c2:
                bhattacharyya_distance = get_bhattacharyya_distance(
                    per_class_stats, int(c1), int(c2), gaussian_type)
                temp_metric.append(tf.exp(-bhattacharyya_distance))
        per_class_bhattacharyya_distance.append(tf.reduce_sum(temp_metric))
    gbc = -tf.reduce_sum(per_class_bhattacharyya_distance)

    return gbc


# src_x_list = []
# src_y_list = []
# rng = np.random.default_rng()

# for i in range(5000):

#     # np.random.seed(30)
#     src_x_list.append(np.random.randn(524))
#     # np.random.seed(30)
#     src_y_list.append(np.random.randint(0,10))

# src_x = np.array(src_x_list)
# # src_y = np.array(src_y_list)

# # src_y = src_y.reshape(50, 1)
# # print(src_x.shape)
# # print(src_y_list)
# gbc = get_gbc_score(src_x, src_y_list, 'diagonal')
# print(gbc)

# # print(GBC(src_x, src_y_list))