# verify_orthonormality_of_decomp001.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python local_scripts/m_npeff/stiefel/verify_orthonormality_of_decomp001.py

python3 local_scripts/m_npeff/stiefel/verify_orthonormality_of_decomp001.py

"""
import os
import dataclasses
from importlib import reload
import random

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from em.tools.nmf import lrm_npeff


###############################################################################

# NMF_PATH = "/tmp/blah.h5"

# NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/"
NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_stiefel_m_npeff003.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)
###############################################################################


nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)

G = nmf.G


def test_orthogonality():
    rank = G.shape[0]
    norm_G = G / tf.sqrt(tf.reduce_sum(tf.square(G), axis=-1, keepdims=True))
    GG = tf.einsum('ij,kj->ik', norm_G, norm_G)
    mat = tf.abs(GG) - 1e9 * tf.eye(rank)
    return tf.reduce_max(mat).numpy()


def test_unitness():
    norm = tf.sqrt(tf.reduce_sum(tf.square(G), axis=-1))
    return tf.reduce_max(tf.abs(norm - 1)).numpy()


print(test_orthogonality())
print(test_unitness())
