"""Learns k-means clusters of dense gradients."""

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.peis.gradients.formats import dn_gradients

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_list('gradients_filepath', None, '')
flags.DEFINE_list('n_examples_per_gradient', [],
                  "Comma-separated list of integers indicating the number of examples to use from each gradient file. "
                  "If provided, the list must be the same length as the --gradients_filepath list. "
                  "Leave empty to use all examples from all gradients. "
                  "Use a value of -1 for a particular gradient to use all examples from that particular gradient.")


flags.DEFINE_integer('n_clusters', None, '')
flags.DEFINE_integer('n_iterations', None, '')

flags.DEFINE_bool('normalize_gradients', True, '')

flags.DEFINE_string('alg', "compute_kmeans_lloyds", 'Method attached to the kmeans module.')

###############################################################################


def _get_n_examples_per_gradient():
    gradients_filepath = FLAGS.gradients_filepath
    n_examples_per_gradient = FLAGS.n_examples_per_gradient

    if not n_examples_per_gradient:
        return [-1 for _ in gradients_filepath]

    assert len(gradients_filepath) == len(n_examples_per_gradient)
    return [int(i) for i in n_examples_per_gradient]


def load_gradients(device: torch.device) -> torch.Tensor:
    gradients = []
    for filepath, n_examples in zip(FLAGS.gradients_filepath, _get_n_examples_per_gradient()):
        x = dn_gradients.load_gradients(filepath)
        if n_examples >= 0:
            x = x[:n_examples]
        gradients.append(x)
    return torch.cat(gradients, dim=0).to(device)


def load_norms(device: torch.device) -> torch.Tensor:
    norms = []
    for filepath, n_examples in zip(FLAGS.gradients_filepath, _get_n_examples_per_gradient()):
        x = dn_gradients.load_norms(filepath)
        if n_examples >= 0:
            x = x[:n_examples]
        norms.append(x)
    return torch.cat(norms, dim=0).to(device)


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.output_filepath is not None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gradients = load_gradients(device)

    if FLAGS.normalize_gradients:
        norms = load_norms(device)
        gradients.div_(norms[:, None])

    compute_kmeans_lloyds = getattr(kmeans, FLAGS.alg)
    kmeans_decomp = compute_kmeans_lloyds(
        gradients,
        n_clusters=FLAGS.n_clusters,
        n_iterations=FLAGS.n_iterations,
    )

    max_gpu_memory_bytes = torch.cuda.memory.max_memory_allocated()
    print(f'max_gpu_memory_bytes: {max_gpu_memory_bytes}')

    kmeans_decomp.save(FLAGS.output_filepath)


if __name__ == "__main__":
    app.run(main)
