"""Toy implementation of the factorization."""
from typing import Optional
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm


class StiefelFactorizer:
    """Ensures that the g_i are mutually orthogonal."""

    def __init__(
        self,
        A: tf.Tensor, rank: int, lr_G: float, eps: float, rcond: Optional[float] = None, loss_frequency: int = 25
    ):
        # A.shape = [n_examples, n_classes, n_params]
        self.A = A
        self.rank = rank
        self.lr_G = lr_G
        self.eps = eps
        self.loss_frequency = loss_frequency
        self.rcond = rcond
        self.n_examples, self.n_classes, self.n_params = A.shape

        # TODO: More theoretically principled computation of the scaling initialization factor.
        # We would want each reconstructed PEF at initialization to have roughly unit Frobenious norm.
        #
        # TODO: Might need to update to handle orthonormal G.
        # scale_factor = tf.sqrt(self.rank * self.n_params / 2)
        # scale_factor = tf.sqrt(self.rank * self.n_examples / 2)
        scale_factor = 0.5 * self.rank

        # Parameters to learn.
        self.W = tf.Variable(tf.random.uniform([self.n_examples, self.rank], dtype=tf.float32) / scale_factor,
                             trainable=False, name='W')
        self.G = tf.Variable(self._make_initial_value_G(),
                             trainable=False, name='G')

    def _make_initial_value_G(self):
        # Initializes to orthonormal matrix.
        G = tf.random.normal([self.n_params, self.rank], dtype=tf.float32)
        G = tfp.math.gram_schmidt(G)
        return tf.transpose(G)

    @tf.function
    def _W_update_step(self):
        # Compute the numerator.
        tmp = tf.einsum('ick,jk->ijc', self.A, self.G)
        XH = tf.einsum('ijc,ijc->ij', tmp, tmp)

        # Compute the denominator.
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        HH = tf.square(GG)
        WHH = tf.einsum('ij,jk->ik', self.W, HH)

        self.W.assign(self.W * XH / (WHH + self.eps))

    @tf.function
    def _compute_loss(self):
        WW = tf.einsum('ji,jk->ik', self.W, self.W)
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        HH = tf.square(GG)
        tr_WW_HH = tf.einsum('ij,ij->', WW, HH)

        tmp = tf.einsum('ick,jk->ijc', self.A, self.G)
        XH = tf.einsum('ijc,ijc->ij', tmp, tmp)
        tr_WHX = tf.einsum('ij,ij->', self.W, XH)

        tmp = tf.einsum('icl,ikl->ick', self.A, self.A)
        tr_XX = tf.einsum('icl,icl->', tmp, tmp)

        return tr_XX - 2 * tr_WHX + tr_WW_HH

    @tf.function
    def _compute_G_loss_grad(self):
        # This is the gradient in the ambient matrix space.
        WW = tf.einsum('ji,jk->ik', self.W, self.W)
        GG = tf.einsum('ij,kj->ik', self.G, self.G)
        term1 = 4 * tf.einsum('ij,ij,jk->ik', WW, GG, self.G)
        # term1 = tf.einsum('ij,ij,jk->ik', WW, GG, self.G)

        GA = tf.einsum('il,jkl->ijk', self.G, self.A)
        term2 = -4 * tf.einsum('ji,ijk,jkl->il', self.W, GA, self.A)
        # term2 = -tf.einsum('ji,ijk,jkl->il', self.W, GA, self.A)

        return term1 + term2

    # @tf.function
    # def _G_update_step(self):
    #     tau = self.lr_G

    #     grad = self._compute_G_loss_grad()

    #     U = tf.concat([grad, self.G], axis=0)
    #     V = tf.concat([self.G, -grad], axis=0)

    #     IUV = tf.eye(2 * self.rank) + tau / 2 * tf.matmul(U, V, transpose_b=True)
    #     inv_IUV = tf.linalg.pinv(IUV, rcond=self.rcond)

    #     GV = tf.matmul(self.G, V, transpose_b=True)
    #     tmp = tf.matmul(GV, inv_IUV)
    #     delta_G = tau * tf.matmul(tmp, U)

    #     self.G.assign_sub(delta_G)

    @tf.function
    def _G_update_step(self):
        tau = self.lr_G

        G = self.G
        D = self._compute_G_loss_grad()

        GD = tf.matmul(G, D, transpose_b=True)
        DD = -tf.matmul(D, D, transpose_b=True)
        GG = tf.matmul(G, G, transpose_b=True)
        DG = -tf.transpose(GD)

        VG = tf.concat([GG, DG], axis=0)
        VU = tf.concat([
            tf.concat([GD, GG], axis=1),
            tf.concat([DD, DG], axis=1),
        ], axis=0)

        msf = tf.eye(2 * self.rank) + tau / 2 * VU

        M = tf.linalg.solve(msf, VG)
        M1 = M[:self.rank]
        M2 = M[self.rank:]

        grad2 = tf.matmul(M1, D, transpose_a=True) + tf.matmul(M2, G, transpose_a=True)

        self.G.assign_sub(tau * grad2)

    def test_orthogonality(self):
        norm_G = self.G / tf.sqrt(tf.reduce_sum(tf.square(self.G), axis=-1, keepdims=True))
        GG = tf.einsum('ij,kj->ik', norm_G, norm_G)
        mat = tf.abs(GG) - 1e9 * tf.eye(self.rank)
        return tf.reduce_max(mat)

    def fit(self, n_iters, *, update_G: bool = True, update_W: bool = True):
        # NOTE: It is recommended to do at least a few steps updating H only
        # before performing alternating updates on both H and W. Otherwise,
        # it appears that the W-update step will mess up the convergence. I
        # think that it leads to W becoming (close to) the zero matrix.
        losses = []

        for step in tqdm(range(n_iters)):
            if update_G:
                self._G_update_step()

            if update_W:
                self._W_update_step()

            if ((step + 1) % self.loss_frequency) == 0:
                loss = self._compute_loss().numpy()
                print(loss)

        return losses
