"""Copies over the PEFs with reduced rank and recomputes PEF norms.

NOTE: Assumes that the PEFs have already been passed through SVD so that their
rows are in descending order of importance and orthogonal.

NOTE: This can probably be better accomplished by computing the PEF norms
on the fly when doing the decomposition.

NOTE: Only writes pefs and norms, and copies over attributes. Everything else
is not carried over.
"""
import os

from absl import app
from absl import flags

import h5py
import numpy as np
import torch

from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('pefs_filepath', None, '')

flags.DEFINE_integer('output_rank', None, '')


###############################################################################

def _read_in_pefs(
    filepath: str,
    rank: int,
    device: torch.device,
) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        ds = f['data/pefs']
        assert rank <= ds.shape[1]
        return torch.from_numpy(ds[:, :rank, :]).to(device)


def _compute_lrm_pef_frobenius_norms(pefs: torch.Tensor) -> torch.Tensor:
    # pefs.shape = [examples, rank, n_parameters]
    AtA = torch.einsum('ecj,ekj->eck', pefs, pefs)
    sq_norm = torch.einsum('eck,eck->e', AtA, AtA)
    return torch.sqrt(sq_norm)


def _copy_attributes(f: h5py.File, filepath: str):
    with h5py.File(os.path.expanduser(filepath), "r") as f2:
        dg = f2['data']

        f['data'].attrs['pef_format'] = dg.attrs['pef_format']
        f['data'].attrs['pef_format_version'] = dg.attrs['pef_format_version']
        f['data'].attrs['parameter_infos'] = dg.attrs['parameter_infos']
        f['data'].attrs['n_og_parameters'] = dg.attrs['n_og_parameters']
        f['data'].attrs['rank'] = dg.attrs['rank']
        f['data'].attrs['n_classes'] = dg.attrs['n_classes']
        f['data'].attrs['n_parameters'] = dg.attrs['n_parameters']

        if 'random_projection_params' in dg.attrs:
            f['data'].attrs['random_projection_params'] = dg.attrs['random_projection_params']


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.output_filepath is not None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pefs = _read_in_pefs(FLAGS.pefs_filepath, rank=FLAGS.output_rank, device=device)
    pef_frobenius_norms = _compute_lrm_pef_frobenius_norms(pefs)

    with h5py.File(os.path.expanduser(FLAGS.output_filepath), "w") as f:
        hdf5_utils.save_h5_ds(f, 'data/pefs', pefs.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(f, 'data/pef_frobenius_norms', pef_frobenius_norms.detach().cpu().numpy())
        _copy_attributes(f, FLAGS.pefs_filepath)


if __name__ == "__main__":
    app.run(main)
