"""Plots histograms of and some statistics of absolute cosine similarities between pseudo-Fisher vectors."""


from absl import app
from absl import flags

import matplotlib.pyplot as plt
import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.decomps.npeff import lrm_npeff_decomps

###############################################################################
_DECOMPOSITION_TYPES = ['npeff', 'kmeans']
###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('decomposition_filepath', None, 'A kmeans decomposition.')
flags.DEFINE_enum('decomposition_type', 'npeff', _DECOMPOSITION_TYPES, 'Hack for allowing us to use this for gradient clusters.')

flags.DEFINE_integer('n_bins', 15, '')


flags.DEFINE_string('figure_filepath', None, 'Path to save figure to. If None, then will just display it.')

###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if FLAGS.decomposition_type == 'npeff':
        raw_decomposition = lrm_npeff_decomps.LrmNpeffDecomposition.load(
            FLAGS.decomposition_filepath, load_W=True, load_G=True)

        raw_decomposition = raw_decomposition.to(device)
        raw_decomposition.normalize_reduced_components_to_unit_norm_()

        pfvs = raw_decomposition.G
        del raw_decomposition

    elif FLAGS.decomposition_type == 'kmeans':
        raw_decomposition = kmeans.KmeansClusteringTorch.load(FLAGS.decomposition_filepath)

        raw_decomposition = raw_decomposition.to(device)
        raw_decomposition.normalize_reduced_components_to_unit_norm_()

        pfvs = raw_decomposition.centroids
        del raw_decomposition

    else:
        raise ValueError(FLAGS.decomposition_type)

    # TODO: Need to get rid of diagonal. The matrix is symmetric also, so off-diagonal elements will
    # get counted twiced.
    abs_cos_similarities = torch.einsum('cf,kf->ck', pfvs, pfvs).abs_()

    # This gets all indices above the main diagonal as a 2xN tensor.
    keep_sims_indices = torch.triu_indices(abs_cos_similarities.shape[0], abs_cos_similarities.shape[1], offset=1)
    # The abs_cos_similarities is a symmetric, square matrix, so the order of stuff here doesn't matter.
    flat_keep_sims_indices = keep_sims_indices[0] * abs_cos_similarities.shape[0] + keep_sims_indices[1]

    flat_abs_cos_similarities = abs_cos_similarities.view(-1)[flat_keep_sims_indices]

    # Print some summary statistics.
    mean_abs_sim = flat_abs_cos_similarities.mean().detach().cpu().numpy()
    median_abs_sim = torch.median(flat_abs_cos_similarities).detach().cpu().numpy()
    # TODO: Maybe fraction above threshold?

    print(f'mean: {mean_abs_sim}')
    print(f'median: {median_abs_sim}')

    # TODO: Make the plot nicer.
    plt.rcParams.update({'font.size': 20})
    plt.hist(flat_abs_cos_similarities.detach().cpu().numpy())
    plt.xlim(xmin=0.0, xmax=1.0)  # All abs cosine sims are in the range [0,1], so make the plot x-axis consistent.
    plt.yticks([])

    if FLAGS.figure_filepath:
        plt.savefig(FLAGS.figure_filepath)
    else:
        plt.show()


if __name__ == "__main__":
    app.run(main)
