R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/m_npeff/m_npeff_factorizer_test001.py

"""
import dataclasses
import random

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf

from em.tools.m_npeff import m_npeff1

###############################################################################


@dataclasses.dataclass
class Inputs:
    G: np.ndarray
    basis_inds: np.ndarray
    A: np.ndarray


def make_inputs(n: int, m: int, r: int, c: int):
    G = tf.random.normal([r, m]).numpy()

    basis_inds = np.array([
        [random.randrange(r) for _ in range(c)]
        for i in range(n)
    ], dtype=np.int32)

    A = np.array([
        G[basis_inds[i]]
        for i in range(n)
    ], dtype=np.float32)

    return Inputs(G=G, basis_inds=basis_inds, A=A)

###############################################################################


def cosine_sim(A, B):
    A = A / np.sqrt(np.sum(A**2, axis=-1, keepdims=True))
    B = B / np.sqrt(np.sum(B**2, axis=-1, keepdims=True))
    return np.einsum('ik,jk->ij', A, B)


###############################################################################


N = 32
M = 128
R = 12
C = 2

inputs = make_inputs(N, M, R, C)

factorizer = m_npeff1.Factorizer(
    A=tf.constant(inputs.A),
    rank=R,
    lr_G=1e-3,
    eps=1e-7,
)

N_ITERS = 1000
factorizer.fit(N_ITERS)


sim_matrix = cosine_sim(inputs.G, factorizer.G.numpy())
plt.imshow(sim_matrix); plt.show()

plt.imshow(factorizer.W.numpy()); plt.show()
