""" """
from typing import List

from absl import app
from absl import flags

import torch
from npeff_torch.peis.fishers.formats import frdn_lrm_pefs


###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('pef_filepaths', None, 'Assumed to have been run through SVD.')

###############################################################################


def _load_eigenvalues_for_file(filepath: str, device: torch.device) -> torch.Tensor:
    # pefs.shape = [examples, rank, d_proj]
    pefs = frdn_lrm_pefs.load_pefs(filepath).to(device)
    eigenvalues = torch.einsum('erp,erp->er', pefs, pefs)
    return eigenvalues


def _load_eigenvalues(filepaths: List[str], device: torch.device) -> torch.Tensor:
    eigenvalues = [
        _load_eigenvalues_for_file(filepath, device)
        for filepath in filepaths
    ]
    return torch.cat(eigenvalues, dim=0)


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # shape = [examples, rank]
    eigenvalues = _load_eigenvalues(FLAGS.pef_filepaths, device)

    # Normalize the eigenvalues.
    eigenvalues /= torch.sum(eigenvalues, dim=-1, keepdim=True)
    
    # shape = rank
    mean_eigen_values = torch.mean(eigenvalues, dim=0)
    print(mean_eigen_values.detach().cpu().numpy().tolist())


if __name__ == "__main__":
    app.run(main)
