import collections
import time

import numpy as np
import scipy.sparse as sps

from src.kmeans_printer import ColumnPrinter, VoidPrinter
from src.kmeans_utilities import symbasis, comms_invperm, coo_perm_col

import warnings
warnings.filterwarnings("ignore", category=sps.linalg.MatrixRankWarning)


def orthogonal_hess(r):
    """
    Compute Hessian of the orthogonality constraint
    """
    I = np.eye(r)

    row_idx = []
    col_idx = []
    svec_factor = []
    for i in range(r):
        for j in range(i, r):
            row_idx.append(i)
            col_idx.append(j)

            if i == j:
                svec_factor.append(1)
            else:
                svec_factor.append(2. ** 0.5)

    ee = np.einsum('ij, ik->ijk', I[row_idx, :], I[:, col_idx].T)
    ee *= np.array(svec_factor)[:, None, None]

    return np.kron(ee, I) + np.kron(ee.transpose(0, 2, 1), I)


class KmeansCR():
    def __init__(
        self,
        n_clusters=None,
        max_iterations: int = 1000,
        min_gradient_norm: float = 1e-6,
        verbosity: int = 2,
        log_verbosity: int = 0
    ):
        self._max_iterations = max_iterations
        self._min_gradient_norm = min_gradient_norm
        self._verbosity = verbosity
        self._log_verbosity = log_verbosity
        self._log = None
        self._sqrtK1 = (n_clusters - 1) ** 0.5

    def run(
        self,
        data,
        rank,
        beta=None,
        kappa=(2.0, 1.5),
        mu=1.0,
        initial_point=None,
        maxinner=None
    ):
        # get Kmean parameters
        n, d = data.shape
        r = rank

        # store parameters for later use
        self._n = n
        self._n_root = n ** 0.5
        self._r = r
        self._d = d

        # number of constraints
        self._num_qconst = r * (r + 1) // 2
        self._num_const = int((d + 1) * (r - 1) + self._num_qconst + 1)

        # number of variables
        self.v_size = n * (r - 1)
        self.r_size = d * (r - 1)
        self.q_size = r * r

        # hyperparameters
        self._beta = beta
        self._kappa = kappa
        self._mu = mu

        # auxiliary matrices
        self._D = data.T
        row = np.arange(r - 1)
        col = row + 1
        self.E = sps.csr_array((np.ones(r - 1, dtype=np.int8), (row, col)),
                               shape=(r - 1, r))
        self.I_n = sps.eye_array(n)
        self.I_r = sps.eye_array(r)
        self.I_n_r1 = sps.eye_array(n * (r - 1))
        self.I_d_r1 = sps.eye_array(d * (r - 1))
        self.perm_n_r1 = comms_invperm(n, r - 1)
        self._symbasis = 2. * symbasis(self._r).T
        self.hOrth = orthogonal_hess(r)
        self.XI = (2. ** .5) * np.kron(data, np.eye(r - 1))
        self.diag_const = np.concatenate((np.ones(r ** 2),
                                          np.zeros(r + self._num_qconst + self.r_size)))
        self.diag_size = int(r ** 2 + r + self._num_qconst + self.r_size)

        # Jacobian building blocks
        I_r1 = np.eye(r - 1)
        perm_r1_n = comms_invperm(r - 1, n)
        self.J1q = np.zeros((1, self.q_size))
        self.J2q = np.zeros((r - 1, self.q_size))
        self.J2v = np.kron(I_r1, np.ones((1, n)))[:, perm_r1_n]
        self.J3v = np.zeros((self._num_qconst, self.v_size))

        if maxinner is None:
            self._maxinner = int(n // r)
        else:
            self._maxinner = maxinner

        # V, Q variables
        X = initial_point

        ### Initializations ###
        start_time = time.perf_counter()

        # Number of outer (TR) iterations. The semantic is that `iteration`
        # counts the number of iterations fully executed so far.
        iteration = 0

        # Initialize solution and companion measures: f(x)
        self._loss, self._pen = self._cost(X[0], X[1])
        self._fx = self._loss + self._pen

        # Display:
        if self._verbosity == 1:
            print("Optimizing...")
        if self._verbosity >= 2:
            iteration_format_length = int(np.log10(self._max_iterations)) + 1
            column_printer = ColumnPrinter(
                columns=[
                    ("Iteration", f"{iteration_format_length}d"),
                    ("Inner", "5d"),
                    ("beta", "+.2e"),
                    ("Loss", "+.5e"),
                    ("Log-barrier", "+.5e"),
                    ("Gradient norm", ".5e")
                ]
            )
        else:
            column_printer = VoidPrinter()

        column_printer.print_header()

        self._initialize_log()
        while True:
            iteration += 1

            # **********************
            # ** Solve Subproblem **
            # **********************
            X_prop, loss_prop, pen_prop, norm_grad, numit = self._lagrangian_sub_problem(
                X)

            # Display:
            column_printer.print_row(
                [iteration, numit, self._beta, self._loss, self._pen, norm_grad])

            self._add_log_entry(
                iteration=iteration,
                cost=self._fx,
                gradient_norm=norm_grad,
                num_inner=numit,
                beta=self._beta
            )

            # Check if the update had failed
            model_decreased = (loss_prop + pen_prop) <= self._fx
            if not model_decreased:
                print("model did not decrease, stopped after "
                      f"{time.perf_counter() - start_time:.2f} seconds."
                      )
                opt_res = {
                    "time": time.perf_counter() - start_time,
                    "log": self._log,
                    "point": X,
                    "cost": self._fx,
                    "gradient_norm": norm_grad,
                    "iterations": iteration
                }
                self._cleanup()

                return opt_res
            else:
                X = X_prop
                self._loss = loss_prop
                self._pen = pen_prop
                self._fx = self._loss + self._pen

            # ** CHECK STOPPING criteria
            run_time = time.perf_counter() - start_time
            stopping_criterion = None
            if iteration >= self._max_iterations:
                stopping_criterion = (
                    "Terminated - max iterations reached after "
                    f"{run_time:.2f} seconds."
                )
            elif norm_grad < self._min_gradient_norm:
                stopping_criterion = (
                    f"Terminated - min grad norm reached after {iteration} "
                    f"iterations, {run_time:.2f} seconds."
                )

            if stopping_criterion:
                if self._verbosity >= 1:
                    print(stopping_criterion)
                    print("")
                break

        opt_res = {
            "time": time.perf_counter() - start_time,
            "log": self._log,
            "point": X,
            "cost": self._fx,
            "gradient_norm": norm_grad,
            "iterations": iteration
        }
        self._cleanup()

        return opt_res

    def _reconstruct_mat(self, v, q):
        v_mat = v.reshape(self._n, self._r - 1)
        q_mat = q.reshape(self._r, self._r, order='F')

        return v_mat, q_mat

    def _cost(self, V, Q):
        U = V @ Q[1:, :] + Q[0, :] / (self._n ** 0.5)

        F1 = -(self._D.dot(V) ** 2.).sum()
        with np.errstate(invalid='ignore'):
            F2 = -self._mu * np.log(U).sum()

        return F1, F2

    def _retraction(self, X, U):
        # retract V
        V = X[0] + U[0]
        V = self._sqrtK1 * V / np.linalg.norm(V, 'fro')

        # retract Q
        # qr implementation from MANOPT
        Q = X[1] + U[1]
        q, r = np.linalg.qr(Q)

        # Compute signs or unit-modulus phase of entries of diagonal of r.
        s = np.diagonal(r).copy()
        s = np.where(s == 0, 1, s)
        s /= np.abs(s)

        s = np.expand_dims(s, axis=-1)
        Q = q * s.T

        return (V, Q)

    def _saddle_point_jacobian(self, V, Q):
        """
        Jacobian of the saddle point problem
        """
        Qhat = Q[1:, :]

        U = V @ Qhat + Q[0, :] / self._n_root
        U1 = np.power(U, -1.)
        U2 = np.power(U1, 2.).ravel('F')

        QU1 = self._mu * (Q[1:, :] @ U1.T).ravel('F')
        QU1 += 2. * (self._D.T @ self._D.dot(V)).ravel()
        VU1 = (np.vstack((U1.sum(axis=0) / self._n_root, V.T @ U1))).ravel('F')
        VU1 *= self._mu

        J1v = 2. * V.ravel()
        J3q = self._symbasis @ np.kron(np.eye(self._r), Q.T)

        return (QU1, VU1), (J1v, J3q), (U1, U2)

    def _saddle_point_hessian(self, V, Q, Aux):
        """
        Sparse Hessian of the saddle point problem
        """
        Qhat = Q[1:, :]
        Vhat = np.pad(V, ((0, 0), (1, 0)), constant_values=1. / self._n_root)
        U1, U2 = Aux

        QI = coo_perm_col(sps.kron(Qhat.T, self.I_n), self.perm_n_r1)
        IV = sps.kron(self.I_r, Vhat)
        U2QI = (QI.T).multiply(U2).T
        U2IV = (IV.T).multiply(U2).T

        hV = self._mu * (QI.T @ U2QI)
        hQ = self._mu * (IV.T @ U2IV)
        hVQ = sps.kron(-U1, self.E) + (QI.T @ U2IV)
        hVQ *= self._mu

        return (hV.tocsc(), hQ.tocsc(), hVQ.todense())

    def _lagrangian_sub_problem(self, X):
        # compute gradient/Jacobian/Hessian
        g, Jacob, Aux = self._saddle_point_jacobian(X[0], X[1])
        hV, hQ, hVQ = self._saddle_point_hessian(X[0], X[1], Aux)
        d = np.concatenate((g[0], g[1]))
        J = np.block([[Jacob[0], self.J1q],
                      [self.J2v, self.J2q],
                      [self.J3v, Jacob[1]]]).T

        # solve for the Lagrangian multipliers
        lam, norm_grad = np.linalg.lstsq(J, d)[0:2]
        norm_grad = norm_grad.item() ** 0.5

        # compute the linear coefficients
        G = J.dot(lam) - d

        # adding Hessian from the constraints
        hOrth = np.einsum('i, ijk->jk', lam[-self._num_qconst:], self.hOrth)
        hV += lam[0] * 2. * self.I_n_r1
        hQ += sps.csr_array(hOrth)

        # prepare matrix blocks
        Gv = -G[:self.v_size]
        Gq = -G[self.v_size:(self.v_size + self.q_size)]
        Jv = J[:self.v_size]
        Jq = J[self.v_size:(self.v_size + self.q_size)]
        K12 = np.block([hVQ, Jv, self.XI])
        B = np.pad(Gq, (0, self._num_const))

        # blocks for computing the Schur complement
        K22 = sps.block_array([[hQ, Jq, None],
                               [Jq.T, None, None],
                               [None, None, self.I_d_r1]]).todense()

        # save the diagonal elements
        beta = self._beta
        hV_diag = hV.diagonal().copy()
        K22_diag = K22.diagonal().copy()
        for j in range(int(self._maxinner)):
            # cubic regularization
            hV.setdiag(hV_diag + beta)
            np.fill_diagonal(K22, K22_diag + beta * self.diag_const)

            # compute the Schur complement
            S = K22 - K12.T @ sps.linalg.spsolve(hV, K12)

            # solve for the tangents
            Q_prop = np.linalg.solve(S, B - K12.T @ sps.linalg.spsolve(hV, Gv))
            V_prop = sps.linalg.spsolve(hV, Gv - K12 @ Q_prop)
            Q_prop = Q_prop[:self.q_size]

            # sufficient decrease
            gp = V_prop.dot(Gv) + Q_prop.dot(Gq)
            gp *= 1e-4

            # retract the tangents
            V_prop, Q_prop = self._reconstruct_mat(V_prop, Q_prop)
            V_prop, Q_prop = self._retraction(X, (V_prop, Q_prop))

            # check new solution
            loss_prop, pen_prop = self._cost(V_prop, Q_prop)
            f_prop = loss_prop + pen_prop + gp
            if f_prop >= self._fx or np.isnan(f_prop):
                beta *= self._kappa[1]
            else:
                beta /= self._kappa[0]
                break

        self._beta = beta

        return (V_prop, Q_prop), loss_prop, pen_prop, norm_grad, j + 1

    def _initialize_log(self):
        self._log = {
            "optimizer": str(self),
            "stopping_criteria": {
                "max_iterations": self._max_iterations,
                "min_gradient_norm": self._min_gradient_norm
            },
            "iterations": collections.defaultdict(list)
            if self._log_verbosity >= 1
            else None,
        }

    def _add_log_entry(self, *, iteration, cost, **kwargs):
        if self._log_verbosity <= 0:
            return
        self._log["iterations"]["time"].append(time.perf_counter())
        self._log["iterations"]["iteration"].append(iteration)
        self._log["iterations"]["cost"].append(cost)
        for key, value in kwargs.items():
            self._log["iterations"][key].append(value)

    def _cleanup(self, ):
        attr_list = ['_sqrtK1', '_n_root',
                     '_num_qconst', '_num_const',
                     'v_size', 'r_size', 'q_size', '_D',
                     'E', 'I_n', 'I_r', 'I_n_r1', 'I_d_r1',
                     'perm_n_r1', '_symbasis', 'hOrth', 'XI', 'diag_const', 'diag_size',
                     'J1q', 'J2q', 'J2v', 'J3v']

        for attr in attr_list:
            delattr(self, attr)

        return
