"""
# ncp: nonnegative tensor decomposition (CANDECOMP/PARAFAC) by block-coordinate update
#  min 0.5*||M - A_1\circ...\circ A_N||_F^2  
#  subject to A_1>=0, ..., A_N>=0
#
# input: 
#       M: input nonnegative tensor
#       r: estimated rank (each A_i has r columns); require exact or moderate overestimates
#       opts.
#           tol: tolerance for relative change of function value, default: 1e-4
#           maxit: max number of iterations, default: 500
#           maxT: max running time, default: 1e3
#           rw: control the extrapolation weight, default: 1
#           A0: initial point in cell struct, default: Gaussian random
#           matrices
# output:
#       A: nonnegative ktensor
#       Out.
#           iter: number of iterations
#           hist_obj: history of objective values
#           hist_rel: history of relative objective changes (row 1) and relative residuals (row 2)
#
# require MATLAB Tensor Toolbox from
# http://www.sandia.gov/~tgkolda/TensorToolbox/
#
# More information can be found at:
# http://www.caam.rice.edu/~optimization/bcu/
"""

from copy import deepcopy
import time
import warnings
import pandas as pd
import numpy as np
from numpy import linalg
from scipy import stats
from tensorly import unfolding_dot_khatri_rao
# from tensorly.kruskal_tensor import kruskal_to_tensor
from tensorly.cp_tensor import kruskal_to_tensor
import utils


def ncp(M, r, tol=1e-4, maxit=500, maxT=1e+6, rw=1, verbose=False):

    # Data preprocessing and initialization

    N       = M.ndim  # M is an N-way tensor
    M_shape = M.shape  # dimensions of M
    M_norm  = linalg.norm(M)  # norm of M
    obj0    = .5 * M_norm ** 2  # initial objective value

    if verbose == True:
        print('N=', N)
        print('M_shape=', M_shape)
        print('M_norm=', M_norm)
        print('obj0=', obj0)
        print()

    # initial tensor factors
    A0, Asq = [], []

    for m in M_shape:
        # randomly generate each factor
        A0m = np.maximum(np.zeros((m, r)), np.random.rand(m, r))
        A0.append(A0m)
        # normalize A0 and cache its square
        # fro: Frobenius norm
        A0m /= linalg.norm(A0m, ord='fro') * M_norm ** (1 / N)
        Asqm = A0m.T @ A0m
        Asq.append(Asqm)

    # Am = np.copy(A0)
    # A  = np.copy(A0)
    Am = deepcopy(A0)
    A  = deepcopy(A0)

    nstall = 0  # of stalled iterations
    t0 = 1  # used for extrapolation weight update
    wA = np.ones((N, 1))  # extrapolation weight aaray
    L0 = np.ones((N, 1))  # Lipschitz constant array
    L  = np.ones((N, 1))

    # Store data?

    ### Iterations of block-coordinate update

    # iteratively updated variables
    # =============================
    #   Gn: gradients with respect to A[n]
    #   A: new updates
    #   A0: old updates
    #   Am: extrapolations of A
    #   L, L0: current and previous Lipshitz bounds
    #   obj, obj0: current and previous objective values

    start_time = time.process_time()

    for k in range(maxit):

        for n in range(N):
            Bsq = np.ones(r)
            for i in range(N):
                if not i == n:
                    Bsq = Bsq * Asq[i]  # element-wise product
            
            L0[n] = L[n]  # caution!!
            L[n]  = linalg.norm(Bsq)  # gradient Lipschitz constant

            # Here, not using stored data in the original code
            MB = unfolding_dot_khatri_rao(M, (None, A), n)

            # compute the gradient
            Gn = Am[n] @ Bsq - MB
            A[n] = np.maximum(np.zeros(Am[n].shape), Am[n] - Gn / L[n])
            Asq[n] = A[n].T @ A[n]

        obj = .5 * (
            np.sum(np.sum(Asq[-1] * Bsq))
            - 2 * np.sum(np.sum(A[-1] * MB))
            + M_norm ** 2
        )

        relerr1 = np.abs(obj - obj0) / (obj0 + 1)
        relerr2 = (2 * obj) ** .5 / M_norm

        # check stopping criterion
        if relerr1 < tol:
            break
        if nstall >= 3 or relerr2 < tol:
            break
        if time.process_time() - start_time > maxT:
            warnings.warn("Time over")
            break

        # correction and extrapolation
        t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2

        if obj >= obj0:
            Am = A0

        else:
            # apply extrapolation
            w = (t0 - 1) / t
            for n in range(N):
                wA[n] = np.minimum(w, rw * L0[n] / L[n])
                Am[n] = A[n] + wA[n] * (A[n] - A0[n])
            
            A0 = A
            t0 = t
            obj0 = obj

    return A  # nonnegative k-tensor


def predict(A, forecast_step):
    pred = A[0] @ np.diag(A[-1].mean(axis=0)) @ A[1].T
    pred = np.tile(pred.T, (forecast_step, 1, 1)).transpose()
    return pred


class NCP:
    def __init__(self, rank=4):
        self.rank = rank

    def fit(self, data, t=0):
        if type(data) == pd.DataFrame:
            # convert the data to a tensor
            data = utils.list2tensor_from_index(
                data, self.date_range, self.n_attributes)

        tic = time.process_time()
        self.factors = ncp(data, self.rank, tol=1e-6, maxit=50)
        toc = time.process_time() - tic
        return np.array((toc,))

    def predict(self, forecast_step):
        time_factor = np.diag(self.factors[-1].mean(axis=0))
        pred = self.factors[0] @ time_factor @ self.factors[1].T
        pred = np.tile(pred.T, (forecast_step, 1, 1)).transpose()
        return pred

    def save(self, outpath):
        for i, A in enumerate(self.factors):
            np.savetxt(outpath + f"/A{i}.txt", A)

    def set_params(self, args):
        return dict(n_components=self.rank)


class FOLD:
    def __init__(self, period, rank=4):
        self.period = period
        self.rank = rank

    def fit(self, data, t=0):
        if type(data) == pd.DataFrame:
            # convert the data to a tensor
            data = utils.list2tensor_from_index(
                data, self.date_range, self.n_attributes)
        
        tic = time.process_time()

        # Fold 
        n_cycle = data.shape[-1] // self.period
        folded_shape = (*data.shape[:-1], n_cycle, self.period)
        folded = data[..., -self.period*n_cycle:].reshape(folded_shape)
        print("FOLD:", data.shape, "=>", folded.shape)

        # Factorize
        self.factors = ncp(folded, self.rank, maxit=5)
        self.factors[-1] = self.factors[-1] @ np.diag(
            self.factors[-2].mean(axis=0))
        self.factors = [A for i, A in enumerate(self.factors) if not i == data.ndim - 1]

        toc = time.process_time() - tic
        return np.array((toc,))

    def predict(self, forecast_step):
        time_factor = np.tile(self.factors[-1],
                              (forecast_step // self.period + 1, 1))
        time_factor = time_factor[:forecast_step]
        factors = [self.factors[0], self.factors[1], time_factor]
        pred = kruskal_to_tensor((None, factors))
        # print("pred", pred.shape)
        return pred

    def save(self, outpath):
        for i, A in enumerate(self.factors):
            np.savetxt(outpath + f"/A{i}.txt", A)

    def set_params(self, args):
        return dict(n_components=self.rank)

