"""Common stuff for analysis of HANS stuff."""

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import seaborn as sns

from em.projects.ll import hans_util
from em.projects.wino import nmf_components_fisher as ncf


###############################################################################


def get_accuracy(container):
    preds = np.argmax(container.predicted_logits, axis=-1)
    return (container.labels == preds).astype(np.float64).mean()


def get_fractions_per_subset(container, indicator, selection_parameters):
    tuned_comp_infos = ncf.get_components_appearing_tuned(
        container,
        indicator=indicator,
        selection_parameters=selection_parameters,
    )
    infos_by_nmf = ncf.group_by_nmf(container, tuned_comp_infos)
    return [
        len(infos) / nmf.W.shape[-1]
        for nmf, infos in zip(container.nmfs, infos_by_nmf)
    ]

###############################################################################
###############################################################################


def plot_frac_true_false(container, title=None, *, show=True):
    pred_trues = container.predictions == 1
    pred_falses = container.predictions == 0
    #
    sel_params_tf = ncf.SelectionParameters(
        coeff_factor=0.4,
        frac_threshold=0.75,
        p_value_threshold=0.05,
    )
    #
    frac_trues = get_fractions_per_subset(container, pred_trues, sel_params_tf)
    frac_falses = get_fractions_per_subset(container, pred_falses, sel_params_tf)
    #
    x_axis = np.arange(container.n_nmfs)
    #
    plt.bar(x_axis - 0.2, frac_trues, 0.4, label='Pred: Non-Entailment', color='red')
    plt.bar(x_axis + 0.2, frac_falses, 0.4, label='Pred: Entailment', color='green')
    #
    if title is not None:
        plt.title(title)

    plt.xlabel("Parameter Subset")
    plt.ylabel("Fraction of Components")
    plt.legend()

    if show:
        plt.show()

###############################################################################


def _sort_comps(nmf, H, comp_sort, mass_fraction):
    if comp_sort is None and mass_fraction is not None:
        raise ValueError('comp_sort must not be None if mass_fraction is not None')

    if comp_sort is None:
        return H

    if comp_sort == 'avg_coeff':
        comp_mags = np.sum(nmf.W, axis=0)
    elif comp_sort == 'avg_sq_coeff':
        comp_mags = np.sum(nmf.W**2, axis=0)
    else:
        raise ValueError(comp_sort)

    inds = np.argsort(-comp_mags)
    H = H[inds]

    if mass_fraction is not None:
        total = np.sum(comp_mags)
        sorted_comp_mags = comp_mags[inds]
        n_comps = 1 + (np.cumsum(sorted_comp_mags) < mass_fraction * total).sum()
        H = H[:n_comps]

    return H


def plot_component_similarity(
    container1,
    container2,
    nmf_index: int,
    container1_number=None,
    container2_number=None,
    title=None,
    *,
    comp_sort=None,
    mass_fraction=None,
    show=True,
):  
    # Computing

    nmf1 = container1.nmfs[nmf_index]
    nmf2 = container2.nmfs[nmf_index]

    H1 = nmf1.get_full_H() / np.sqrt(np.sum(nmf1.H**2, axis=-1, keepdims=True))
    H2 = nmf2.get_full_H() / np.sqrt(np.sum(nmf2.H**2, axis=-1, keepdims=True))

    H1 = _sort_comps(nmf1, H1, comp_sort, mass_fraction)
    H2 = _sort_comps(nmf2, H2, comp_sort, mass_fraction)

    cos_sims = H1 @ H2.T

    # Plotting

    plt.figure(figsize=(9, 9))  

    ax = plt.subplot(111)

    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
    ax.spines["bottom"].set_visible(False)  
    ax.spines["left"].set_visible(False)  

    ax.get_yaxis().tick_left()  
    ax.get_xaxis().tick_bottom()

    plt.xticks(fontsize=11)  
    plt.yticks(fontsize=11)
    #
    im = ax.imshow(
        cos_sims,
        vmin=0,
        vmax=1,
        cmap=sns.color_palette("rocket", as_cmap=True),
        interpolation=None,
    )

    # This 2/1 ordering is correct.
    if container2_number is not None:
        plt.xlabel(f'MNLI Model {container2_number}', fontsize=14)
    if container1_number is not None:
        plt.ylabel(f'MNLI Model {container1_number}', fontsize=14)

    if title is not None:
        plt.title(title, fontsize=14)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2%", pad=0.1)

    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    if show:
        plt.show()
