"""Converts a file of dense gradients to a file of dense LRM PEFs.

Basically renames (and reshapes) a few datasets and sets the PEF norms as the squares
of the gradient norms.

NOTE: Only writes pefs and norms, and copies over attributes. Everything else is not carried over.
"""

import os

from absl import app
from absl import flags

import h5py
import numpy as np

from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('gradients_filepath', None, '')


###############################################################################


def _set_data_group_attributes(f_in: h5py.File, f_out: h5py.File):
    dg_in = f_in['data']
    dg_out = f_out['data']

    dg_out.attrs['pef_format'] = "frdn_lrm"
    dg_out.attrs['pef_format_version'] = dg_in.attrs['gradient_format_version']
    dg_out.attrs['rank'] = 1
    dg_out.attrs['n_classes'] = 1

    dg_out.attrs['parameter_infos'] = dg_in.attrs['parameter_infos']
    dg_out.attrs['n_og_parameters'] = dg_in.attrs['n_og_parameters']
    dg_out.attrs['n_parameters'] = dg_in.attrs['n_parameters']

    if 'random_projection_params' in dg_in.attrs:
        dg_out.attrs['random_projection_params'] = dg_in.attrs['random_projection_params']


def _set_pefs_and_norms(f_in: h5py.File, f_out: h5py.File):
    grads = hdf5_utils.load_h5_ds(f_in['data/gradients'])
    grad_norms = hdf5_utils.load_h5_ds(f_in['data/norms'])

    # shape: [n_examples, n_parameters] => [n_examples, 1 (rank), n_parameters]
    hdf5_utils.save_h5_ds(f_out, 'data/pefs', grads[:, None, :])
    # The norm of the corresponding rank-1 PEF is the square of the gradient norm.
    hdf5_utils.save_h5_ds(f_out, 'data/pef_frobenius_norms', grad_norms * grad_norms)


###############################################################################


def main(_):
    assert FLAGS.output_filepath is not None

    with h5py.File(FLAGS.gradients_filepath, "r") as f_in, h5py.File(FLAGS.output_filepath, "w") as f_out:
        f_out.create_group('data')
        _set_data_group_attributes(f_in, f_out)
        _set_pefs_and_norms(f_in, f_out)


if __name__ == "__main__":
    app.run(main)
