"""Given two pairs of runs, computes the max coefficient cosine similarity between component coefficients.

The npeff coefficients should have been computed on the same set of PEFs for each decomposition.
Useful for seeing if components in one decomposition roughly have a matching component in another decomposition.
"""
import os

from absl import app
from absl import flags

import h5py
import torch

from npeff_torch.util import hdf5_utils


###############################################################################
FLAGS = flags.FLAGS

flags.DEFINE_string('npeff_filepath_1', None, '')
flags.DEFINE_string('npeff_filepath_2', None, '')

###############################################################################


def _read_and_normalize_coeffs(filepath: str, device: torch.device) -> torch.Tensor:
    # Result will have each [n_examples] vector associated to a component have unit l2 norm.
    # ret.shape = [n_examples, n_components]
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        W = hdf5_utils.load_h5_ds(f['data/W'])
    W = torch.from_numpy(W).to(device)
    W = torch.nn.functional.normalize(W, dim=0)
    return W


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    W1 = _read_and_normalize_coeffs(FLAGS.npeff_filepath_1, device)
    W2 = _read_and_normalize_coeffs(FLAGS.npeff_filepath_2, device)

    cs = torch.einsum('ec,ek->ck', W1, W2)

    vals1, inds1 = torch.max(cs, dim=1)
    vals2, inds2 = torch.max(cs, dim=0)

    median1 = float(torch.median(vals1).detach().cpu().numpy())
    mean1 = float(torch.mean(vals1).detach().cpu().numpy())

    median2 = float(torch.median(vals2).detach().cpu().numpy())
    mean2 = float(torch.mean(vals2).detach().cpu().numpy())

    print(f'1: {median1} [median], {mean1} [mean]')
    print(f'2: {median2} [median], {mean2} [mean]')

    print(f'overall mean: {(mean1 + mean2) / 2.0}')


if __name__ == "__main__":
    app.run(main)
