"""Given a k-means decomposition, computes the cluster identity for a set of gradients."""

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.peis.gradients.formats import dn_gradients

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('kmeans_filepath', None, '')

flags.DEFINE_list('gradients_filepath', None, '')
flags.DEFINE_list('n_examples_per_gradient', [],
                  "Comma-separated list of integers indicating the number of examples to use from each gradient file. "
                  "If provided, the list must be the same length as the --gradients_filepath list. "
                  "Leave empty to use all examples from all gradients. "
                  "Use a value of -1 for a particular gradient to use all examples from that particular gradient.")

flags.DEFINE_bool('normalize_gradients', True, '')

###############################################################################


def _get_n_examples_per_gradient():
    gradients_filepath = FLAGS.gradients_filepath
    n_examples_per_gradient = FLAGS.n_examples_per_gradient

    if not n_examples_per_gradient:
        return [-1 for _ in gradients_filepath]

    assert len(gradients_filepath) == n_examples_per_gradient
    return [int(i) for i in n_examples_per_gradient]


def load_gradients(device: torch.device) -> torch.Tensor:
    gradients = []
    for filepath, n_examples in zip(FLAGS.gradients_filepath, _get_n_examples_per_gradient()):
        x = dn_gradients.load_gradients(filepath)
        if n_examples >= 0:
            x = x[:n_examples]
        gradients.append(x)
    return torch.cat(gradients, dim=0).to(device)


def load_norms(device: torch.device) -> torch.Tensor:
    norms = []
    for filepath, n_examples in zip(FLAGS.gradients_filepath, _get_n_examples_per_gradient()):
        x = dn_gradients.load_norms(filepath)
        if n_examples >= 0:
            x = x[:n_examples]
        norms.append(x)
    return torch.cat(norms, dim=0).to(device)


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.output_filepath is not None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    kmeans_decomp = kmeans.KmeansClusteringTorch.load(FLAGS.kmeans_filepath).to(device)

    gradients = load_gradients(device)

    if FLAGS.normalize_gradients:
        norms = load_norms(device)
        gradients.div_(norms[:, None])

    cluster_assignments, centroid_distances = kmeans_decomp.compute_clusters(gradients)

    out = kmeans.KmeansClusteringTorch(
        centroids=kmeans_decomp.centroids,
        cluster_assignments=cluster_assignments,
        centroid_distances=centroid_distances,
    )
    out.save(FLAGS.output_filepath)
    
    
if __name__ == "__main__":
    app.run(main)
