"""Prints the loss (in terms of squared L2 loss) for a kmeans decomposition."""

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.kmeans import kmeans

###############################################################################
FLAGS = flags.FLAGS


flags.DEFINE_string('decomposition_filepath', None, 'A kmeans decomposition.')


###############################################################################


@torch.no_grad()
def main(_):
    # These are l2-distances, per-example.
    centroid_distances = kmeans.KmeansClusteringTorch.load_centroid_distances(FLAGS.decomposition_filepath)
    # Get to square l2-distance, summed over examples.
    final_loss = torch.einsum('i,i->', centroid_distances, centroid_distances).detach().cpu().numpy()

    print(f'final_loss: {final_loss}')


if __name__ == "__main__":
    app.run(main)
