"""Stuff for analysis fo BERT/RoBERTa NMF decompositions."""
import dataclasses
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np
import seaborn as sns
import tensorflow as tf

from em.models import transformer_model_vars as tmv
from em.util import flat_pack

# typedefs
NameOrVariable = Union[str, tf.Variable]


###############################################################################

def print_top_examples(
    W: np.ndarray,
    tokenizer,
    input_ids,
    labels,
    component: int,
    n_examples: int,
):
    _, inds = tf.math.top_k(W[:, component], k=n_examples)
    for ind in inds:
        label = labels[ind]
        if isinstance(label, tf.Tensor):
            label = label.numpy()
        example = tokenizer.decode(input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        # example = example.replace(tokenizer.bos_token, '')
        # example = example.replace(tokenizer.eos_token, '')
        example = example.strip()
        print(f'{label}: {example}')


# ### => reordered component index
# @@@ => original component index
_LATEX_TABLE_PREFIX = R"""\begin{figure}[h]
% 
\begin{center}
\includesvg[width=0.9\linewidth]{images/comp_top_and_loc/ro###_og@@@_top_32_and_loc.svg}
% 
\begin{footnotesize}
\begin{sc}
\begin{tabular}{cccl}
\toprule
Label & Prediction & Coefficient & Example \\
\midrule"""

_LATEX_TABLE_SUFFIX = R"""\bottomrule
\end{tabular}
\end{sc}
\end{footnotesize}
\caption{\textit{Original component index:} @@@. \textit{Description:} \textbf{[TODO]}}
\label{fig:ro###_comp_ex_and_loc}
\end{center}
\vskip -0.1in
% 
\end{figure}"""


def print_top_examples_for_latex(
    decomp,
    pe_fishers_data,
    tokenizer,
    component: int,
    n_examples: int,
    label_map=('False', 'True'),
    permutation: Optional[np.ndarray] = None,
    full_figure: bool = False,
    capitalize_start_of_examples: bool = False,
):
    coeffs, inds = tf.math.top_k(decomp.W[:, component], k=n_examples)
    rows = []
    for coeff, ind in zip(coeffs, inds):
        coeff, ind = coeff.numpy(), ind.numpy()
        label = pe_fishers_data.labels[ind]
        prediction = np.argmax(pe_fishers_data.predicted_logits[ind])

        example = tokenizer.decode(pe_fishers_data.input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        example = example.replace(tokenizer.cls_token, '')
        example = example.replace(tokenizer.sep_token, '')
        example = example.strip()

        if capitalize_start_of_examples:
            example = example[0].upper() + example[1:]

        row = [
            label_map[label],
            label_map[prediction],
            f"{coeff:.4f}",
            "\\textup{\\texttt{" + example + '}}'
        ]
        rows.append((" & ".join(row)) + R' \\')

    body = '\n'.join(rows)

    if not full_figure:
        print(body)
        return

    assert permutation is not None

    reordered_comp_index = permutation.tolist().index(component)

    prefix = _LATEX_TABLE_PREFIX.replace('###', str(reordered_comp_index)).replace('@@@', str(component))
    suffix = _LATEX_TABLE_SUFFIX.replace('###', str(reordered_comp_index)).replace('@@@', str(component))

    figure = '\n'.join([prefix, body, suffix])
    print('\n\n')
    print(figure)
    print('\n\n')


def plot_coeffs_for_top_examples(
    decomp,
    permutation,
    component: int,
    n_examples: int,
    ax=None,
    figsize=(9, 4),
    *,
    show=True
):
    _, inds = tf.math.top_k(decomp.W[:, component], k=n_examples)

    x = decomp.W[inds, :]
    x = x[:, permutation]

    if figsize is not None:
        plt.figure(figsize=figsize)  

    if ax is None:
        ax = plt.subplot(111)

    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
    ax.spines["bottom"].set_visible(False)  
    ax.spines["left"].set_visible(False)  

    ax.get_yaxis().tick_left()  
    ax.get_xaxis().tick_bottom()

    ax.tick_params(axis='both', labelsize=11)

    ax.set_xlabel("Component", fontsize=14)
    ax.set_ylabel("Examples", fontsize=14)

    im = ax.imshow(
        x,
        cmap=sns.color_palette("rocket", as_cmap=True),
    )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2%", pad=0.1)

    plt.colorbar(im, cax=cax)

    plt.tight_layout()

    if show:
        plt.show()


def plot_localization_for_single_component(
    comp_frac_in_subsets: np.ndarray,
    subset_labels: Sequence[str],
    size_x: int = 1,
    size_y: int = 1,
    ax=None,
    figsize=(2, 5),
    *,
    show=True,
):
    assert size_y % 2, 'vertical_stretch must be odd'
    assert len(subset_labels) == len(comp_frac_in_subsets)

    x = [f * np.ones([size_x, size_y], dtype=comp_frac_in_subsets.dtype) for f in comp_frac_in_subsets]
    x = np.concatenate(x, axis=0)

    if figsize is not None:
        plt.figure(figsize=figsize)  

    if ax is None:
        ax = plt.subplot(111)

    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
    ax.spines["bottom"].set_visible(False)  
    ax.spines["left"].set_visible(False)  

    ax.get_xaxis().set_visible(False)

    ax.set_yticks(size_x // 2 + size_x * np.arange(len(subset_labels)))
    ax.set_yticklabels(subset_labels, fontsize=14)

    im = ax.imshow(
        x[::-1],
        # x,
        vmin=0,
        vmax=1,
        cmap=sns.color_palette("mako", as_cmap=True),
        interpolation=None,
    )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="25%", pad=0.1)

    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    if show:
        plt.show()


# def plot_coeffs_for_top_examples2(
#     decomp,
#     permutation,
#     components: Sequence[int],
#     n_examples: int,
#     n_rows: int,
#     n_cols: int,
#     figsize=(9, 4),
#     *,
#     show=True
# ):
#     assert len(components) <= n_rows * n_cols

#     fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize)

#     for i, component in enumerate(components):
#         _, inds = tf.math.top_k(decomp.W[:, component], k=n_examples)

#         row = i // n_cols
#         col = i % n_cols

#         x = decomp.W[inds, :]
#         x = x[:, permutation]

#         ax = axs[row][col]

#         ax.spines["top"].set_visible(False)  
#         ax.spines["right"].set_visible(False)  
#         ax.spines["bottom"].set_visible(False)  
#         ax.spines["left"].set_visible(False)  

#         ax.get_yaxis().tick_left()  
#         ax.get_xaxis().tick_bottom()

#         # ax.set_yticks(vertical_stretch // 2 + vertical_stretch * np.arange(len(subset_labels)))
#         # ax.set_yticklabels(subset_labels)

#         plt.xticks(fontsize=11)  
#         plt.yticks(fontsize=11)

#         # plt.xlabel("Components", fontsize=14)
#         # plt.ylabel("Examples", fontsize=14)

#         im = ax.imshow(
#             x,
#             cmap=sns.color_palette("rocket", as_cmap=True),
#         )

#         divider = make_axes_locatable(ax)
#         cax = divider.append_axes("right", size="2%", pad=0.1)
#         plt.colorbar(im, cax=cax)

#     # for ax in fig.get_axes():
#     #     ax.label_outer()

#     if show:
#         plt.show()


###############################################################################

def plot_fraction_of_pe_fisher_captured_histogram(dense_normalized_fishers, n_bins: int = 50):
    # Assuming that the pe_fishers_data was created with normalize=True. This means
    # that the fishers were divided by the dense_norm. The `dense_normalized_fishers` parameter
    # passed should be the pe_fishers_data.fishers.
    sparse_norms = np.linalg.norm(dense_normalized_fishers, axis=1)
       
    n, bins, patches = plt.hist(
        sparse_norms,
        n_bins, 
    )
    plt.show()


def plot_fraction_of_pe_fisher_captured_histogram_for_fisher_sizes(
    dense_normalized_fishers, sizes: Sequence[int], alpha: float = 0.35, colors=None, n_bins: int = 50, show=True,
):
    if colors is None:
        colors = len(sizes) * [None]
    assert len(colors) == len(sizes)

    plt.figure(figsize=(15, 5))  
  
    # Remove the plot frame lines. They are unnecessary chartjunk.  
    ax = plt.subplot(111)  
    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
      
    # Ensure that the axis ticks only show up on the bottom and left of the plot.  
    # Ticks on the right and top of the plot are generally unnecessary chartjunk.  
    ax.get_xaxis().tick_bottom()  
    ax.get_yaxis().tick_left()  
      
    # Make sure your axis ticks are large enough to be easily read.  
    # You don't want your viewers squinting to read your plot.  
    plt.xticks(fontsize=10)  
    # plt.yticks(range(5000, 30001, 5000), fontsize=14)  
    plt.yticks(fontsize=10)  
      
    # Along the same vein, make sure your axis labels are large  
    # enough to be easily read as well. Make them slightly larger  
    # than your axis tick labels so they stand out.  
    plt.xlabel("L2-Fraction Retained", fontsize=12)  
    plt.ylabel("Examples per Bin (Out of 32768 Total)", fontsize=12)  
    for size, color in zip(sizes, colors):
        sparse_norms = np.linalg.norm(dense_normalized_fishers[:, :size], axis=1)
        plt.hist(
            sparse_norms,
            n_bins,
            alpha=alpha,
            label=str(size),
            color=color,
        )
    plt.legend(
        loc='upper left',
        title='Non-Zero Elements',
        fontsize=12,
        # title_fontsize=12,
        title_fontsize=13,
        # prop={'size': 16, 'title_size': 18},
    )
    plt.tight_layout()
    if show:
        plt.show()

###############################################################################


def plot_component_locations(
    frac_in_subsets: np.ndarray,
    subset_labels: Sequence[str],
    vertical_stretch: int = 1,
    *,
    yticks_fontsize: int = 14,
    show=True,
):
    assert vertical_stretch % 2, 'vertical_stretch must be odd'
    assert len(subset_labels) == frac_in_subsets.shape[1]

    x = frac_in_subsets[..., None] * np.ones([*frac_in_subsets.shape, vertical_stretch], dtype=frac_in_subsets.dtype)
    x = x.reshape([frac_in_subsets.shape[0], -1]).T

    plt.figure(figsize=(15, 5))  

    ax = plt.subplot(111)

    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
    ax.spines["bottom"].set_visible(False)  
    ax.spines["left"].set_visible(False)  

    ax.get_yaxis().tick_left()  
    ax.get_xaxis().tick_bottom()

    ax.set_yticks(vertical_stretch // 2 + vertical_stretch * np.arange(len(subset_labels)))
    ax.set_yticklabels(subset_labels)

    plt.xticks(fontsize=11)  
    plt.yticks(fontsize=yticks_fontsize)

    plt.xlabel("NMF Component", fontsize=14)

    plt.title('NMF Component Localization', fontsize=18)

    im = ax.imshow(
        # frac_in_subsets.T[::-1],
        x[::-1],
        vmin=0,
        vmax=1,
        cmap=sns.color_palette("rocket", as_cmap=True),
        # aspect=aspect,
        interpolation=None,
    )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2%", pad=0.1)

    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    if show:
        plt.show()


def plot_sim_matrix(sim_matrix, *, show=True):
    plt.figure(figsize=(9, 9))  

    ax = plt.subplot(111)

    ax.spines["top"].set_visible(False)  
    ax.spines["right"].set_visible(False)  
    ax.spines["bottom"].set_visible(False)  
    ax.spines["left"].set_visible(False)  

    ax.get_yaxis().tick_left()  
    ax.get_xaxis().tick_bottom()

    # ax.set_yticks(vertical_stretch // 2 + vertical_stretch * np.arange(len(subset_labels)))
    # ax.set_yticklabels(subset_labels)

    plt.xticks(fontsize=11)  
    plt.yticks(fontsize=11)

    plt.xlabel("NMF Component", fontsize=14)
    plt.ylabel("NMF Component", fontsize=14)

    plt.title('NMF Component Cosine Similarity Matrix', fontsize=18)

    im = ax.imshow(
        # frac_in_subsets.T[::-1],
        sim_matrix,
        vmin=0,
        vmax=1,
        cmap=sns.color_palette("rocket", as_cmap=True),
        # aspect=aspect,
        interpolation=None,
    )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3.5%", pad=0.1)

    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    if show:
        plt.show()

# def make_comp_loc_and_sim_plot(sim_matrix, frac_in_subsets, *, show=True):

#     fig, axs = plt.subplots(2)
#     # fig.suptitle('Vertically stacked subplots')

#     for ax in axs:
#         # These suck for plotting data as images.
#         ax.spines["top"].set_visible(False)  
#         ax.spines["right"].set_visible(False)  
#         ax.spines["bottom"].set_visible(False)  
#         ax.spines["left"].set_visible(False)  
#         ax.get_yaxis().tick_left()  
        
#     axs[0].get_xaxis().tick_top()  
#     axs[1].get_xaxis().tick_bottom()

#     aspect = 8
        
#     # axs[0].set_box_aspect(frac_in_subsets.shape[0] / frac_in_subsets.shape[1])
#     axs[0].set_box_aspect(1)
#     axs[0].imshow(
#         frac_in_subsets.T[::-1],
#         vmin=0,
#         vmax=1,
#         cmap=sns.color_palette("mako", as_cmap=True),
#         # aspect=aspect,
#         interpolation=None,
#     )
#     # plt.colorbar()

#     axs[1].set_box_aspect(1)
#     axs[1].imshow(
#         sim_matrix,
#         vmin=0,
#         vmax=1,
#         cmap=sns.color_palette("rocket", as_cmap=True),
#         interpolation=None,
#     )
#     # plt.colorbar()

#     plt.tight_layout()
#     if show:
#         plt.show()


###############################################################################


def _layer_indices_from_var_names(names: Sequence[str]) -> Tuple[int, ...]:
    inds = {tmv.extract_layer_index(s) for s in names} - {None}
    return tuple(sorted(inds))


@dataclasses.dataclass
class ComponentLocalizationInfo:
    # We only really need these for their names and shapes.
    variables: Sequence[tf.Variable]

    def __post_init__(self):
        raise Exception("Deprecated. Use the ComponentLocalizationInfo from bert_nmf_analysis2.py.")
        self._var_names = [v.name for v in self.variables]
        self._var_shapes = [v.shape for v in self.variables]

        self._layer_indices = _layer_indices_from_var_names(self._var_names)

        self._packer = flat_pack.FlatPacker(self._var_shapes)

    def fraction_per_layer(self, component: np.ndarray):
        # NOTE: The returned array might not sum up to 1 since
        # embeddings and pooler are not part of it. The entries in the
        # returned array will in increasing order of the layers present
        # in the variables.
        unpacked_comp = self._packer.decode_tf(component)

        unpacked_comp_sums = [tf.reduce_sum(c) for c in unpacked_comp]
        comp_sum = tf.reduce_sum(component).numpy()

        ret = []
        for ind in self._layer_indices:
            subset_sums = [
                s for s, n in zip(unpacked_comp_sums, self._var_names)
                if tmv.extract_layer_index(n) == ind
            ]
            if len(subset_sums) == 0:
                ret.append(0.0)
            else:
                ret.append(tf.reduce_sum(subset_sums).numpy() / comp_sum)

        return ret

    def fraction_in_pooler(self, component: np.ndarray):
        unpacked_comp = self._packer.decode_tf(component)
        comp_sum = tf.reduce_sum(component).numpy()
        pooler_sum = tf.reduce_sum([
            tf.reduce_sum(s) for s, n in zip(unpacked_comp, self._var_names)
            if tmv.is_pooler_layer(n)
        ]).numpy()
        return pooler_sum / comp_sum
