"""Toy implementation of the factorization."""
import numpy as np
import tensorflow as tf
from tqdm import tqdm


class Factorizer:

    def __init__(self, A: tf.Tensor, rank: int, lr_G: float, eps: float, loss_frequency: int = 25):
        # A.shape = [n_examples, n_classes, n_params]
        self.A = A
        self.rank = rank
        self.lr_G = lr_G
        self.eps = eps
        self.loss_frequency = loss_frequency
        self.n_examples, self.n_classes, self.n_params = A.shape

        # Parameters to learn.
        self.W = tf.Variable(tf.random.uniform([self.n_examples, self.rank], dtype=tf.float32),
                             trainable=False, name='W')
        # TODO: More theoretically principled computation of the scaling initialization factor.
        # We would want each reconstructed PEF at initialization to have roughly unit Frobenious norm.
        g_factor = tf.sqrt(self.rank * self.n_params / 2)
        self.G = tf.Variable(tf.random.normal([self.rank, self.n_params], dtype=tf.float32) / float(g_factor),
                             trainable=False, name='G')

    @tf.function
    def _W_update_step(self):
        # Compute the numerator.
        tmp = tf.einsum('ick,jk->ijc', self.A, self.G)
        XH = tf.einsum('ijc,ijc->ij', tmp, tmp)

        # Compute the denominator.
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        HH = tf.square(GG)
        WHH = tf.einsum('ij,jk->ik', self.W, HH)

        self.W.assign(self.W * XH / (WHH + self.eps))

    @tf.function
    def _G_update_step(self):
        WW = tf.einsum('ji,jk->ik', self.W, self.W)
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        term1 = 4 * tf.einsum('ij,ij,jk->ik', WW, GG, self.G)

        GA = tf.einsum('il,jkl->ijk', self.G, self.A)
        term2 = -4 * tf.einsum('ji,ijk,jkl->il', self.W, GA, self.A)

        grad = term1 + term2
        self.G.assign_sub(self.lr_G * grad)

    @tf.function
    def _compute_loss(self):
        WW = tf.einsum('ji,jk->ik', self.W, self.W)
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        HH = tf.square(GG)
        tr_WW_HH = tf.einsum('ij,ij->', WW, HH)

        tmp = tf.einsum('ick,jk->ijc', self.A, self.G)
        XH = tf.einsum('ijc,ijc->ij', tmp, tmp)
        tr_WHX = tf.einsum('ij,ij->', self.W, XH)

        tmp = tf.einsum('icl,ikl->ick', self.A, self.A)
        tr_XX = tf.einsum('icl,icl->', tmp, tmp)

        return tr_XX - 2 * tr_WHX + tr_WW_HH

    def fit(self, n_iters, *, update_G: bool = True, update_W: bool = True):
        # NOTE: It is recommended to do at least a few steps updating H only
        # before performing alternating updates on both H and W. Otherwise,
        # it appears that the W-update step will mess up the convergence. I
        # think that it leads to W becoming (close to) the zero matrix.
        losses = []

        for step in tqdm(range(n_iters)):
            if update_G:
                self._G_update_step()

            if update_W:
                self._W_update_step()

            if ((step + 1) % self.loss_frequency) == 0:
                loss = self._compute_loss().numpy()
                print(loss)

        return losses
