import numpy as np
import matplotlib.pyplot as plt
from entropy_estimators import continuous
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from get_args import get_args

args = get_args()
def get_correlations(biases, variances, labels):
    correlations = np.zeros((len(np.unique(labels)),))
    num_labels = np.zeros_like(correlations)
    biases = biases.flatten()
    variances = variances.flatten()
    labels = np.tile(labels, 8).squeeze()

    for label in np.unique(labels):
        num_labels[label] = np.sum(labels == label)
        bias = biases[labels == label]
        variance = variances[labels == label]
        pearson, pearson_p = pearsonr(bias, variance)
        spearman, spearman_p = spearmanr(bias, variance)
        correlations[label] = spearman

    return correlations, num_labels

def get_entropies(variances, labels):
    entropies = np.zeros((len(np.unique(labels)),))
    for label in np.unique(labels):
        variance = variances[:, labels == label].transpose()
        entropies[label] = continuous.get_h(variance)

    return entropies

plt.style.use("./plot_params.dms")


variances = np.load(f"{args.results_path}/data/clustered_normalized_variances.npy")
biases = np.load(f"{args.results_path}/data/clustered_biases.npy")
labels = np.load(f"{args.results_path}/data/clustered_labels.npy").squeeze()

entropies = get_entropies(variances, labels)
correlations, num_labels = get_correlations(biases, variances, labels)

# Delete -inf
entropies = np.delete(entropies, 1)
correlations = np.delete(correlations, 1)
num_labels = np.delete(num_labels, 1)
size = 5
plt.scatter(entropies, correlations, s=3) # , s=num_labels/np.max(num_labels) * size)
corrcoef = np.corrcoef(entropies, correlations)

plt.xlabel("TV Entropy")
plt.ylabel("TV/Bias Correlation")

plt.title(f"TV/Bias Correlation vs. TV Entropy, r={round(corrcoef[0, 1], 3)}")
plt.tight_layout()
plt.savefig(f"{args.results_path}/plots/fig2b.png")




