import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import matplotlib.colors as cm
from get_args import get_args
import os
datasets = [
    "Ethiopic",
    "Fashion",
    "Kannada",
    "KMNIST",
    "MNIST",
    "nko",
    "Osmanya",
    "Vai"
    ]

def get_correlations(biases, variances, labels):
    correlations = np.zeros((len(np.unique(labels)),))
    num_labels = np.zeros_like(correlations)
    biases = biases.flatten()
    variances = variances.flatten()
    labels = np.tile(labels, 8).squeeze()

    for label in np.unique(labels):
        num_labels[label] = np.sum(labels == label)
        bias = biases[labels == label]
        variance = variances[labels == label]
        pearson, pearson_p = pearsonr(bias, variance)
        spearman, spearman_p = spearmanr(bias, variance)
        correlations[label] = round(spearman, 2)

    return correlations, num_labels
def plot_array(array, biases, labels, plot_path, title):
    fig, axes = plt.subplots(2, 1, sharex=True, constrained_layout=True, figsize=(6,3))
    ax = axes[0]
    im = ax.imshow(array, aspect=1000, interpolation='none', cmap='hot')
    ax.set_yticks(np.arange(len(datasets)))
    ax.set_yticklabels(datasets)
    ax.yaxis.set_tick_params(length=0)
    ax.xaxis.set_tick_params(length=0)
    ax.set_xticks([])

    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Normalized task variance', rotation=90)
    ax.set_title("Units")

    correlations, num_labels = get_correlations(biases, array, labels)

    ax = axes[1]
    ax.imshow(labels, aspect=1000, interpolation='none', cmap='tab20')
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.yaxis.set_tick_params(length=0)
    RATIO = 2780/22224
    x_ticks = np.zeros((len(correlations),))
    curr_tick = 0
    for i in range(len(correlations)):
        x_ticks[i] = curr_tick + num_labels[i] * RATIO / 2
        curr_tick += num_labels[i] * RATIO

    ax.set_xticks(x_ticks)
    ax.set_xticklabels(correlations, fontsize=5, rotation=90)
    colors = [cm.to_hex(plt.cm.tab20(i)) for i in range(20)]    
    for xtick, color in zip(ax.get_xticklabels(), colors):
        xtick.set_color(color)
    ax.xaxis.set_tick_params(length=0)

    ax_t = ax.secondary_xaxis('top')
    ax_t.xaxis.set_tick_params(length=0)
    ax_t.set_xticks(x_ticks)
    ax_t.set_xticklabels(np.arange(len(correlations)) + 1, fontsize=5)
    for xtick, color in zip(ax_t.get_xticklabels(), colors):
        xtick.set_color(color)

    ax.text(0.45, 2.1, "Cluster #", transform=ax.transAxes, size=6)
    ax.text(0.4, -2.1, "Correlation with biases", transform=ax.transAxes, size=6)
    os.makedirs(f"{plot_path}/plots", exist_ok=True)
    plt.savefig(f"{plot_path}/plots/{title}.png")
    
args = get_args()
plt.style.use("./plot_params.dms")

variances = np.load(f"{args.results_path}/data/clustered_normalized_variances.npy")
biases = np.load(f"{args.results_path}/data/clustered_biases.npy")
labels = np.load(f"{args.results_path}/data/clustered_labels.npy")
plot_array(variances, biases, labels, args.results_path, "fig2a")