R"""Given an LRM-NPEFF decomposition, re-writes components to be orthogonal rejection onto subset of others.
"""
import os

from absl import app
from absl import flags

import numpy as np
from tqdm import tqdm

from em.tools.nmf import lrm_npeff

FLAGS = flags.FLAGS

flags.DEFINE_string("decomposition_filepath", None, "")
flags.DEFINE_string("output_filepath", None, "Path to h5 file to write output to.")

flags.DEFINE_float("max_similarity", None, '')

flags.mark_flags_as_required(["decomposition_filepath", "output_filepath", "max_similarity"])


def make_rejected_perturbation(G, G2, cos_sims, component_index: int):
    max_sim = FLAGS.max_similarity
    g_main = np.copy(G[component_index])

    for i in range(G.shape[0]):
        if i == component_index:
            continue
        if cos_sims[component_index, i] > max_sim:
            continue
        g_main -= g_main.dot(G[i]) * G[i]

    g_main /= np.sqrt(np.sum(g_main**2))

    G2[component_index] = g_main


def main(_):
    nmf = lrm_npeff.LrmNpeffDecomposition.load(FLAGS.decomposition_filepath, read_G=True)
    nmf.normalize_components_to_unit_norm()

    G = nmf.G
    G2 = np.zeros_like(G)

    cos_sims = np.abs(G @ G.T)

    for i in tqdm(range(G.shape[0])):
        make_rejected_perturbation(G, G2, cos_sims, i)

    nmf.G = G2
    nmf.save(FLAGS.output_filepath)


if __name__ == "__main__":
    app.run(main)
