"""Performs a LRM-NPEFF decomposition using a PyTorch implementation."""

from absl import app
from absl import flags

import torch

from npeff_torch.peis.fishers.formats import frdn_lrm_pefs
from npeff_torch.decomps.npeff import lrm_npeff_decomposer

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_list('pef_filepaths', None, '')

flags.DEFINE_integer('n_components', None, '')
flags.DEFINE_integer('n_iters_G_only', None, '')
flags.DEFINE_integer('n_iters_joint', None, '')

flags.DEFINE_float('learning_rate_G', None, '')
flags.DEFINE_float('learning_rate_G_G_only', None, '')

flags.DEFINE_float('mu_eps', 1e-9, "Epsilon for the multiplicative update on W.")

flags.DEFINE_integer('rand_gen_seed', 48230, '')
flags.DEFINE_integer('log_loss_frequency', 10, '')


# Set these to None or negative number to not do them.
flags.DEFINE_float('abs_tol_G_only', -1, '')
flags.DEFINE_float('rel_tol_G_only', -1, '')
flags.DEFINE_float('abs_tol_joint', -1, '')
flags.DEFINE_float('rel_tol_joint', -1, '')

###############################################################################


def _load_and_normalize_pef_file(filepath: str):
    pefs = frdn_lrm_pefs.load_pefs(filepath)
    norms = frdn_lrm_pefs.load_pef_frobenius_norms(filepath)
    lrm_npeff_decomposer.normalize_pefs_in_place(pefs, norms)
    return pefs


def _load_and_normalize_pefs(filepaths):
    return torch.cat([_load_and_normalize_pef_file(f) for f in filepaths], dim=0)


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.output_filepath is not None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pefs = _load_and_normalize_pefs(FLAGS.pef_filepaths)
    print('Loaded pefs')
    pefs = pefs.to(device)

    runner = lrm_npeff_decomposer.LrmNpeffRunner(
        pefs,
        n_components=FLAGS.n_components,
        seed=FLAGS.rand_gen_seed,
        mu_eps=FLAGS.mu_eps,
        learning_rate_G_G_only=FLAGS.learning_rate_G_G_only,
        learning_rate_G_G_joint=FLAGS.learning_rate_G,
        n_iters_G_only=FLAGS.n_iters_G_only,
        n_iters_joint=FLAGS.n_iters_joint,
        log_loss_frequency=FLAGS.log_loss_frequency,
        # 
        abs_tol_G_only=FLAGS.abs_tol_G_only if FLAGS.abs_tol_G_only >= 0 else None,
        rel_tol_G_only=FLAGS.rel_tol_G_only if FLAGS.rel_tol_G_only >= 0 else None,
        abs_tol_joint=FLAGS.abs_tol_joint if FLAGS.abs_tol_joint >= 0 else None,
        rel_tol_joint=FLAGS.rel_tol_joint if FLAGS.rel_tol_joint >= 0 else None,
    )
    runner.run()

    max_gpu_memory_bytes = torch.cuda.memory.max_memory_allocated()
    print(f'max_gpu_memory_bytes: {max_gpu_memory_bytes}')

    runner.save(FLAGS.output_filepath)


if __name__ == "__main__":
    app.run(main)
